summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py20
-rw-r--r--numpy/linalg/tests/test_linalg.py131
2 files changed, 90 insertions, 61 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index c3b76ada7..ac3183e16 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -542,6 +542,8 @@ def matrix_power(a, n):
of the same shape as M is returned. If ``n < 0``, the inverse
is computed and then raised to the ``abs(n)``.
+ .. note:: Stacks of object matrices are not currently supported.
+
Parameters
----------
a : (..., M, M) array_like
@@ -604,6 +606,16 @@ def matrix_power(a, n):
except TypeError:
raise TypeError("exponent must be an integer")
+ # Fall back on dot for object arrays. Object arrays are not supported by
+ # the current implementation of matmul using einsum
+ if a.dtype != object:
+ fmatmul = matmul
+ elif a.ndim == 2:
+ fmatmul = dot
+ else:
+ raise NotImplementedError(
+ "matrix_power not supported for stacks of object arrays")
+
if n == 0:
a = empty_like(a)
a[...] = eye(a.shape[-2], dtype=a.dtype)
@@ -618,20 +630,20 @@ def matrix_power(a, n):
return a
elif n == 2:
- return matmul(a, a)
+ return fmatmul(a, a)
elif n == 3:
- return matmul(matmul(a, a), a)
+ return fmatmul(fmatmul(a, a), a)
# Use binary decomposition to reduce the number of matrix multiplications.
# Here, we iterate over the bits of n, from LSB to MSB, raise `a` to
# increasing powers of 2, and multiply into the result as needed.
z = result = None
while n > 0:
- z = a if z is None else matmul(z, z)
+ z = a if z is None else fmatmul(z, z)
n, bit = divmod(n, 2)
if bit:
- result = z if result is None else matmul(result, z)
+ result = z if result is None else fmatmul(result, z)
return result
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 0df673884..30db5f2fe 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -924,77 +924,94 @@ class TestLstsq(LstsqCases):
assert_(len(w) == 1)
+@pytest.mark.parametrize('dt', [np.dtype(c) for c in '?bBhHiIqQefdgFDGO'])
class TestMatrixPower(object):
- R90 = array([[0, 1], [-1, 0]])
- Arb22 = array([[4, -7], [-2, 10]])
+
+ rshft_0 = np.eye(4)
+ rshft_1 = rshft_0[[3, 0, 1, 2]]
+ rshft_2 = rshft_0[[2, 3, 0, 1]]
+ rshft_3 = rshft_0[[1, 2, 3, 0]]
+ rshft_all = [rshft_0, rshft_1, rshft_2, rshft_3]
noninv = array([[1, 0], [0, 0]])
- arbfloat = array([[[0.1, 3.2], [1.2, 0.7]],
- [[0.2, 6.4], [2.4, 1.4]]])
+ stacked = np.block([[[rshft_0]]]*2)
+ #FIXME the 'e' dtype might work in future
+ dtnoinv = [object, np.dtype('e'), np.dtype('g'), np.dtype('G')]
- large = identity(10)
- t = large[1, :].copy()
- large[1, :] = large[0, :]
- large[0, :] = t
- def test_large_power(self):
+ def test_large_power(self, dt):
+ power = matrix_power
+ rshft = self.rshft_1.astype(dt)
assert_equal(
- matrix_power(self.R90, 2 ** 100 + 2 ** 10 + 2 ** 5 + 1), self.R90)
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 0), self.rshft_0)
assert_equal(
- matrix_power(self.R90, 2 ** 100 + 2 ** 10 + 1), self.R90)
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 1), self.rshft_1)
assert_equal(
- matrix_power(self.R90, 2 ** 100 + 2 + 1), -self.R90)
-
- def test_large_power_trailing_zero(self):
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 2), self.rshft_2)
assert_equal(
- matrix_power(self.R90, 2 ** 100 + 2 ** 10 + 2 ** 5), identity(2))
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 3), self.rshft_3)
- def testip_zero(self):
+ def test_power_is_zero(self, dt):
def tz(M):
mz = matrix_power(M, 0)
assert_equal(mz, identity_like_generalized(M))
assert_equal(mz.dtype, M.dtype)
- for M in [self.Arb22, self.arbfloat, self.large]:
- tz(M)
-
- def testip_one(self):
- def tz(M):
- mz = matrix_power(M, 1)
- assert_equal(mz, M)
- assert_equal(mz.dtype, M.dtype)
- for M in [self.Arb22, self.arbfloat, self.large]:
- tz(M)
-
- def testip_two(self):
- def tz(M):
- mz = matrix_power(M, 2)
- assert_equal(mz, matmul(M, M))
- assert_equal(mz.dtype, M.dtype)
- for M in [self.Arb22, self.arbfloat, self.large]:
- tz(M)
-
- def testip_invert(self):
- def tz(M):
- mz = matrix_power(M, -1)
- assert_almost_equal(matmul(mz, M), identity_like_generalized(M))
- for M in [self.R90, self.Arb22, self.arbfloat, self.large]:
- tz(M)
-
- def test_invert_noninvertible(self):
- assert_raises(LinAlgError, matrix_power, self.noninv, -1)
-
- def test_invalid(self):
- assert_raises(TypeError, matrix_power, self.R90, 1.5)
- assert_raises(TypeError, matrix_power, self.R90, [1])
- assert_raises(LinAlgError, matrix_power, np.array([1]), 1)
- assert_raises(LinAlgError, matrix_power, np.array([[1], [2]]), 1)
- assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2)), 1)
-
-
-class TestBoolPower(object):
+
+ for mat in self.rshft_all:
+ tz(mat.astype(dt))
+ if dt != object:
+ tz(self.stacked.astype(dt))
+
+ def test_power_is_one(self, dt):
+ def tz(mat):
+ mz = matrix_power(mat, 1)
+ assert_equal(mz, mat)
+ assert_equal(mz.dtype, mat.dtype)
+
+ for mat in self.rshft_all:
+ tz(mat.astype(dt))
+ if dt != object:
+ tz(self.stacked.astype(dt))
+
+ def test_power_is_two(self, dt):
+ def tz(mat):
+ mz = matrix_power(mat, 2)
+ mmul = matmul if mat.dtype != object else dot
+ assert_equal(mz, mmul(mat, mat))
+ assert_equal(mz.dtype, mat.dtype)
+
+ for mat in self.rshft_all:
+ tz(mat.astype(dt))
+ if dt != object:
+ tz(self.stacked.astype(dt))
+
+ def test_power_is_minus_one(self, dt):
+ def tz(mat):
+ invmat = matrix_power(mat, -1)
+ mmul = matmul if mat.dtype != object else dot
+ assert_almost_equal(
+ mmul(invmat, mat), identity_like_generalized(mat))
+
+ for mat in self.rshft_all:
+ if dt not in self.dtnoinv:
+ tz(mat.astype(dt))
+
+ def test_exceptions_bad_power(self, dt):
+ mat = self.rshft_0.astype(dt)
+ assert_raises(TypeError, matrix_power, mat, 1.5)
+ assert_raises(TypeError, matrix_power, mat, [1])
+
+
+ def test_exceptions_non_square(self, dt):
+ assert_raises(LinAlgError, matrix_power, np.array([1], dt), 1)
+ assert_raises(LinAlgError, matrix_power, np.array([[1], [2]], dt), 1)
+ assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2), dt), 1)
+
+ def test_exceptions_not_invertible(self, dt):
+ if dt in self.dtnoinv:
+ return
+ mat = self.noninv.astype(dt)
+ assert_raises(LinAlgError, matrix_power, mat, -1)
- def test_square(self):
- A = array([[True, False], [True, True]])
- assert_equal(matrix_power(A, 2), A)
class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase):