summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/fromnumeric.py6
-rw-r--r--numpy/core/tests/test_multiarray.py27
2 files changed, 32 insertions, 1 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index 362c29cb8..a2937c5c5 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -1367,7 +1367,11 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
(2, 3)
"""
- return asarray(a).trace(offset, axis1, axis2, dtype, out)
+ if isinstance(a, np.matrix):
+ # Get trace of matrix via an array to preserve backward compatibility.
+ return asarray(a).trace(offset, axis1, axis2, dtype, out)
+ else:
+ return asanyarray(a).trace(offset, axis1, axis2, dtype, out)
def ravel(a, order='C'):
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 0541016d9..c66e49e5f 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2083,6 +2083,33 @@ class TestMethods(TestCase):
a.diagonal()
assert_(sys.getrefcount(a) < 50)
+ def test_trace(self):
+ a = np.arange(12).reshape((3, 4))
+ assert_equal(a.trace(), 15)
+ assert_equal(a.trace(0), 15)
+ assert_equal(a.trace(1), 18)
+ assert_equal(a.trace(-1), 13)
+
+ b = np.arange(8).reshape((2, 2, 2))
+ assert_equal(b.trace(), [6, 8])
+ assert_equal(b.trace(0), [6, 8])
+ assert_equal(b.trace(1), [2, 3])
+ assert_equal(b.trace(-1), [4, 5])
+ assert_equal(b.trace(0, 0, 1), [6, 8])
+ assert_equal(b.trace(0, 0, 2), [5, 9])
+ assert_equal(b.trace(0, 1, 2), [3, 11])
+ assert_equal(b.trace(offset=1, axis1=0, axis2=2), [1, 3])
+
+ def test_trace_subclass(self):
+ # The class would need to overwrite trace to ensure single-element
+ # output also has the right subclass.
+ class MyArray(np.ndarray):
+ pass
+
+ b = np.arange(8).reshape((2, 2, 2)).view(MyArray)
+ t = b.trace()
+ assert isinstance(t, MyArray)
+
def test_put(self):
icodes = np.typecodes['AllInteger']
fcodes = np.typecodes['AllFloat']