diff options
author | Pauli Virtanen <pav@iki.fi> | 2013-04-09 22:01:40 +0300 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2013-04-10 22:47:44 +0300 |
commit | 2e8b24e30345c75eb9101ea29e96c9fde48add7f (patch) | |
tree | fbb618cdc14b99a72a777324e3ea77c81bf2a781 /numpy/linalg/linalg.py | |
parent | f0a78c76e1fffdc222e527423871d8adfe7433f6 (diff) | |
download | numpy-2e8b24e30345c75eb9101ea29e96c9fde48add7f.tar.gz |
ENH: linalg: add helper routines for gufuncs
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 49 |
1 files changed, 48 insertions, 1 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 34fe9b550..f1568875a 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -22,12 +22,14 @@ from numpy.core import array, asarray, zeros, empty, transpose, \ intc, single, double, csingle, cdouble, inexact, complexfloating, \ newaxis, ravel, all, Inf, dot, add, multiply, identity, sqrt, \ maximum, flatnonzero, diagonal, arange, fastCopyAndTranspose, sum, \ - isfinite, size, finfo, absolute, log, exp + isfinite, size, finfo, absolute, log, exp, errstate, geterrobj from numpy.lib import triu from numpy.linalg import lapack_lite from numpy.matrixlib.defmatrix import matrix_power from numpy.compat import asbytes +from numpy.core import _umath_linalg + # For Python2/3 compatibility _N = asbytes('N') _V = asbytes('V') @@ -67,6 +69,40 @@ class LinAlgError(Exception): """ pass +# Dealing with errors in _umath_linalg + +_linalg_error_extobj = None + +def _determine_error_states(): + global _linalg_error_extobj + errobj = geterrobj() + bufsize = errobj[0] + + with errstate(invalid='call', over='ignore', + divide='ignore', under='ignore'): + invalid_call_errmask = geterrobj()[1] + + _linalg_error_extobj = [bufsize, invalid_call_errmask, None] + +_determine_error_states() + +def _raise_linalgerror_singular(err, flag): + raise LinAlgError("Singular matrix") + +def _raise_linalgerror_nonposdef(err, flag): + raise LinAlgError("Matrix is not positive definite") + +def _raise_linalgerror_eigenvalues_nonconvergence(err, flag): + raise LinAlgError("Eigenvalues did not converge") + +def _raise_linalgerror_svd_nonconvergence(err, flag): + raise LinAlgError("SVD did not converge") + +def get_linalg_error_extobj(callback): + extobj = list(_linalg_error_extobj) + extobj[2] = callback + return extobj + def _makearray(a): new = asarray(a) wrap = getattr(a, "__array_prepare__", new.__array_wrap__) @@ -158,11 +194,22 @@ def _assertRank2(*arrays): raise LinAlgError('%d-dimensional array given. Array must be ' 'two-dimensional' % len(a.shape)) +def _assertRankAtLeast2(*arrays): + for a in arrays: + if len(a.shape) < 2: + raise LinAlgError('%d-dimensional array given. Array must be ' + 'at least two-dimensional' % len(a.shape)) + def _assertSquareness(*arrays): for a in arrays: if max(a.shape) != min(a.shape): raise LinAlgError('Array must be square') +def _assertNdSquareness(*arrays): + for a in arrays: + if max(a.shape[-2:]) != min(a.shape[-2:]): + raise LinAlgError('Last 2 dimensions of the array must be square') + def _assertFinite(*arrays): for a in arrays: if not (isfinite(a).all()): |