diff options
author | Roy Smart <roytsmart@gmail.com> | 2023-02-24 17:16:32 -0700 |
---|---|---|
committer | Roy Smart <roytsmart@gmail.com> | 2023-03-06 02:43:57 -0700 |
commit | 3b89f5227e7d8f8eddd2959982efb6739efdb729 (patch) | |
tree | ed8fcef1a659acb52d791fcc82dc12b737c72def /numpy/core/tests | |
parent | b4f6b0b9d2af3e5a7ed0bfb51c8d692298bc2dde (diff) | |
download | numpy-3b89f5227e7d8f8eddd2959982efb6739efdb729.tar.gz |
ENH: Modify `numpy.logspace` so that the `base` argument broadcasts correctly against `start` and `stop`.
Diffstat (limited to 'numpy/core/tests')
-rw-r--r-- | numpy/core/tests/test_function_base.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/numpy/core/tests/test_function_base.py b/numpy/core/tests/test_function_base.py index 21583dd44..79f1ecfc9 100644 --- a/numpy/core/tests/test_function_base.py +++ b/numpy/core/tests/test_function_base.py @@ -1,3 +1,4 @@ +import pytest from numpy import ( logspace, linspace, geomspace, dtype, array, sctypes, arange, isnan, ndarray, sqrt, nextafter, stack, errstate @@ -65,6 +66,33 @@ class TestLogspace: t5 = logspace(start, stop, 6, axis=-1) assert_equal(t5, t2.T) + @pytest.mark.parametrize("axis", [0, 1, -1]) + def test_base_array(self, axis: int): + start = 1 + stop = 2 + num = 6 + base = array([1, 2]) + t1 = logspace(start, stop, num=num, base=base, axis=axis) + t2 = stack( + [logspace(start, stop, num=num, base=_base) for _base in base], + axis=(axis + 1) % t1.ndim, + ) + assert_equal(t1, t2) + + @pytest.mark.parametrize("axis", [0, 1, -1]) + def test_stop_base_array(self, axis: int): + start = 1 + stop = array([2, 3]) + num = 6 + base = array([1, 2]) + t1 = logspace(start, stop, num=num, base=base, axis=axis) + t2 = stack( + [logspace(start, _stop, num=num, base=_base) + for _stop, _base in zip(stop, base)], + axis=(axis + 1) % t1.ndim, + ) + assert_equal(t1, t2) + def test_dtype(self): y = logspace(0, 6, dtype='float32') assert_equal(y.dtype, dtype('float32')) |