summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests/test_linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r--numpy/linalg/tests/test_linalg.py131
1 files changed, 74 insertions, 57 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 0df673884..30db5f2fe 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -924,77 +924,94 @@ class TestLstsq(LstsqCases):
assert_(len(w) == 1)
+@pytest.mark.parametrize('dt', [np.dtype(c) for c in '?bBhHiIqQefdgFDGO'])
class TestMatrixPower(object):
- R90 = array([[0, 1], [-1, 0]])
- Arb22 = array([[4, -7], [-2, 10]])
+
+ rshft_0 = np.eye(4)
+ rshft_1 = rshft_0[[3, 0, 1, 2]]
+ rshft_2 = rshft_0[[2, 3, 0, 1]]
+ rshft_3 = rshft_0[[1, 2, 3, 0]]
+ rshft_all = [rshft_0, rshft_1, rshft_2, rshft_3]
noninv = array([[1, 0], [0, 0]])
- arbfloat = array([[[0.1, 3.2], [1.2, 0.7]],
- [[0.2, 6.4], [2.4, 1.4]]])
+ stacked = np.block([[[rshft_0]]]*2)
+ #FIXME the 'e' dtype might work in future
+ dtnoinv = [object, np.dtype('e'), np.dtype('g'), np.dtype('G')]
- large = identity(10)
- t = large[1, :].copy()
- large[1, :] = large[0, :]
- large[0, :] = t
- def test_large_power(self):
+ def test_large_power(self, dt):
+ power = matrix_power
+ rshft = self.rshft_1.astype(dt)
assert_equal(
- matrix_power(self.R90, 2 ** 100 + 2 ** 10 + 2 ** 5 + 1), self.R90)
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 0), self.rshft_0)
assert_equal(
- matrix_power(self.R90, 2 ** 100 + 2 ** 10 + 1), self.R90)
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 1), self.rshft_1)
assert_equal(
- matrix_power(self.R90, 2 ** 100 + 2 + 1), -self.R90)
-
- def test_large_power_trailing_zero(self):
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 2), self.rshft_2)
assert_equal(
- matrix_power(self.R90, 2 ** 100 + 2 ** 10 + 2 ** 5), identity(2))
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 3), self.rshft_3)
- def testip_zero(self):
+ def test_power_is_zero(self, dt):
def tz(M):
mz = matrix_power(M, 0)
assert_equal(mz, identity_like_generalized(M))
assert_equal(mz.dtype, M.dtype)
- for M in [self.Arb22, self.arbfloat, self.large]:
- tz(M)
-
- def testip_one(self):
- def tz(M):
- mz = matrix_power(M, 1)
- assert_equal(mz, M)
- assert_equal(mz.dtype, M.dtype)
- for M in [self.Arb22, self.arbfloat, self.large]:
- tz(M)
-
- def testip_two(self):
- def tz(M):
- mz = matrix_power(M, 2)
- assert_equal(mz, matmul(M, M))
- assert_equal(mz.dtype, M.dtype)
- for M in [self.Arb22, self.arbfloat, self.large]:
- tz(M)
-
- def testip_invert(self):
- def tz(M):
- mz = matrix_power(M, -1)
- assert_almost_equal(matmul(mz, M), identity_like_generalized(M))
- for M in [self.R90, self.Arb22, self.arbfloat, self.large]:
- tz(M)
-
- def test_invert_noninvertible(self):
- assert_raises(LinAlgError, matrix_power, self.noninv, -1)
-
- def test_invalid(self):
- assert_raises(TypeError, matrix_power, self.R90, 1.5)
- assert_raises(TypeError, matrix_power, self.R90, [1])
- assert_raises(LinAlgError, matrix_power, np.array([1]), 1)
- assert_raises(LinAlgError, matrix_power, np.array([[1], [2]]), 1)
- assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2)), 1)
-
-
-class TestBoolPower(object):
+
+ for mat in self.rshft_all:
+ tz(mat.astype(dt))
+ if dt != object:
+ tz(self.stacked.astype(dt))
+
+ def test_power_is_one(self, dt):
+ def tz(mat):
+ mz = matrix_power(mat, 1)
+ assert_equal(mz, mat)
+ assert_equal(mz.dtype, mat.dtype)
+
+ for mat in self.rshft_all:
+ tz(mat.astype(dt))
+ if dt != object:
+ tz(self.stacked.astype(dt))
+
+ def test_power_is_two(self, dt):
+ def tz(mat):
+ mz = matrix_power(mat, 2)
+ mmul = matmul if mat.dtype != object else dot
+ assert_equal(mz, mmul(mat, mat))
+ assert_equal(mz.dtype, mat.dtype)
+
+ for mat in self.rshft_all:
+ tz(mat.astype(dt))
+ if dt != object:
+ tz(self.stacked.astype(dt))
+
+ def test_power_is_minus_one(self, dt):
+ def tz(mat):
+ invmat = matrix_power(mat, -1)
+ mmul = matmul if mat.dtype != object else dot
+ assert_almost_equal(
+ mmul(invmat, mat), identity_like_generalized(mat))
+
+ for mat in self.rshft_all:
+ if dt not in self.dtnoinv:
+ tz(mat.astype(dt))
+
+ def test_exceptions_bad_power(self, dt):
+ mat = self.rshft_0.astype(dt)
+ assert_raises(TypeError, matrix_power, mat, 1.5)
+ assert_raises(TypeError, matrix_power, mat, [1])
+
+
+ def test_exceptions_non_square(self, dt):
+ assert_raises(LinAlgError, matrix_power, np.array([1], dt), 1)
+ assert_raises(LinAlgError, matrix_power, np.array([[1], [2]], dt), 1)
+ assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2), dt), 1)
+
+ def test_exceptions_not_invertible(self, dt):
+ if dt in self.dtnoinv:
+ return
+ mat = self.noninv.astype(dt)
+ assert_raises(LinAlgError, matrix_power, mat, -1)
- def test_square(self):
- A = array([[True, False], [True, True]])
- assert_equal(matrix_power(A, 2), A)
class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase):