diff --git a/nnfwtbn/tests/test_toydata.py b/nnfwtbn/tests/test_toydata.py new file mode 100644 index 0000000000000000000000000000000000000000..a67fecbcdc581a5ac9f898a3bda147108d63e4d3 --- /dev/null +++ b/nnfwtbn/tests/test_toydata.py @@ -0,0 +1,53 @@ + +import unittest +import numpy as np +from nnfwtbn.toydata import rand + + +class ToyDataTestBase(unittest.TestCase): + + def test_rand_length(self): + """ + Check that the returned array has the requested length. + """ + self.assertEqual(len(rand()), 1) + self.assertEqual(len(rand(size=1)), 1) + self.assertEqual(len(rand(312)), 312) + + def test_rand_repeatable(self): + """ + Check that the method returns the same array twice when called with + the same seed. + """ + numbers_1 = list(rand(143)) + numbers_2 = list(rand(143)) + self.assertEqual(numbers_1, numbers_2) + + def test_rand_seed(self): + """ + Check that using different seeds returns different values. + """ + numbers_1 = list(rand(143, seed=1)) + numbers_2 = list(rand(143, seed=2)) + self.assertNotEqual(numbers_1, numbers_2) + + def test_rand_independent(self): + """ + Check that setting the numpy seed does not affect return value. + """ + np.random.seed(1234) + numbers_1 = list(rand(143)) + + np.random.seed(4321) + numbers_2 = list(rand(143)) + + self.assertEqual(numbers_1, numbers_2) + + def test_rand_values(self): + """ + Check for specific values. + """ + a, b, c = rand(3) + self.assertAlmostEqual(a, 0.90141859) + self.assertAlmostEqual(b, 0.85225178) + self.assertAlmostEqual(c, 0.93632300) diff --git a/nnfwtbn/toydata.py b/nnfwtbn/toydata.py new file mode 100644 index 0000000000000000000000000000000000000000..2dffb59d85c44e7d9331221dfecf8847e13d0da4 --- /dev/null +++ b/nnfwtbn/toydata.py @@ -0,0 +1,33 @@ +""" +This module implements method to generate a deterministic, physics-inspired +toy dataset. The dataset is intended for documentations and examples. The +module does not rely on external random number generators (seeding numpy +might break user code). +""" + +import numpy as np +from numba import njit + +@njit +def rand(size=1, seed=1991): + """ + Returns a numpy array with random floats between 0 and 1. Calling the + method twice (with the same seed) returns the same array. + + The parameter n defines the length of the returned array. + """ + # This method uses a Linear congruential generator with the same + # parameters used in Virtual Pascal. + # https://en.wikipedia.org/wiki/Linear_congruential_generator + m = 2**32 + a = 134775813 + c = 1 + + output = np.empty(size) + seed = (a * seed + c) % m # make the output look more random + output[0] = (a * seed + c) % m + + for i in range(1, size): + output[i] = (a * output[i - 1] + c) % m + + return output / m diff --git a/requirements.txt b/requirements.txt index c0077b061bcd7461817d8383d03c0a421648dc56..fcadea7ecc2055e4ccb7a66d2be3136986f6a9df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy matplotlib seaborn pandas +numba