summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/__init__.py15
-rw-r--r--numpy/_array_api/_linear_algebra_functions.py181
2 files changed, 31 insertions, 165 deletions
diff --git a/numpy/_array_api/__init__.py b/numpy/_array_api/__init__.py
index 56699b09d..e39a2c7d0 100644
--- a/numpy/_array_api/__init__.py
+++ b/numpy/_array_api/__init__.py
@@ -47,8 +47,8 @@ A few notes about the current state of this submodule:
- np.argmin and np.argmax do not implement the keepdims keyword argument.
- - Some linear algebra functions in the spec are still a work in progress (to
- be added soon). These will be updated once the spec is.
+ - The linear algebra extension in the spec will be added in a future pull
+request.
- Some tests in the test suite are still not fully correct in that they test
all datatypes whereas certain functions are only defined for a subset of
@@ -132,13 +132,14 @@ from ._elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, at
__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
-from ._linear_algebra_functions import cross, det, diagonal, inv, norm, outer, trace, transpose
+# einsum is not yet implemented in the array API spec.
-__all__ += ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose']
+# from ._linear_algebra_functions import einsum
+# __all__ += ['einsum']
-# from ._linear_algebra_functions import cholesky, cross, det, diagonal, dot, eig, eigvalsh, einsum, inv, lstsq, matmul, matrix_power, matrix_rank, norm, outer, pinv, qr, slogdet, solve, svd, trace, transpose
-#
-# __all__ += ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']
+from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot
+
+__all__ += ['matmul', 'tensordot', 'transpose', 'vecdot']
from ._manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack
diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py
index 99a386866..461770641 100644
--- a/numpy/_array_api/_linear_algebra_functions.py
+++ b/numpy/_array_api/_linear_algebra_functions.py
@@ -1,70 +1,16 @@
from __future__ import annotations
from ._array_object import ndarray
+from ._dtypes import _numeric_dtypes
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from ._types import Literal, Optional, Tuple, Union, array
+ from ._types import Optional, Sequence, Tuple, Union, array
import numpy as np
-# def cholesky():
-# """
-# Array API compatible wrapper for :py:func:`np.cholesky <numpy.cholesky>`.
-#
-# See its docstring for more information.
-# """
-# return np.cholesky()
-
-def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
- """
- Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
+# einsum is not yet implemented in the array API spec.
- See its docstring for more information.
- """
- return ndarray._new(np.cross(x1._array, x2._array, axis=axis))
-
-def det(x: array, /) -> array:
- """
- Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
-
- See its docstring for more information.
- """
- # Note: this function is being imported from a nondefault namespace
- return ndarray._new(np.linalg.det(x._array))
-
-def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
- """
- Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
-
- See its docstring for more information.
- """
- return ndarray._new(np.diagonal(x._array, axis1=axis1, axis2=axis2, offset=offset))
-
-# def dot():
-# """
-# Array API compatible wrapper for :py:func:`np.dot <numpy.dot>`.
-#
-# See its docstring for more information.
-# """
-# return np.dot()
-#
-# def eig():
-# """
-# Array API compatible wrapper for :py:func:`np.eig <numpy.eig>`.
-#
-# See its docstring for more information.
-# """
-# return np.eig()
-#
-# def eigvalsh():
-# """
-# Array API compatible wrapper for :py:func:`np.eigvalsh <numpy.eigvalsh>`.
-#
-# See its docstring for more information.
-# """
-# return np.eigvalsh()
-#
# def einsum():
# """
# Array API compatible wrapper for :py:func:`np.einsum <numpy.einsum>`.
@@ -73,114 +19,27 @@ def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) ->
# """
# return np.einsum()
-def inv(x: array, /) -> array:
- """
- Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
-
- See its docstring for more information.
- """
- # Note: this function is being imported from a nondefault namespace
- return ndarray._new(np.linalg.inv(x._array))
-
-# def lstsq():
-# """
-# Array API compatible wrapper for :py:func:`np.lstsq <numpy.lstsq>`.
-#
-# See its docstring for more information.
-# """
-# return np.lstsq()
-#
-# def matmul():
-# """
-# Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
-#
-# See its docstring for more information.
-# """
-# return np.matmul()
-#
-# def matrix_power():
-# """
-# Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`.
-#
-# See its docstring for more information.
-# """
-# return np.matrix_power()
-#
-# def matrix_rank():
-# """
-# Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
-#
-# See its docstring for more information.
-# """
-# return np.matrix_rank()
-
-def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[np.inf, -np.inf, 'fro', 'nuc']]] = None) -> array:
- """
- Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
-
- See its docstring for more information.
- """
- # Note: this is different from the default behavior
- if axis == None and x.ndim > 2:
- x = ndarray._new(x._array.flatten())
- # Note: this function is being imported from a nondefault namespace
- return ndarray._new(np.linalg.norm(x._array, axis=axis, keepdims=keepdims, ord=ord))
-
-def outer(x1: array, x2: array, /) -> array:
+def matmul(x1: array, x2: array, /) -> array:
"""
- Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
+ Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
See its docstring for more information.
"""
- return ndarray._new(np.outer(x1._array, x2._array))
+ # Note: the restriction to numeric dtypes only is different from
+ # np.matmul.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in matmul')
-# def pinv():
-# """
-# Array API compatible wrapper for :py:func:`np.pinv <numpy.pinv>`.
-#
-# See its docstring for more information.
-# """
-# return np.pinv()
-#
-# def qr():
-# """
-# Array API compatible wrapper for :py:func:`np.qr <numpy.qr>`.
-#
-# See its docstring for more information.
-# """
-# return np.qr()
-#
-# def slogdet():
-# """
-# Array API compatible wrapper for :py:func:`np.slogdet <numpy.slogdet>`.
-#
-# See its docstring for more information.
-# """
-# return np.slogdet()
-#
-# def solve():
-# """
-# Array API compatible wrapper for :py:func:`np.solve <numpy.solve>`.
-#
-# See its docstring for more information.
-# """
-# return np.solve()
-#
-# def svd():
-# """
-# Array API compatible wrapper for :py:func:`np.svd <numpy.svd>`.
-#
-# See its docstring for more information.
-# """
-# return np.svd()
+ return ndarray._new(np.matmul(x1._array, x2._array))
-def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
- """
- Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
+# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
+def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
+ # Note: the restriction to numeric dtypes only is different from
+ # np.tensordot.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in tensordot')
- See its docstring for more information.
- """
- return ndarray._new(np.asarray(np.trace(x._array, axis1=axis1, axis2=axis2, offset=offset)))
+ return ndarray._new(np.tensordot(x1._array, x2._array, axes=axes))
def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
"""
@@ -189,3 +48,9 @@ def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
See its docstring for more information.
"""
return ndarray._new(np.transpose(x._array, axes=axes))
+
+# Note: vecdot is not in NumPy
+def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
+ if axis is None:
+ axis = -1
+ return tensordot(x1, x2, axes=((axis,), (axis,)))