summaryrefslogtreecommitdiff
path: root/numpy/core/tests
diff options
context:
space:
mode:
authorRoy Smart <roytsmart@gmail.com>2023-02-24 17:16:32 -0700
committerRoy Smart <roytsmart@gmail.com>2023-03-06 02:43:57 -0700
commit3b89f5227e7d8f8eddd2959982efb6739efdb729 (patch)
treeed8fcef1a659acb52d791fcc82dc12b737c72def /numpy/core/tests
parentb4f6b0b9d2af3e5a7ed0bfb51c8d692298bc2dde (diff)
downloadnumpy-3b89f5227e7d8f8eddd2959982efb6739efdb729.tar.gz
ENH: Modify `numpy.logspace` so that the `base` argument broadcasts correctly against `start` and `stop`.
Diffstat (limited to 'numpy/core/tests')
-rw-r--r--numpy/core/tests/test_function_base.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/numpy/core/tests/test_function_base.py b/numpy/core/tests/test_function_base.py
index 21583dd44..79f1ecfc9 100644
--- a/numpy/core/tests/test_function_base.py
+++ b/numpy/core/tests/test_function_base.py
@@ -1,3 +1,4 @@
+import pytest
from numpy import (
logspace, linspace, geomspace, dtype, array, sctypes, arange, isnan,
ndarray, sqrt, nextafter, stack, errstate
@@ -65,6 +66,33 @@ class TestLogspace:
t5 = logspace(start, stop, 6, axis=-1)
assert_equal(t5, t2.T)
+ @pytest.mark.parametrize("axis", [0, 1, -1])
+ def test_base_array(self, axis: int):
+ start = 1
+ stop = 2
+ num = 6
+ base = array([1, 2])
+ t1 = logspace(start, stop, num=num, base=base, axis=axis)
+ t2 = stack(
+ [logspace(start, stop, num=num, base=_base) for _base in base],
+ axis=(axis + 1) % t1.ndim,
+ )
+ assert_equal(t1, t2)
+
+ @pytest.mark.parametrize("axis", [0, 1, -1])
+ def test_stop_base_array(self, axis: int):
+ start = 1
+ stop = array([2, 3])
+ num = 6
+ base = array([1, 2])
+ t1 = logspace(start, stop, num=num, base=base, axis=axis)
+ t2 = stack(
+ [logspace(start, _stop, num=num, base=_base)
+ for _stop, _base in zip(stop, base)],
+ axis=(axis + 1) % t1.ndim,
+ )
+ assert_equal(t1, t2)
+
def test_dtype(self):
y = logspace(0, 6, dtype='float32')
assert_equal(y.dtype, dtype('float32'))