summaryrefslogtreecommitdiff
path: root/numpy/array_api/linalg.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2022-03-28 18:12:42 -0600
committerAaron Meurer <asmeurer@gmail.com>2022-03-28 18:12:42 -0600
commitf375d71ca101db9541b1e70476999b574634556d (patch)
tree43ee09ff5c2fffcb02437b3fa9d5cd02ad62ed40 /numpy/array_api/linalg.py
parentf306e941ffb0bc49604f5507aa8ea614d933cfd0 (diff)
downloadnumpy-f375d71ca101db9541b1e70476999b574634556d.tar.gz
Properly restrict the input dtypes for the array_api trace, svdvals, and vecdot
Diffstat (limited to 'numpy/array_api/linalg.py')
-rw-r--r--numpy/array_api/linalg.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py
index cb04113b6..8398147ba 100644
--- a/numpy/array_api/linalg.py
+++ b/numpy/array_api/linalg.py
@@ -344,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# np.linalg.svd(compute_uv=False).
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in svdvals')
return Array._new(np.linalg.svd(x._array, compute_uv=False))
# Note: tensordot is the numpy top-level namespace but not in np.linalg
@@ -364,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in trace')
# Note: trace always operates on the last two axes, whereas np.trace
# operates on the first two axes by default
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
# Note: vecdot is not in NumPy
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in vecdot')
return tensordot(x1, x2, axes=((axis,), (axis,)))