summaryrefslogtreecommitdiff
path: root/numpy/_array_api/linear_algebra_functions.py
blob: 820dfffba3206b654f976929f2be02efc383d385 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# def cholesky():
#     from .. import cholesky
#     return cholesky()

def cross(x1, x2, /, *, axis=-1):
    from .. import cross
    return cross(x1, x2, axis=axis)

def det(x, /):
    # Note: this function is being imported from a nondefault namespace
    from ..linalg import det
    return det(x)

def diagonal(x, /, *, axis1=0, axis2=1, offset=0):
    from .. import diagonal
    return diagonal(x, axis1=axis1, axis2=axis2, offset=offset)

# def dot():
#     from .. import dot
#     return dot()
#
# def eig():
#     from .. import eig
#     return eig()
#
# def eigvalsh():
#     from .. import eigvalsh
#     return eigvalsh()
#
# def einsum():
#     from .. import einsum
#     return einsum()

def inv(x):
    # Note: this function is being imported from a nondefault namespace
    from ..linalg import inv
    return inv(x)

# def lstsq():
#     from .. import lstsq
#     return lstsq()
#
# def matmul():
#     from .. import matmul
#     return matmul()
#
# def matrix_power():
#     from .. import matrix_power
#     return matrix_power()
#
# def matrix_rank():
#     from .. import matrix_rank
#     return matrix_rank()

def norm(x, /, *, axis=None, keepdims=False, ord=None):
    # Note: this function is being imported from a nondefault namespace
    from ..linalg import norm
    # Note: this is different from the default behavior
    if axis == None and x.ndim > 2:
        x = x.flatten()
    return norm(x, axis=axis, keepdims=keepdims, ord=ord)

def outer(x1, x2, /):
    from .. import outer
    return outer(x1, x2)

# def pinv():
#     from .. import pinv
#     return pinv()
#
# def qr():
#     from .. import qr
#     return qr()
#
# def slogdet():
#     from .. import slogdet
#     return slogdet()
#
# def solve():
#     from .. import solve
#     return solve()
#
# def svd():
#     from .. import svd
#     return svd()

def trace(x, /, *, axis1=0, axis2=1, offset=0):
    from .. import trace
    return trace(x, axis1=axis1, axis2=axis2, offset=offset)

def transpose(x, /, *, axes=None):
    from .. import transpose
    return transpose(x, axes=axes)

# __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']

__all__ = ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose']