summaryrefslogtreecommitdiff
path: root/networkx/utils
diff options
context:
space:
mode:
authorRoss Barnowski <rossbar@berkeley.edu>2022-02-18 08:07:25 -0800
committerGitHub <noreply@github.com>2022-02-18 20:07:25 +0400
commit88158b66e396b59d2d382600806e3d28c7bf5509 (patch)
tree9c432ff7759ab060dc61917959ec0c01523ae42a /networkx/utils
parent1adc82c0041b520bfe840f43a0dede1373eb697e (diff)
downloadnetworkx-88158b66e396b59d2d382600806e3d28c7bf5509.tar.gz
Add support for `numpy.random.Generator` (#5336)
* Add numpy.random.Generator to create_random_state. * Modify PythonRandomInterface to support Generator + testing. * Fix failing test.
Diffstat (limited to 'networkx/utils')
-rw-r--r--networkx/utils/misc.py29
-rw-r--r--networkx/utils/tests/test_misc.py35
2 files changed, 54 insertions, 10 deletions
diff --git a/networkx/utils/misc.py b/networkx/utils/misc.py
index 33e8e5d0..1a5d52c9 100644
--- a/networkx/utils/misc.py
+++ b/networkx/utils/misc.py
@@ -419,13 +419,15 @@ def to_tuple(x):
def create_random_state(random_state=None):
- """Returns a numpy.random.RandomState instance depending on input.
+ """Returns a numpy.random.RandomState or numpy.random.Generator instance
+ depending on input.
Parameters
----------
- random_state : int or RandomState instance or None optional (default=None)
+ random_state : int or NumPy RandomState or Generator instance, optional (default=None)
If int, return a numpy.random.RandomState instance set with seed=int.
- if numpy.random.RandomState instance, return it.
+ if `numpy.random.RandomState` instance, return it.
+ if `numpy.random.Generator` instance, return it.
if None or numpy.random, return the global random number generator used
by numpy.random.
"""
@@ -435,8 +437,11 @@ def create_random_state(random_state=None):
return random_state
if isinstance(random_state, int):
return np.random.RandomState(random_state)
+ if isinstance(random_state, np.random.Generator):
+ return random_state
msg = (
- f"{random_state} cannot be used to generate a numpy.random.RandomState instance"
+ f"{random_state} cannot be used to create a numpy.random.RandomState or\n"
+ "numpy.random.Generator instance"
)
raise ValueError(msg)
@@ -455,16 +460,24 @@ class PythonRandomInterface:
self._rng = rng
def random(self):
- return self._rng.random_sample()
+ return self._rng.random()
def uniform(self, a, b):
- return a + (b - a) * self._rng.random_sample()
+ return a + (b - a) * self._rng.random()
def randrange(self, a, b=None):
+ if isinstance(self._rng, np.random.Generator):
+ return self._rng.integers(a, b)
return self._rng.randint(a, b)
+ # NOTE: the numpy implementations of `choice` don't support strings, so
+ # this cannot be replaced with self._rng.choice
def choice(self, seq):
- return seq[self._rng.randint(0, len(seq))]
+ if isinstance(self._rng, np.random.Generator):
+ idx = self._rng.integers(0, len(seq))
+ else:
+ idx = self._rng.randint(0, len(seq))
+ return seq[idx]
def gauss(self, mu, sigma):
return self._rng.normal(mu, sigma)
@@ -479,6 +492,8 @@ class PythonRandomInterface:
return self._rng.choice(list(seq), size=(k,), replace=False)
def randint(self, a, b):
+ if isinstance(self._rng, np.random.Generator):
+ return self._rng.integers(a, b + 1)
return self._rng.randint(a, b + 1)
# exponential as expovariate with 1/argument,
diff --git a/networkx/utils/tests/test_misc.py b/networkx/utils/tests/test_misc.py
index f8b975ca..b38645e4 100644
--- a/networkx/utils/tests/test_misc.py
+++ b/networkx/utils/tests/test_misc.py
@@ -220,6 +220,9 @@ def test_create_random_state():
assert isinstance(create_random_state(None), rs)
assert isinstance(create_random_state(np.random), rs)
assert isinstance(create_random_state(rs(1)), rs)
+ # Support for numpy.random.Generator
+ rng = np.random.default_rng()
+ assert isinstance(create_random_state(rng), np.random.Generator)
pytest.raises(ValueError, create_random_state, "a")
assert np.all(rs(1).rand(10) == create_random_state(1).rand(10))
@@ -243,25 +246,51 @@ def test_create_py_random_state():
assert isinstance(PythonRandomInterface(), nprs)
-def test_PythonRandomInterface():
+def test_PythonRandomInterface_RandomState():
np = pytest.importorskip("numpy")
+
rs = np.random.RandomState
rng = PythonRandomInterface(rs(42))
rs42 = rs(42)
# make sure these functions are same as expected outcome
assert rng.randrange(3, 5) == rs42.randint(3, 5)
- assert np.all(rng.choice([1, 2, 3]) == rs42.choice([1, 2, 3]))
+ assert rng.choice([1, 2, 3]) == rs42.choice([1, 2, 3])
assert rng.gauss(0, 1) == rs42.normal(0, 1)
assert rng.expovariate(1.5) == rs42.exponential(1 / 1.5)
assert np.all(rng.shuffle([1, 2, 3]) == rs42.shuffle([1, 2, 3]))
assert np.all(
rng.sample([1, 2, 3], 2) == rs42.choice([1, 2, 3], (2,), replace=False)
)
- assert rng.randint(3, 5) == rs42.randint(3, 6)
+ assert np.all(
+ [rng.randint(3, 5) for _ in range(100)]
+ == [rs42.randint(3, 6) for _ in range(100)]
+ )
assert rng.random() == rs42.random_sample()
+def test_PythonRandomInterface_Generator():
+ np = pytest.importorskip("numpy")
+
+ rng = np.random.default_rng(42)
+ pri = PythonRandomInterface(np.random.default_rng(42))
+
+ # make sure these functions are same as expected outcome
+ assert pri.randrange(3, 5) == rng.integers(3, 5)
+ assert pri.choice([1, 2, 3]) == rng.choice([1, 2, 3])
+ assert pri.gauss(0, 1) == rng.normal(0, 1)
+ assert pri.expovariate(1.5) == rng.exponential(1 / 1.5)
+ assert np.all(pri.shuffle([1, 2, 3]) == rng.shuffle([1, 2, 3]))
+ assert np.all(
+ pri.sample([1, 2, 3], 2) == rng.choice([1, 2, 3], (2,), replace=False)
+ )
+ assert np.all(
+ [pri.randint(3, 5) for _ in range(100)]
+ == [rng.integers(3, 6) for _ in range(100)]
+ )
+ assert pri.random() == rng.random()
+
+
@pytest.mark.parametrize(
("iterable_type", "expected"), ((list, 1), (tuple, 1), (str, "["), (set, 1))
)