diff options
author | Ross Barnowski <rossbar@berkeley.edu> | 2022-02-18 08:07:25 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-18 20:07:25 +0400 |
commit | 88158b66e396b59d2d382600806e3d28c7bf5509 (patch) | |
tree | 9c432ff7759ab060dc61917959ec0c01523ae42a /networkx/utils | |
parent | 1adc82c0041b520bfe840f43a0dede1373eb697e (diff) | |
download | networkx-88158b66e396b59d2d382600806e3d28c7bf5509.tar.gz |
Add support for `numpy.random.Generator` (#5336)
* Add numpy.random.Generator to create_random_state.
* Modify PythonRandomInterface to support Generator + testing.
* Fix failing test.
Diffstat (limited to 'networkx/utils')
-rw-r--r-- | networkx/utils/misc.py | 29 | ||||
-rw-r--r-- | networkx/utils/tests/test_misc.py | 35 |
2 files changed, 54 insertions, 10 deletions
diff --git a/networkx/utils/misc.py b/networkx/utils/misc.py index 33e8e5d0..1a5d52c9 100644 --- a/networkx/utils/misc.py +++ b/networkx/utils/misc.py @@ -419,13 +419,15 @@ def to_tuple(x): def create_random_state(random_state=None): - """Returns a numpy.random.RandomState instance depending on input. + """Returns a numpy.random.RandomState or numpy.random.Generator instance + depending on input. Parameters ---------- - random_state : int or RandomState instance or None optional (default=None) + random_state : int or NumPy RandomState or Generator instance, optional (default=None) If int, return a numpy.random.RandomState instance set with seed=int. - if numpy.random.RandomState instance, return it. + if `numpy.random.RandomState` instance, return it. + if `numpy.random.Generator` instance, return it. if None or numpy.random, return the global random number generator used by numpy.random. """ @@ -435,8 +437,11 @@ def create_random_state(random_state=None): return random_state if isinstance(random_state, int): return np.random.RandomState(random_state) + if isinstance(random_state, np.random.Generator): + return random_state msg = ( - f"{random_state} cannot be used to generate a numpy.random.RandomState instance" + f"{random_state} cannot be used to create a numpy.random.RandomState or\n" + "numpy.random.Generator instance" ) raise ValueError(msg) @@ -455,16 +460,24 @@ class PythonRandomInterface: self._rng = rng def random(self): - return self._rng.random_sample() + return self._rng.random() def uniform(self, a, b): - return a + (b - a) * self._rng.random_sample() + return a + (b - a) * self._rng.random() def randrange(self, a, b=None): + if isinstance(self._rng, np.random.Generator): + return self._rng.integers(a, b) return self._rng.randint(a, b) + # NOTE: the numpy implementations of `choice` don't support strings, so + # this cannot be replaced with self._rng.choice def choice(self, seq): - return seq[self._rng.randint(0, len(seq))] + if isinstance(self._rng, np.random.Generator): + idx = self._rng.integers(0, len(seq)) + else: + idx = self._rng.randint(0, len(seq)) + return seq[idx] def gauss(self, mu, sigma): return self._rng.normal(mu, sigma) @@ -479,6 +492,8 @@ class PythonRandomInterface: return self._rng.choice(list(seq), size=(k,), replace=False) def randint(self, a, b): + if isinstance(self._rng, np.random.Generator): + return self._rng.integers(a, b + 1) return self._rng.randint(a, b + 1) # exponential as expovariate with 1/argument, diff --git a/networkx/utils/tests/test_misc.py b/networkx/utils/tests/test_misc.py index f8b975ca..b38645e4 100644 --- a/networkx/utils/tests/test_misc.py +++ b/networkx/utils/tests/test_misc.py @@ -220,6 +220,9 @@ def test_create_random_state(): assert isinstance(create_random_state(None), rs) assert isinstance(create_random_state(np.random), rs) assert isinstance(create_random_state(rs(1)), rs) + # Support for numpy.random.Generator + rng = np.random.default_rng() + assert isinstance(create_random_state(rng), np.random.Generator) pytest.raises(ValueError, create_random_state, "a") assert np.all(rs(1).rand(10) == create_random_state(1).rand(10)) @@ -243,25 +246,51 @@ def test_create_py_random_state(): assert isinstance(PythonRandomInterface(), nprs) -def test_PythonRandomInterface(): +def test_PythonRandomInterface_RandomState(): np = pytest.importorskip("numpy") + rs = np.random.RandomState rng = PythonRandomInterface(rs(42)) rs42 = rs(42) # make sure these functions are same as expected outcome assert rng.randrange(3, 5) == rs42.randint(3, 5) - assert np.all(rng.choice([1, 2, 3]) == rs42.choice([1, 2, 3])) + assert rng.choice([1, 2, 3]) == rs42.choice([1, 2, 3]) assert rng.gauss(0, 1) == rs42.normal(0, 1) assert rng.expovariate(1.5) == rs42.exponential(1 / 1.5) assert np.all(rng.shuffle([1, 2, 3]) == rs42.shuffle([1, 2, 3])) assert np.all( rng.sample([1, 2, 3], 2) == rs42.choice([1, 2, 3], (2,), replace=False) ) - assert rng.randint(3, 5) == rs42.randint(3, 6) + assert np.all( + [rng.randint(3, 5) for _ in range(100)] + == [rs42.randint(3, 6) for _ in range(100)] + ) assert rng.random() == rs42.random_sample() +def test_PythonRandomInterface_Generator(): + np = pytest.importorskip("numpy") + + rng = np.random.default_rng(42) + pri = PythonRandomInterface(np.random.default_rng(42)) + + # make sure these functions are same as expected outcome + assert pri.randrange(3, 5) == rng.integers(3, 5) + assert pri.choice([1, 2, 3]) == rng.choice([1, 2, 3]) + assert pri.gauss(0, 1) == rng.normal(0, 1) + assert pri.expovariate(1.5) == rng.exponential(1 / 1.5) + assert np.all(pri.shuffle([1, 2, 3]) == rng.shuffle([1, 2, 3])) + assert np.all( + pri.sample([1, 2, 3], 2) == rng.choice([1, 2, 3], (2,), replace=False) + ) + assert np.all( + [pri.randint(3, 5) for _ in range(100)] + == [rng.integers(3, 6) for _ in range(100)] + ) + assert pri.random() == rng.random() + + @pytest.mark.parametrize( ("iterable_type", "expected"), ((list, 1), (tuple, 1), (str, "["), (set, 1)) ) |