diff options
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 49 |
1 files changed, 48 insertions, 1 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 34fe9b550..f1568875a 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -22,12 +22,14 @@ from numpy.core import array, asarray, zeros, empty, transpose, \ intc, single, double, csingle, cdouble, inexact, complexfloating, \ newaxis, ravel, all, Inf, dot, add, multiply, identity, sqrt, \ maximum, flatnonzero, diagonal, arange, fastCopyAndTranspose, sum, \ - isfinite, size, finfo, absolute, log, exp + isfinite, size, finfo, absolute, log, exp, errstate, geterrobj from numpy.lib import triu from numpy.linalg import lapack_lite from numpy.matrixlib.defmatrix import matrix_power from numpy.compat import asbytes +from numpy.core import _umath_linalg + # For Python2/3 compatibility _N = asbytes('N') _V = asbytes('V') @@ -67,6 +69,40 @@ class LinAlgError(Exception): """ pass +# Dealing with errors in _umath_linalg + +_linalg_error_extobj = None + +def _determine_error_states(): + global _linalg_error_extobj + errobj = geterrobj() + bufsize = errobj[0] + + with errstate(invalid='call', over='ignore', + divide='ignore', under='ignore'): + invalid_call_errmask = geterrobj()[1] + + _linalg_error_extobj = [bufsize, invalid_call_errmask, None] + +_determine_error_states() + +def _raise_linalgerror_singular(err, flag): + raise LinAlgError("Singular matrix") + +def _raise_linalgerror_nonposdef(err, flag): + raise LinAlgError("Matrix is not positive definite") + +def _raise_linalgerror_eigenvalues_nonconvergence(err, flag): + raise LinAlgError("Eigenvalues did not converge") + +def _raise_linalgerror_svd_nonconvergence(err, flag): + raise LinAlgError("SVD did not converge") + +def get_linalg_error_extobj(callback): + extobj = list(_linalg_error_extobj) + extobj[2] = callback + return extobj + def _makearray(a): new = asarray(a) wrap = getattr(a, "__array_prepare__", new.__array_wrap__) @@ -158,11 +194,22 @@ def _assertRank2(*arrays): raise LinAlgError('%d-dimensional array given. Array must be ' 'two-dimensional' % len(a.shape)) +def _assertRankAtLeast2(*arrays): + for a in arrays: + if len(a.shape) < 2: + raise LinAlgError('%d-dimensional array given. Array must be ' + 'at least two-dimensional' % len(a.shape)) + def _assertSquareness(*arrays): for a in arrays: if max(a.shape) != min(a.shape): raise LinAlgError('Array must be square') +def _assertNdSquareness(*arrays): + for a in arrays: + if max(a.shape[-2:]) != min(a.shape[-2:]): + raise LinAlgError('Last 2 dimensions of the array must be square') + def _assertFinite(*arrays): for a in arrays: if not (isfinite(a).all()): |