diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/fromnumeric.py | 6 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 27 |
2 files changed, 32 insertions, 1 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index 362c29cb8..a2937c5c5 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -1367,7 +1367,11 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None): (2, 3) """ - return asarray(a).trace(offset, axis1, axis2, dtype, out) + if isinstance(a, np.matrix): + # Get trace of matrix via an array to preserve backward compatibility. + return asarray(a).trace(offset, axis1, axis2, dtype, out) + else: + return asanyarray(a).trace(offset, axis1, axis2, dtype, out) def ravel(a, order='C'): diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 0541016d9..c66e49e5f 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2083,6 +2083,33 @@ class TestMethods(TestCase): a.diagonal() assert_(sys.getrefcount(a) < 50) + def test_trace(self): + a = np.arange(12).reshape((3, 4)) + assert_equal(a.trace(), 15) + assert_equal(a.trace(0), 15) + assert_equal(a.trace(1), 18) + assert_equal(a.trace(-1), 13) + + b = np.arange(8).reshape((2, 2, 2)) + assert_equal(b.trace(), [6, 8]) + assert_equal(b.trace(0), [6, 8]) + assert_equal(b.trace(1), [2, 3]) + assert_equal(b.trace(-1), [4, 5]) + assert_equal(b.trace(0, 0, 1), [6, 8]) + assert_equal(b.trace(0, 0, 2), [5, 9]) + assert_equal(b.trace(0, 1, 2), [3, 11]) + assert_equal(b.trace(offset=1, axis1=0, axis2=2), [1, 3]) + + def test_trace_subclass(self): + # The class would need to overwrite trace to ensure single-element + # output also has the right subclass. + class MyArray(np.ndarray): + pass + + b = np.arange(8).reshape((2, 2, 2)).view(MyArray) + t = b.trace() + assert isinstance(t, MyArray) + def test_put(self): icodes = np.typecodes['AllInteger'] fcodes = np.typecodes['AllFloat'] |