summaryrefslogtreecommitdiff
path: root/numpy/random
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-02-10 13:34:45 +0100
committerSebastian Berg <sebastianb@nvidia.com>2023-02-14 20:19:28 +0100
commitc1fa0d981258063e00e8976c41d34a6b94b12516 (patch)
tree6ce24a148c4cef4863f6962b03372fb0d103246d /numpy/random
parent482d3fadbfe23a1d2a2bb179e90369e0d08b11be (diff)
downloadnumpy-c1fa0d981258063e00e8976c41d34a6b94b12516.tar.gz
API: Add `rng.spawn()`, `bit_gen.spawn()`, and `bit_gen.seed_seq`
This makes the seed sequence interface more public facing by: 1. Adding `BitGenerator.seed_seq` to give clear access to `_seed_seq` 2. Add `spawn()` to both the generator and the bit generator as convenience methods for spawning new instances. I somewhat remember that we always meant to consider making this more public and adding such convenient methods, but did not do so originally. So, now, I do wonder whether it is time to make this fully public? It would be nice to follow up at some point with a bit of best practices. This also doesn't add it to the `RandomState`, although doing it via `RandomState._bit_generator` is of course valid. Can we define as this kind of access as stable enough that downstream libraries could use it? I fear that backcompat with `RandomState` might make adopting newer things like spawning hard for libraries?
Diffstat (limited to 'numpy/random')
-rw-r--r--numpy/random/_generator.pyi1
-rw-r--r--numpy/random/_generator.pyx19
-rw-r--r--numpy/random/bit_generator.pyi3
-rw-r--r--numpy/random/bit_generator.pyx38
-rw-r--r--numpy/random/tests/test_direct.py40
5 files changed, 101 insertions, 0 deletions
diff --git a/numpy/random/_generator.pyi b/numpy/random/_generator.pyi
index f0d814fef..23c04e472 100644
--- a/numpy/random/_generator.pyi
+++ b/numpy/random/_generator.pyi
@@ -72,6 +72,7 @@ class Generator:
def __reduce__(self) -> tuple[Callable[[str], Generator], tuple[str], dict[str, Any]]: ...
@property
def bit_generator(self) -> BitGenerator: ...
+ def spawn(self, n_children: int) -> list[Generator]: ...
def bytes(self, length: int) -> bytes: ...
@overload
def standard_normal( # type: ignore[misc]
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index 83a4b2ad5..854a25af8 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -238,6 +238,25 @@ cdef class Generator:
"""
return self._bit_generator
+ def spawn(self, int n_children):
+ """
+ Create new independent child generators.
+
+ This is a convenience method to safely spawn new random number
+ generators via the `numpy.random.SeedSequence.spawn` mechanism.
+ The original seed sequence is used by the bit generator and accessible
+ as ``Generator.bit_generator.seed_seq``.
+
+ Please see `numpy.random.SeedSequence` for additional notes on
+ spawning children.
+
+ Returns
+ -------
+ child_generators : list of Generators
+
+ """
+ return [type(self)(g) for g in self._bit_generator.spawn(n_children)]
+
def random(self, size=None, dtype=np.float64, out=None):
"""
random(size=None, dtype=np.float64, out=None)
diff --git a/numpy/random/bit_generator.pyi b/numpy/random/bit_generator.pyi
index e6e3b10cd..8b9779cad 100644
--- a/numpy/random/bit_generator.pyi
+++ b/numpy/random/bit_generator.pyi
@@ -96,6 +96,9 @@ class BitGenerator(abc.ABC):
def state(self) -> Mapping[str, Any]: ...
@state.setter
def state(self, value: Mapping[str, Any]) -> None: ...
+ @property
+ def seed_seq(self) -> ISeedSequence: ...
+ def spawn(self, n_children: int) -> list[BitGenerator]: ...
@overload
def random_raw(self, size: None = ..., output: Literal[True] = ...) -> int: ... # type: ignore[misc]
@overload
diff --git a/numpy/random/bit_generator.pyx b/numpy/random/bit_generator.pyx
index 47804c487..f96dbe3c9 100644
--- a/numpy/random/bit_generator.pyx
+++ b/numpy/random/bit_generator.pyx
@@ -551,6 +551,44 @@ cdef class BitGenerator():
def state(self, value):
raise NotImplementedError('Not implemented in base BitGenerator')
+ @property
+ def seed_seq(self):
+ """
+ Get the seed sequence used to initialize the bit generator.
+
+ Returns
+ -------
+ seed_seq : ISeedSequence
+ The SeedSequence object used to initialize the BitGenerator.
+ This is normally a `np.random.SeedSequence` instance.
+
+ """
+ return self._seed_seq
+
+ def spawn(self, int n_children):
+ """
+ Create new independent child bit generators.
+
+ This is a convenience method to safely spawn new random number
+ generators via the `numpy.random.SeedSequence.spawn` mechanism.
+ The original seed sequence is accessible as ``bit_generator.seed_seq``.
+
+ Please see `numpy.random.SeedSequence` for additional notes on
+ spawning children.
+
+ Returns
+ -------
+ child_bit_generators : list of BitGenerators
+
+ """
+ if not isinstance(self._seed_seq, ISpawnableSeedSequence):
+ raise TypeError(
+ "The underlying SeedSequence does not implement spawning. "
+ "You must ensure a custom SeedSequence used for initializing "
+ "the random state implements spawning (and registers it).")
+
+ return [type(self)(seed=s) for s in self._seed_seq.spawn(n_children)]
+
def random_raw(self, size=None, output=True):
"""
random_raw(self, size=None)
diff --git a/numpy/random/tests/test_direct.py b/numpy/random/tests/test_direct.py
index 58d966adf..fa2ae866b 100644
--- a/numpy/random/tests/test_direct.py
+++ b/numpy/random/tests/test_direct.py
@@ -148,6 +148,46 @@ def test_seedsequence():
assert len(dummy.spawn(10)) == 10
+def test_generator_spawning():
+ """ Test spawning new generators and bit_generators directly.
+ """
+ rng = np.random.default_rng()
+ seq = rng.bit_generator.seed_seq
+ new_ss = seq.spawn(5)
+ expected_keys = [seq.spawn_key + (i,) for i in range(5)]
+ assert [c.spawn_key for c in new_ss] == expected_keys
+
+ new_bgs = rng.bit_generator.spawn(5)
+ expected_keys = [seq.spawn_key + (i,) for i in range(5, 10)]
+ assert [bg.seed_seq.spawn_key for bg in new_bgs] == expected_keys
+
+ new_rngs = rng.spawn(5)
+ expected_keys = [seq.spawn_key + (i,) for i in range(10, 15)]
+ found_keys = [rng.bit_generator.seed_seq.spawn_key for rng in new_rngs]
+ assert found_keys == expected_keys
+
+ # Sanity check that streams are actually different:
+ assert new_rngs[0].uniform() != new_rngs[1].uniform()
+
+
+def test_non_spawnable():
+ from numpy.random.bit_generator import ISeedSequence
+
+ class FakeSeedSequence:
+ def generate_state(self, n_words, dtype=np.uint32):
+ return np.zeros(n_words, dtype=dtype)
+
+ ISeedSequence.register(FakeSeedSequence)
+
+ rng = np.random.default_rng(FakeSeedSequence())
+
+ with pytest.raises(TypeError, match="The underlying SeedSequence"):
+ rng.spawn(5)
+
+ with pytest.raises(TypeError, match="The underlying SeedSequence"):
+ rng.bit_generator.spawn(5)
+
+
class Base:
dtype = np.uint64
data2 = data1 = {}