diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2022-03-28 18:12:42 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2022-03-28 18:12:42 -0600 |
commit | f375d71ca101db9541b1e70476999b574634556d (patch) | |
tree | 43ee09ff5c2fffcb02437b3fa9d5cd02ad62ed40 /numpy/array_api/linalg.py | |
parent | f306e941ffb0bc49604f5507aa8ea614d933cfd0 (diff) | |
download | numpy-f375d71ca101db9541b1e70476999b574634556d.tar.gz |
Properly restrict the input dtypes for the array_api trace, svdvals, and vecdot
Diffstat (limited to 'numpy/array_api/linalg.py')
-rw-r--r-- | numpy/array_api/linalg.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index cb04113b6..8398147ba 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -344,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) # Note: tensordot is the numpy top-level namespace but not in np.linalg @@ -364,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in trace') # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1))) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in vecdot') return tensordot(x1, x2, axes=((axis,), (axis,))) |