summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-07-01 15:48:19 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-07-01 15:48:19 -0600
commitcad21e94b58b125a4264f154e91a1730dcf550da (patch)
treeae6e402e0e8d60a03a493b4d7d11d2417041c13a /numpy/_array_api
parentf6015d2754dde04342ca2a0d719ca7f01d6e0dcb (diff)
downloadnumpy-cad21e94b58b125a4264f154e91a1730dcf550da.tar.gz
Update the linear algebra functions in the array API namespace
For now, only the functions in from the main spec namespace are implemented. The remaining linear algebra functions are part of an extension in the spec, and will be implemented in a future pull request. This is because the linear algebra functions are relatively complicated, so they will be easier to review separately. This also updates those functions that do remain for now to be more compliant with the spec.
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/__init__.py15
-rw-r--r--numpy/_array_api/_linear_algebra_functions.py181
2 files changed, 31 insertions, 165 deletions
diff --git a/numpy/_array_api/__init__.py b/numpy/_array_api/__init__.py
index 56699b09d..e39a2c7d0 100644
--- a/numpy/_array_api/__init__.py
+++ b/numpy/_array_api/__init__.py
@@ -47,8 +47,8 @@ A few notes about the current state of this submodule:
- np.argmin and np.argmax do not implement the keepdims keyword argument.
- - Some linear algebra functions in the spec are still a work in progress (to
- be added soon). These will be updated once the spec is.
+ - The linear algebra extension in the spec will be added in a future pull
+request.
- Some tests in the test suite are still not fully correct in that they test
all datatypes whereas certain functions are only defined for a subset of
@@ -132,13 +132,14 @@ from ._elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, at
__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
-from ._linear_algebra_functions import cross, det, diagonal, inv, norm, outer, trace, transpose
+# einsum is not yet implemented in the array API spec.
-__all__ += ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose']
+# from ._linear_algebra_functions import einsum
+# __all__ += ['einsum']
-# from ._linear_algebra_functions import cholesky, cross, det, diagonal, dot, eig, eigvalsh, einsum, inv, lstsq, matmul, matrix_power, matrix_rank, norm, outer, pinv, qr, slogdet, solve, svd, trace, transpose
-#
-# __all__ += ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']
+from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot
+
+__all__ += ['matmul', 'tensordot', 'transpose', 'vecdot']
from ._manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack
diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py
index 99a386866..461770641 100644
--- a/numpy/_array_api/_linear_algebra_functions.py
+++ b/numpy/_array_api/_linear_algebra_functions.py
@@ -1,70 +1,16 @@
from __future__ import annotations
from ._array_object import ndarray
+from ._dtypes import _numeric_dtypes
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from ._types import Literal, Optional, Tuple, Union, array
+ from ._types import Optional, Sequence, Tuple, Union, array
import numpy as np
-# def cholesky():
-# """
-# Array API compatible wrapper for :py:func:`np.cholesky <numpy.cholesky>`.
-#
-# See its docstring for more information.
-# """
-# return np.cholesky()
-
-def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
- """
- Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
+# einsum is not yet implemented in the array API spec.
- See its docstring for more information.
- """
- return ndarray._new(np.cross(x1._array, x2._array, axis=axis))
-
-def det(x: array, /) -> array:
- """
- Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
-
- See its docstring for more information.
- """
- # Note: this function is being imported from a nondefault namespace
- return ndarray._new(np.linalg.det(x._array))
-
-def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
- """
- Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
-
- See its docstring for more information.
- """
- return ndarray._new(np.diagonal(x._array, axis1=axis1, axis2=axis2, offset=offset))
-
-# def dot():
-# """
-# Array API compatible wrapper for :py:func:`np.dot <numpy.dot>`.
-#
-# See its docstring for more information.
-# """
-# return np.dot()
-#
-# def eig():
-# """
-# Array API compatible wrapper for :py:func:`np.eig <numpy.eig>`.
-#
-# See its docstring for more information.
-# """
-# return np.eig()
-#
-# def eigvalsh():
-# """
-# Array API compatible wrapper for :py:func:`np.eigvalsh <numpy.eigvalsh>`.
-#
-# See its docstring for more information.
-# """
-# return np.eigvalsh()
-#
# def einsum():
# """
# Array API compatible wrapper for :py:func:`np.einsum <numpy.einsum>`.
@@ -73,114 +19,27 @@ def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) ->
# """
# return np.einsum()
-def inv(x: array, /) -> array:
- """
- Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
-
- See its docstring for more information.
- """
- # Note: this function is being imported from a nondefault namespace
- return ndarray._new(np.linalg.inv(x._array))
-
-# def lstsq():
-# """
-# Array API compatible wrapper for :py:func:`np.lstsq <numpy.lstsq>`.
-#
-# See its docstring for more information.
-# """
-# return np.lstsq()
-#
-# def matmul():
-# """
-# Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
-#
-# See its docstring for more information.
-# """
-# return np.matmul()
-#
-# def matrix_power():
-# """
-# Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`.
-#
-# See its docstring for more information.
-# """
-# return np.matrix_power()
-#
-# def matrix_rank():
-# """
-# Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
-#
-# See its docstring for more information.
-# """
-# return np.matrix_rank()
-
-def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[np.inf, -np.inf, 'fro', 'nuc']]] = None) -> array:
- """
- Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
-
- See its docstring for more information.
- """
- # Note: this is different from the default behavior
- if axis == None and x.ndim > 2:
- x = ndarray._new(x._array.flatten())
- # Note: this function is being imported from a nondefault namespace
- return ndarray._new(np.linalg.norm(x._array, axis=axis, keepdims=keepdims, ord=ord))
-
-def outer(x1: array, x2: array, /) -> array:
+def matmul(x1: array, x2: array, /) -> array:
"""
- Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
+ Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
See its docstring for more information.
"""
- return ndarray._new(np.outer(x1._array, x2._array))
+ # Note: the restriction to numeric dtypes only is different from
+ # np.matmul.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in matmul')
-# def pinv():
-# """
-# Array API compatible wrapper for :py:func:`np.pinv <numpy.pinv>`.
-#
-# See its docstring for more information.
-# """
-# return np.pinv()
-#
-# def qr():
-# """
-# Array API compatible wrapper for :py:func:`np.qr <numpy.qr>`.
-#
-# See its docstring for more information.
-# """
-# return np.qr()
-#
-# def slogdet():
-# """
-# Array API compatible wrapper for :py:func:`np.slogdet <numpy.slogdet>`.
-#
-# See its docstring for more information.
-# """
-# return np.slogdet()
-#
-# def solve():
-# """
-# Array API compatible wrapper for :py:func:`np.solve <numpy.solve>`.
-#
-# See its docstring for more information.
-# """
-# return np.solve()
-#
-# def svd():
-# """
-# Array API compatible wrapper for :py:func:`np.svd <numpy.svd>`.
-#
-# See its docstring for more information.
-# """
-# return np.svd()
+ return ndarray._new(np.matmul(x1._array, x2._array))
-def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
- """
- Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
+# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
+def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
+ # Note: the restriction to numeric dtypes only is different from
+ # np.tensordot.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in tensordot')
- See its docstring for more information.
- """
- return ndarray._new(np.asarray(np.trace(x._array, axis1=axis1, axis2=axis2, offset=offset)))
+ return ndarray._new(np.tensordot(x1._array, x2._array, axes=axes))
def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
"""
@@ -189,3 +48,9 @@ def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
See its docstring for more information.
"""
return ndarray._new(np.transpose(x._array, axes=axes))
+
+# Note: vecdot is not in NumPy
+def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
+ if axis is None:
+ axis = -1
+ return tensordot(x1, x2, axes=((axis,), (axis,)))