summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2013-07-10 14:19:42 +0200
committerSebastian Berg <sebastian@sipsolutions.net>2013-08-04 19:14:25 +0200
commitfa55f4c463806599bccf145baf22e13ff79f9a68 (patch)
treea7842184034542b4b9f7c9eb51871f7f41c1d457
parent90ececac5755b0b54d7c8c2c5d71caaeb5c0b45c (diff)
downloadnumpy-fa55f4c463806599bccf145baf22e13ff79f9a68.tar.gz
ENH: inv/solve work with empty inner and others empty outer array
This makes the inverse of a 0x0 array simply be 0x0 again. It also modifies the no-empty array check in favor of a no-empty *inner* array, since the gufuncs seem to handle the other case fine.
-rw-r--r--numpy/linalg/linalg.py50
-rw-r--r--numpy/linalg/tests/test_linalg.py48
2 files changed, 82 insertions, 16 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 760ca31b0..4b0d3d86d 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -19,10 +19,11 @@ __all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv',
import warnings
from numpy.core import (
- array, asarray, zeros, empty, transpose, intc, single, double, csingle,
- cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot, add,
- multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
- finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax
+ array, asarray, zeros, empty, empty_like, transpose, intc, single, double,
+ csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
+ add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
+ finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product,
+ broadcast
)
from numpy.lib import triu, asfarray
from numpy.linalg import lapack_lite, _umath_linalg
@@ -215,9 +216,9 @@ def _assertFinite(*arrays):
if not (isfinite(a).all()):
raise LinAlgError("Array must not contain infs or NaNs")
-def _assertNonEmpty(*arrays):
+def _assertNoEmpty2d(*arrays):
for a in arrays:
- if size(a) == 0:
+ if a.size == 0 and product(a.shape[-2:]) == 0:
raise LinAlgError("Arrays cannot be empty")
@@ -350,15 +351,28 @@ def solve(a, b):
"""
a, _ = _makearray(a)
- _assertNonEmpty(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
b, wrap = _makearray(b)
t, result_t = _commonType(a, b)
- if len(b.shape) == len(a.shape) - 1:
+ # We use the b = (..., M,) logic, only if the number of extra dimensions
+ # match exactly
+ if b.ndim == a.ndim - 1:
+ if a.shape[-1] == 0 and b.shape[-1] == 0:
+ # Legal, but the ufunc cannot handle the 0-sized inner dims
+ # let the ufunc handle all wrong cases.
+ a = a.reshape(a.shape[:-1])
+ bc = broadcast(a, b)
+ return wrap(empty(bc.shape, dtype=result_t))
+
gufunc = _umath_linalg.solve1
else:
+ if a.shape[-1] == 0 and b.shape[-2] == 0:
+ a = a.reshape(a.shape[:-1] + (1,))
+ bc = broadcast(a, b)
+ return wrap(empty(bc.shape, dtype=result_t))
+
gufunc = _umath_linalg.solve
signature = 'DD->D' if isComplexType(t) else 'dd->d'
@@ -492,10 +506,14 @@ def inv(a):
"""
a, wrap = _makearray(a)
- _assertNonEmpty(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
+
+ if a.shape[-1] == 0:
+ # The inner array is 0x0, the ufunc cannot handle this case
+ return wrap(empty_like(a, dtype=result_t))
+
signature = 'D->D' if isComplexType(t) else 'd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
@@ -718,7 +736,7 @@ def qr(a, mode='reduced'):
a, wrap = _makearray(a)
_assertRank2(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
m, n = a.shape
t, result_t = _commonType(a)
a = _fastCopyAndTranspose(t, a)
@@ -863,7 +881,7 @@ def eigvals(a):
"""
a, wrap = _makearray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
_assertFinite(a)
@@ -940,7 +958,7 @@ def eigvalsh(a, UPLO='L'):
gufunc = _umath_linalg.eigvalsh_up
a, wrap = _makearray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
@@ -1279,7 +1297,7 @@ def svd(a, full_matrices=1, compute_uv=1):
"""
a, wrap = _makearray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
t, result_t = _commonType(a)
@@ -1556,7 +1574,7 @@ def pinv(a, rcond=1e-15 ):
"""
a, wrap = _makearray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
a = a.conjugate()
u, s, vt = svd(a, 0)
m = u.shape[0]
@@ -1643,7 +1661,7 @@ def slogdet(a):
"""
a = asarray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
@@ -1697,7 +1715,7 @@ def det(a):
"""
a = asarray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 881311c94..7f102634e 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -204,6 +204,39 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ class ArraySubclass(np.ndarray):
+ pass
+ # Test system of 0x0 matrices
+ a = np.arange(8).reshape(2, 2, 2)
+ b = np.arange(6).reshape(1, 2, 3).view(ArraySubclass)
+
+ expected = linalg.solve(a, b)[:,0:0,:]
+ result = linalg.solve(a[:,0:0,0:0], b[:,0:0,:])
+ assert_array_equal(result, expected)
+ assert_(isinstance(result, ArraySubclass))
+
+ # Test errors for non-square and only b's dimension being 0
+ assert_raises(linalg.LinAlgError, linalg.solve, a[:,0:0,0:1], b)
+ assert_raises(ValueError, linalg.solve, a, b[:,0:0,:])
+
+ # Test broadcasting error
+ b = np.arange(6).reshape(1, 3, 2) # broadcasting error
+ assert_raises(ValueError, linalg.solve, a, b)
+ assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
+
+ # Test zero "single equations" with 0x0 matrices.
+ b = np.arange(2).reshape(1, 2).view(ArraySubclass)
+ expected = linalg.solve(a, b)[:,0:0]
+ result = linalg.solve(a[:,0:0,0:0], b[:,0:0])
+ assert_array_equal(result, expected)
+ assert_(isinstance(result, ArraySubclass))
+
+ b = np.arange(3).reshape(1, 3)
+ assert_raises(ValueError, linalg.solve, a, b)
+ assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
+ assert_raises(ValueError, linalg.solve, a[:,0:0,0:0], b)
+
class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
@@ -219,6 +252,21 @@ class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0,1,1), dtype=np.int_).view(ArraySubclass)
+ res = linalg.inv(a)
+ assert_(res.dtype.type is np.float64)
+ assert_equal(a.shape, res.shape)
+ assert_(isinstance(a, ArraySubclass))
+
+ a = np.zeros((0,0), dtype=np.complex64).view(ArraySubclass)
+ res = linalg.inv(a)
+ assert_(res.dtype.type is np.complex64)
+ assert_equal(a.shape, res.shape)
+
class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):