summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2013-04-09 22:01:40 +0300
committerPauli Virtanen <pav@iki.fi>2013-04-10 22:47:44 +0300
commit2e8b24e30345c75eb9101ea29e96c9fde48add7f (patch)
treefbb618cdc14b99a72a777324e3ea77c81bf2a781 /numpy/linalg/linalg.py
parentf0a78c76e1fffdc222e527423871d8adfe7433f6 (diff)
downloadnumpy-2e8b24e30345c75eb9101ea29e96c9fde48add7f.tar.gz
ENH: linalg: add helper routines for gufuncs
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py49
1 files changed, 48 insertions, 1 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 34fe9b550..f1568875a 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -22,12 +22,14 @@ from numpy.core import array, asarray, zeros, empty, transpose, \
intc, single, double, csingle, cdouble, inexact, complexfloating, \
newaxis, ravel, all, Inf, dot, add, multiply, identity, sqrt, \
maximum, flatnonzero, diagonal, arange, fastCopyAndTranspose, sum, \
- isfinite, size, finfo, absolute, log, exp
+ isfinite, size, finfo, absolute, log, exp, errstate, geterrobj
from numpy.lib import triu
from numpy.linalg import lapack_lite
from numpy.matrixlib.defmatrix import matrix_power
from numpy.compat import asbytes
+from numpy.core import _umath_linalg
+
# For Python2/3 compatibility
_N = asbytes('N')
_V = asbytes('V')
@@ -67,6 +69,40 @@ class LinAlgError(Exception):
"""
pass
+# Dealing with errors in _umath_linalg
+
+_linalg_error_extobj = None
+
+def _determine_error_states():
+ global _linalg_error_extobj
+ errobj = geterrobj()
+ bufsize = errobj[0]
+
+ with errstate(invalid='call', over='ignore',
+ divide='ignore', under='ignore'):
+ invalid_call_errmask = geterrobj()[1]
+
+ _linalg_error_extobj = [bufsize, invalid_call_errmask, None]
+
+_determine_error_states()
+
+def _raise_linalgerror_singular(err, flag):
+ raise LinAlgError("Singular matrix")
+
+def _raise_linalgerror_nonposdef(err, flag):
+ raise LinAlgError("Matrix is not positive definite")
+
+def _raise_linalgerror_eigenvalues_nonconvergence(err, flag):
+ raise LinAlgError("Eigenvalues did not converge")
+
+def _raise_linalgerror_svd_nonconvergence(err, flag):
+ raise LinAlgError("SVD did not converge")
+
+def get_linalg_error_extobj(callback):
+ extobj = list(_linalg_error_extobj)
+ extobj[2] = callback
+ return extobj
+
def _makearray(a):
new = asarray(a)
wrap = getattr(a, "__array_prepare__", new.__array_wrap__)
@@ -158,11 +194,22 @@ def _assertRank2(*arrays):
raise LinAlgError('%d-dimensional array given. Array must be '
'two-dimensional' % len(a.shape))
+def _assertRankAtLeast2(*arrays):
+ for a in arrays:
+ if len(a.shape) < 2:
+ raise LinAlgError('%d-dimensional array given. Array must be '
+ 'at least two-dimensional' % len(a.shape))
+
def _assertSquareness(*arrays):
for a in arrays:
if max(a.shape) != min(a.shape):
raise LinAlgError('Array must be square')
+def _assertNdSquareness(*arrays):
+ for a in arrays:
+ if max(a.shape[-2:]) != min(a.shape[-2:]):
+ raise LinAlgError('Last 2 dimensions of the array must be square')
+
def _assertFinite(*arrays):
for a in arrays:
if not (isfinite(a).all()):