diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2013-07-10 14:19:42 +0200 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2013-08-04 19:14:25 +0200 |
commit | fa55f4c463806599bccf145baf22e13ff79f9a68 (patch) | |
tree | a7842184034542b4b9f7c9eb51871f7f41c1d457 | |
parent | 90ececac5755b0b54d7c8c2c5d71caaeb5c0b45c (diff) | |
download | numpy-fa55f4c463806599bccf145baf22e13ff79f9a68.tar.gz |
ENH: inv/solve work with empty inner and others empty outer array
This makes the inverse of a 0x0 array simply be 0x0 again. It
also modifies the no-empty array check in favor of a no-empty
*inner* array, since the gufuncs seem to handle the other case
fine.
-rw-r--r-- | numpy/linalg/linalg.py | 50 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 48 |
2 files changed, 82 insertions, 16 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 760ca31b0..4b0d3d86d 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -19,10 +19,11 @@ __all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv', import warnings from numpy.core import ( - array, asarray, zeros, empty, transpose, intc, single, double, csingle, - cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot, add, - multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size, - finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax + array, asarray, zeros, empty, empty_like, transpose, intc, single, double, + csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot, + add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size, + finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, + broadcast ) from numpy.lib import triu, asfarray from numpy.linalg import lapack_lite, _umath_linalg @@ -215,9 +216,9 @@ def _assertFinite(*arrays): if not (isfinite(a).all()): raise LinAlgError("Array must not contain infs or NaNs") -def _assertNonEmpty(*arrays): +def _assertNoEmpty2d(*arrays): for a in arrays: - if size(a) == 0: + if a.size == 0 and product(a.shape[-2:]) == 0: raise LinAlgError("Arrays cannot be empty") @@ -350,15 +351,28 @@ def solve(a, b): """ a, _ = _makearray(a) - _assertNonEmpty(a) _assertRankAtLeast2(a) _assertNdSquareness(a) b, wrap = _makearray(b) t, result_t = _commonType(a, b) - if len(b.shape) == len(a.shape) - 1: + # We use the b = (..., M,) logic, only if the number of extra dimensions + # match exactly + if b.ndim == a.ndim - 1: + if a.shape[-1] == 0 and b.shape[-1] == 0: + # Legal, but the ufunc cannot handle the 0-sized inner dims + # let the ufunc handle all wrong cases. + a = a.reshape(a.shape[:-1]) + bc = broadcast(a, b) + return wrap(empty(bc.shape, dtype=result_t)) + gufunc = _umath_linalg.solve1 else: + if a.shape[-1] == 0 and b.shape[-2] == 0: + a = a.reshape(a.shape[:-1] + (1,)) + bc = broadcast(a, b) + return wrap(empty(bc.shape, dtype=result_t)) + gufunc = _umath_linalg.solve signature = 'DD->D' if isComplexType(t) else 'dd->d' @@ -492,10 +506,14 @@ def inv(a): """ a, wrap = _makearray(a) - _assertNonEmpty(a) _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) + + if a.shape[-1] == 0: + # The inner array is 0x0, the ufunc cannot handle this case + return wrap(empty_like(a, dtype=result_t)) + signature = 'D->D' if isComplexType(t) else 'd->d' extobj = get_linalg_error_extobj(_raise_linalgerror_singular) ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj) @@ -718,7 +736,7 @@ def qr(a, mode='reduced'): a, wrap = _makearray(a) _assertRank2(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) m, n = a.shape t, result_t = _commonType(a) a = _fastCopyAndTranspose(t, a) @@ -863,7 +881,7 @@ def eigvals(a): """ a, wrap = _makearray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) _assertNdSquareness(a) _assertFinite(a) @@ -940,7 +958,7 @@ def eigvalsh(a, UPLO='L'): gufunc = _umath_linalg.eigvalsh_up a, wrap = _makearray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) @@ -1279,7 +1297,7 @@ def svd(a, full_matrices=1, compute_uv=1): """ a, wrap = _makearray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) t, result_t = _commonType(a) @@ -1556,7 +1574,7 @@ def pinv(a, rcond=1e-15 ): """ a, wrap = _makearray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) a = a.conjugate() u, s, vt = svd(a, 0) m = u.shape[0] @@ -1643,7 +1661,7 @@ def slogdet(a): """ a = asarray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) @@ -1697,7 +1715,7 @@ def det(a): """ a = asarray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 881311c94..7f102634e 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -204,6 +204,39 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + class ArraySubclass(np.ndarray): + pass + # Test system of 0x0 matrices + a = np.arange(8).reshape(2, 2, 2) + b = np.arange(6).reshape(1, 2, 3).view(ArraySubclass) + + expected = linalg.solve(a, b)[:,0:0,:] + result = linalg.solve(a[:,0:0,0:0], b[:,0:0,:]) + assert_array_equal(result, expected) + assert_(isinstance(result, ArraySubclass)) + + # Test errors for non-square and only b's dimension being 0 + assert_raises(linalg.LinAlgError, linalg.solve, a[:,0:0,0:1], b) + assert_raises(ValueError, linalg.solve, a, b[:,0:0,:]) + + # Test broadcasting error + b = np.arange(6).reshape(1, 3, 2) # broadcasting error + assert_raises(ValueError, linalg.solve, a, b) + assert_raises(ValueError, linalg.solve, a[0:0], b[0:0]) + + # Test zero "single equations" with 0x0 matrices. + b = np.arange(2).reshape(1, 2).view(ArraySubclass) + expected = linalg.solve(a, b)[:,0:0] + result = linalg.solve(a[:,0:0,0:0], b[:,0:0]) + assert_array_equal(result, expected) + assert_(isinstance(result, ArraySubclass)) + + b = np.arange(3).reshape(1, 3) + assert_raises(ValueError, linalg.solve, a, b) + assert_raises(ValueError, linalg.solve, a[0:0], b[0:0]) + assert_raises(ValueError, linalg.solve, a[:,0:0,0:0], b) + class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): @@ -219,6 +252,21 @@ class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0,1,1), dtype=np.int_).view(ArraySubclass) + res = linalg.inv(a) + assert_(res.dtype.type is np.float64) + assert_equal(a.shape, res.shape) + assert_(isinstance(a, ArraySubclass)) + + a = np.zeros((0,0), dtype=np.complex64).view(ArraySubclass) + res = linalg.inv(a) + assert_(res.dtype.type is np.complex64) + assert_equal(a.shape, res.shape) + class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): |