diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 20 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 131 |
2 files changed, 90 insertions, 61 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index c3b76ada7..ac3183e16 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -542,6 +542,8 @@ def matrix_power(a, n): of the same shape as M is returned. If ``n < 0``, the inverse is computed and then raised to the ``abs(n)``. + .. note:: Stacks of object matrices are not currently supported. + Parameters ---------- a : (..., M, M) array_like @@ -604,6 +606,16 @@ def matrix_power(a, n): except TypeError: raise TypeError("exponent must be an integer") + # Fall back on dot for object arrays. Object arrays are not supported by + # the current implementation of matmul using einsum + if a.dtype != object: + fmatmul = matmul + elif a.ndim == 2: + fmatmul = dot + else: + raise NotImplementedError( + "matrix_power not supported for stacks of object arrays") + if n == 0: a = empty_like(a) a[...] = eye(a.shape[-2], dtype=a.dtype) @@ -618,20 +630,20 @@ def matrix_power(a, n): return a elif n == 2: - return matmul(a, a) + return fmatmul(a, a) elif n == 3: - return matmul(matmul(a, a), a) + return fmatmul(fmatmul(a, a), a) # Use binary decomposition to reduce the number of matrix multiplications. # Here, we iterate over the bits of n, from LSB to MSB, raise `a` to # increasing powers of 2, and multiply into the result as needed. z = result = None while n > 0: - z = a if z is None else matmul(z, z) + z = a if z is None else fmatmul(z, z) n, bit = divmod(n, 2) if bit: - result = z if result is None else matmul(result, z) + result = z if result is None else fmatmul(result, z) return result diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 0df673884..30db5f2fe 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -924,77 +924,94 @@ class TestLstsq(LstsqCases): assert_(len(w) == 1) +@pytest.mark.parametrize('dt', [np.dtype(c) for c in '?bBhHiIqQefdgFDGO']) class TestMatrixPower(object): - R90 = array([[0, 1], [-1, 0]]) - Arb22 = array([[4, -7], [-2, 10]]) + + rshft_0 = np.eye(4) + rshft_1 = rshft_0[[3, 0, 1, 2]] + rshft_2 = rshft_0[[2, 3, 0, 1]] + rshft_3 = rshft_0[[1, 2, 3, 0]] + rshft_all = [rshft_0, rshft_1, rshft_2, rshft_3] noninv = array([[1, 0], [0, 0]]) - arbfloat = array([[[0.1, 3.2], [1.2, 0.7]], - [[0.2, 6.4], [2.4, 1.4]]]) + stacked = np.block([[[rshft_0]]]*2) + #FIXME the 'e' dtype might work in future + dtnoinv = [object, np.dtype('e'), np.dtype('g'), np.dtype('G')] - large = identity(10) - t = large[1, :].copy() - large[1, :] = large[0, :] - large[0, :] = t - def test_large_power(self): + def test_large_power(self, dt): + power = matrix_power + rshft = self.rshft_1.astype(dt) assert_equal( - matrix_power(self.R90, 2 ** 100 + 2 ** 10 + 2 ** 5 + 1), self.R90) + matrix_power(rshft, 2**100 + 2**10 + 2**5 + 0), self.rshft_0) assert_equal( - matrix_power(self.R90, 2 ** 100 + 2 ** 10 + 1), self.R90) + matrix_power(rshft, 2**100 + 2**10 + 2**5 + 1), self.rshft_1) assert_equal( - matrix_power(self.R90, 2 ** 100 + 2 + 1), -self.R90) - - def test_large_power_trailing_zero(self): + matrix_power(rshft, 2**100 + 2**10 + 2**5 + 2), self.rshft_2) assert_equal( - matrix_power(self.R90, 2 ** 100 + 2 ** 10 + 2 ** 5), identity(2)) + matrix_power(rshft, 2**100 + 2**10 + 2**5 + 3), self.rshft_3) - def testip_zero(self): + def test_power_is_zero(self, dt): def tz(M): mz = matrix_power(M, 0) assert_equal(mz, identity_like_generalized(M)) assert_equal(mz.dtype, M.dtype) - for M in [self.Arb22, self.arbfloat, self.large]: - tz(M) - - def testip_one(self): - def tz(M): - mz = matrix_power(M, 1) - assert_equal(mz, M) - assert_equal(mz.dtype, M.dtype) - for M in [self.Arb22, self.arbfloat, self.large]: - tz(M) - - def testip_two(self): - def tz(M): - mz = matrix_power(M, 2) - assert_equal(mz, matmul(M, M)) - assert_equal(mz.dtype, M.dtype) - for M in [self.Arb22, self.arbfloat, self.large]: - tz(M) - - def testip_invert(self): - def tz(M): - mz = matrix_power(M, -1) - assert_almost_equal(matmul(mz, M), identity_like_generalized(M)) - for M in [self.R90, self.Arb22, self.arbfloat, self.large]: - tz(M) - - def test_invert_noninvertible(self): - assert_raises(LinAlgError, matrix_power, self.noninv, -1) - - def test_invalid(self): - assert_raises(TypeError, matrix_power, self.R90, 1.5) - assert_raises(TypeError, matrix_power, self.R90, [1]) - assert_raises(LinAlgError, matrix_power, np.array([1]), 1) - assert_raises(LinAlgError, matrix_power, np.array([[1], [2]]), 1) - assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2)), 1) - - -class TestBoolPower(object): + + for mat in self.rshft_all: + tz(mat.astype(dt)) + if dt != object: + tz(self.stacked.astype(dt)) + + def test_power_is_one(self, dt): + def tz(mat): + mz = matrix_power(mat, 1) + assert_equal(mz, mat) + assert_equal(mz.dtype, mat.dtype) + + for mat in self.rshft_all: + tz(mat.astype(dt)) + if dt != object: + tz(self.stacked.astype(dt)) + + def test_power_is_two(self, dt): + def tz(mat): + mz = matrix_power(mat, 2) + mmul = matmul if mat.dtype != object else dot + assert_equal(mz, mmul(mat, mat)) + assert_equal(mz.dtype, mat.dtype) + + for mat in self.rshft_all: + tz(mat.astype(dt)) + if dt != object: + tz(self.stacked.astype(dt)) + + def test_power_is_minus_one(self, dt): + def tz(mat): + invmat = matrix_power(mat, -1) + mmul = matmul if mat.dtype != object else dot + assert_almost_equal( + mmul(invmat, mat), identity_like_generalized(mat)) + + for mat in self.rshft_all: + if dt not in self.dtnoinv: + tz(mat.astype(dt)) + + def test_exceptions_bad_power(self, dt): + mat = self.rshft_0.astype(dt) + assert_raises(TypeError, matrix_power, mat, 1.5) + assert_raises(TypeError, matrix_power, mat, [1]) + + + def test_exceptions_non_square(self, dt): + assert_raises(LinAlgError, matrix_power, np.array([1], dt), 1) + assert_raises(LinAlgError, matrix_power, np.array([[1], [2]], dt), 1) + assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2), dt), 1) + + def test_exceptions_not_invertible(self, dt): + if dt in self.dtnoinv: + return + mat = self.noninv.astype(dt) + assert_raises(LinAlgError, matrix_power, mat, -1) - def test_square(self): - A = array([[True, False], [True, True]]) - assert_equal(matrix_power(A, 2), A) class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase): |