diff options
Diffstat (limited to 'numpy/_array_api')
-rw-r--r-- | numpy/_array_api/__init__.py | 15 | ||||
-rw-r--r-- | numpy/_array_api/_linear_algebra_functions.py | 181 |
2 files changed, 31 insertions, 165 deletions
diff --git a/numpy/_array_api/__init__.py b/numpy/_array_api/__init__.py index 56699b09d..e39a2c7d0 100644 --- a/numpy/_array_api/__init__.py +++ b/numpy/_array_api/__init__.py @@ -47,8 +47,8 @@ A few notes about the current state of this submodule: - np.argmin and np.argmax do not implement the keepdims keyword argument. - - Some linear algebra functions in the spec are still a work in progress (to - be added soon). These will be updated once the spec is. + - The linear algebra extension in the spec will be added in a future pull +request. - Some tests in the test suite are still not fully correct in that they test all datatypes whereas certain functions are only defined for a subset of @@ -132,13 +132,14 @@ from ._elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, at __all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] -from ._linear_algebra_functions import cross, det, diagonal, inv, norm, outer, trace, transpose +# einsum is not yet implemented in the array API spec. -__all__ += ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose'] +# from ._linear_algebra_functions import einsum +# __all__ += ['einsum'] -# from ._linear_algebra_functions import cholesky, cross, det, diagonal, dot, eig, eigvalsh, einsum, inv, lstsq, matmul, matrix_power, matrix_rank, norm, outer, pinv, qr, slogdet, solve, svd, trace, transpose -# -# __all__ += ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose'] +from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot + +__all__ += ['matmul', 'tensordot', 'transpose', 'vecdot'] from ._manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py index 99a386866..461770641 100644 --- a/numpy/_array_api/_linear_algebra_functions.py +++ b/numpy/_array_api/_linear_algebra_functions.py @@ -1,70 +1,16 @@ from __future__ import annotations from ._array_object import ndarray +from ._dtypes import _numeric_dtypes from typing import TYPE_CHECKING if TYPE_CHECKING: - from ._types import Literal, Optional, Tuple, Union, array + from ._types import Optional, Sequence, Tuple, Union, array import numpy as np -# def cholesky(): -# """ -# Array API compatible wrapper for :py:func:`np.cholesky <numpy.cholesky>`. -# -# See its docstring for more information. -# """ -# return np.cholesky() - -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: - """ - Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`. +# einsum is not yet implemented in the array API spec. - See its docstring for more information. - """ - return ndarray._new(np.cross(x1._array, x2._array, axis=axis)) - -def det(x: array, /) -> array: - """ - Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`. - - See its docstring for more information. - """ - # Note: this function is being imported from a nondefault namespace - return ndarray._new(np.linalg.det(x._array)) - -def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array: - """ - Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`. - - See its docstring for more information. - """ - return ndarray._new(np.diagonal(x._array, axis1=axis1, axis2=axis2, offset=offset)) - -# def dot(): -# """ -# Array API compatible wrapper for :py:func:`np.dot <numpy.dot>`. -# -# See its docstring for more information. -# """ -# return np.dot() -# -# def eig(): -# """ -# Array API compatible wrapper for :py:func:`np.eig <numpy.eig>`. -# -# See its docstring for more information. -# """ -# return np.eig() -# -# def eigvalsh(): -# """ -# Array API compatible wrapper for :py:func:`np.eigvalsh <numpy.eigvalsh>`. -# -# See its docstring for more information. -# """ -# return np.eigvalsh() -# # def einsum(): # """ # Array API compatible wrapper for :py:func:`np.einsum <numpy.einsum>`. @@ -73,114 +19,27 @@ def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> # """ # return np.einsum() -def inv(x: array, /) -> array: - """ - Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`. - - See its docstring for more information. - """ - # Note: this function is being imported from a nondefault namespace - return ndarray._new(np.linalg.inv(x._array)) - -# def lstsq(): -# """ -# Array API compatible wrapper for :py:func:`np.lstsq <numpy.lstsq>`. -# -# See its docstring for more information. -# """ -# return np.lstsq() -# -# def matmul(): -# """ -# Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`. -# -# See its docstring for more information. -# """ -# return np.matmul() -# -# def matrix_power(): -# """ -# Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`. -# -# See its docstring for more information. -# """ -# return np.matrix_power() -# -# def matrix_rank(): -# """ -# Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`. -# -# See its docstring for more information. -# """ -# return np.matrix_rank() - -def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[np.inf, -np.inf, 'fro', 'nuc']]] = None) -> array: - """ - Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`. - - See its docstring for more information. - """ - # Note: this is different from the default behavior - if axis == None and x.ndim > 2: - x = ndarray._new(x._array.flatten()) - # Note: this function is being imported from a nondefault namespace - return ndarray._new(np.linalg.norm(x._array, axis=axis, keepdims=keepdims, ord=ord)) - -def outer(x1: array, x2: array, /) -> array: +def matmul(x1: array, x2: array, /) -> array: """ - Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`. + Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`. See its docstring for more information. """ - return ndarray._new(np.outer(x1._array, x2._array)) + # Note: the restriction to numeric dtypes only is different from + # np.matmul. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in matmul') -# def pinv(): -# """ -# Array API compatible wrapper for :py:func:`np.pinv <numpy.pinv>`. -# -# See its docstring for more information. -# """ -# return np.pinv() -# -# def qr(): -# """ -# Array API compatible wrapper for :py:func:`np.qr <numpy.qr>`. -# -# See its docstring for more information. -# """ -# return np.qr() -# -# def slogdet(): -# """ -# Array API compatible wrapper for :py:func:`np.slogdet <numpy.slogdet>`. -# -# See its docstring for more information. -# """ -# return np.slogdet() -# -# def solve(): -# """ -# Array API compatible wrapper for :py:func:`np.solve <numpy.solve>`. -# -# See its docstring for more information. -# """ -# return np.solve() -# -# def svd(): -# """ -# Array API compatible wrapper for :py:func:`np.svd <numpy.svd>`. -# -# See its docstring for more information. -# """ -# return np.svd() + return ndarray._new(np.matmul(x1._array, x2._array)) -def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array: - """ - Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`. +# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. +def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array: + # Note: the restriction to numeric dtypes only is different from + # np.tensordot. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in tensordot') - See its docstring for more information. - """ - return ndarray._new(np.asarray(np.trace(x._array, axis1=axis1, axis2=axis2, offset=offset))) + return ndarray._new(np.tensordot(x1._array, x2._array, axes=axes)) def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array: """ @@ -189,3 +48,9 @@ def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array: See its docstring for more information. """ return ndarray._new(np.transpose(x._array, axes=axes)) + +# Note: vecdot is not in NumPy +def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array: + if axis is None: + axis = -1 + return tensordot(x1, x2, axes=((axis,), (axis,))) |