diff options
| author | Sebastian Berg <sebastianb@nvidia.com> | 2023-02-10 13:34:45 +0100 |
|---|---|---|
| committer | Sebastian Berg <sebastianb@nvidia.com> | 2023-02-14 20:19:28 +0100 |
| commit | c1fa0d981258063e00e8976c41d34a6b94b12516 (patch) | |
| tree | 6ce24a148c4cef4863f6962b03372fb0d103246d /numpy/random | |
| parent | 482d3fadbfe23a1d2a2bb179e90369e0d08b11be (diff) | |
| download | numpy-c1fa0d981258063e00e8976c41d34a6b94b12516.tar.gz | |
API: Add `rng.spawn()`, `bit_gen.spawn()`, and `bit_gen.seed_seq`
This makes the seed sequence interface more public facing by:
1. Adding `BitGenerator.seed_seq` to give clear access to `_seed_seq`
2. Add `spawn()` to both the generator and the bit generator as
convenience methods for spawning new instances.
I somewhat remember that we always meant to consider making this
more public and adding such convenient methods, but did not do
so originally.
So, now, I do wonder whether it is time to make this fully public?
It would be nice to follow up at some point with a bit of best practices.
This also doesn't add it to the `RandomState`, although doing it via
`RandomState._bit_generator` is of course valid.
Can we define as this kind of access as stable enough that downstream
libraries could use it? I fear that backcompat with `RandomState`
might make adopting newer things like spawning hard for libraries?
Diffstat (limited to 'numpy/random')
| -rw-r--r-- | numpy/random/_generator.pyi | 1 | ||||
| -rw-r--r-- | numpy/random/_generator.pyx | 19 | ||||
| -rw-r--r-- | numpy/random/bit_generator.pyi | 3 | ||||
| -rw-r--r-- | numpy/random/bit_generator.pyx | 38 | ||||
| -rw-r--r-- | numpy/random/tests/test_direct.py | 40 |
5 files changed, 101 insertions, 0 deletions
diff --git a/numpy/random/_generator.pyi b/numpy/random/_generator.pyi index f0d814fef..23c04e472 100644 --- a/numpy/random/_generator.pyi +++ b/numpy/random/_generator.pyi @@ -72,6 +72,7 @@ class Generator: def __reduce__(self) -> tuple[Callable[[str], Generator], tuple[str], dict[str, Any]]: ... @property def bit_generator(self) -> BitGenerator: ... + def spawn(self, n_children: int) -> list[Generator]: ... def bytes(self, length: int) -> bytes: ... @overload def standard_normal( # type: ignore[misc] diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index 83a4b2ad5..854a25af8 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -238,6 +238,25 @@ cdef class Generator: """ return self._bit_generator + def spawn(self, int n_children): + """ + Create new independent child generators. + + This is a convenience method to safely spawn new random number + generators via the `numpy.random.SeedSequence.spawn` mechanism. + The original seed sequence is used by the bit generator and accessible + as ``Generator.bit_generator.seed_seq``. + + Please see `numpy.random.SeedSequence` for additional notes on + spawning children. + + Returns + ------- + child_generators : list of Generators + + """ + return [type(self)(g) for g in self._bit_generator.spawn(n_children)] + def random(self, size=None, dtype=np.float64, out=None): """ random(size=None, dtype=np.float64, out=None) diff --git a/numpy/random/bit_generator.pyi b/numpy/random/bit_generator.pyi index e6e3b10cd..8b9779cad 100644 --- a/numpy/random/bit_generator.pyi +++ b/numpy/random/bit_generator.pyi @@ -96,6 +96,9 @@ class BitGenerator(abc.ABC): def state(self) -> Mapping[str, Any]: ... @state.setter def state(self, value: Mapping[str, Any]) -> None: ... + @property + def seed_seq(self) -> ISeedSequence: ... + def spawn(self, n_children: int) -> list[BitGenerator]: ... @overload def random_raw(self, size: None = ..., output: Literal[True] = ...) -> int: ... # type: ignore[misc] @overload diff --git a/numpy/random/bit_generator.pyx b/numpy/random/bit_generator.pyx index 47804c487..f96dbe3c9 100644 --- a/numpy/random/bit_generator.pyx +++ b/numpy/random/bit_generator.pyx @@ -551,6 +551,44 @@ cdef class BitGenerator(): def state(self, value): raise NotImplementedError('Not implemented in base BitGenerator') + @property + def seed_seq(self): + """ + Get the seed sequence used to initialize the bit generator. + + Returns + ------- + seed_seq : ISeedSequence + The SeedSequence object used to initialize the BitGenerator. + This is normally a `np.random.SeedSequence` instance. + + """ + return self._seed_seq + + def spawn(self, int n_children): + """ + Create new independent child bit generators. + + This is a convenience method to safely spawn new random number + generators via the `numpy.random.SeedSequence.spawn` mechanism. + The original seed sequence is accessible as ``bit_generator.seed_seq``. + + Please see `numpy.random.SeedSequence` for additional notes on + spawning children. + + Returns + ------- + child_bit_generators : list of BitGenerators + + """ + if not isinstance(self._seed_seq, ISpawnableSeedSequence): + raise TypeError( + "The underlying SeedSequence does not implement spawning. " + "You must ensure a custom SeedSequence used for initializing " + "the random state implements spawning (and registers it).") + + return [type(self)(seed=s) for s in self._seed_seq.spawn(n_children)] + def random_raw(self, size=None, output=True): """ random_raw(self, size=None) diff --git a/numpy/random/tests/test_direct.py b/numpy/random/tests/test_direct.py index 58d966adf..fa2ae866b 100644 --- a/numpy/random/tests/test_direct.py +++ b/numpy/random/tests/test_direct.py @@ -148,6 +148,46 @@ def test_seedsequence(): assert len(dummy.spawn(10)) == 10 +def test_generator_spawning(): + """ Test spawning new generators and bit_generators directly. + """ + rng = np.random.default_rng() + seq = rng.bit_generator.seed_seq + new_ss = seq.spawn(5) + expected_keys = [seq.spawn_key + (i,) for i in range(5)] + assert [c.spawn_key for c in new_ss] == expected_keys + + new_bgs = rng.bit_generator.spawn(5) + expected_keys = [seq.spawn_key + (i,) for i in range(5, 10)] + assert [bg.seed_seq.spawn_key for bg in new_bgs] == expected_keys + + new_rngs = rng.spawn(5) + expected_keys = [seq.spawn_key + (i,) for i in range(10, 15)] + found_keys = [rng.bit_generator.seed_seq.spawn_key for rng in new_rngs] + assert found_keys == expected_keys + + # Sanity check that streams are actually different: + assert new_rngs[0].uniform() != new_rngs[1].uniform() + + +def test_non_spawnable(): + from numpy.random.bit_generator import ISeedSequence + + class FakeSeedSequence: + def generate_state(self, n_words, dtype=np.uint32): + return np.zeros(n_words, dtype=dtype) + + ISeedSequence.register(FakeSeedSequence) + + rng = np.random.default_rng(FakeSeedSequence()) + + with pytest.raises(TypeError, match="The underlying SeedSequence"): + rng.spawn(5) + + with pytest.raises(TypeError, match="The underlying SeedSequence"): + rng.bit_generator.spawn(5) + + class Base: dtype = np.uint64 data2 = data1 = {} |
