diff options
author | Tim Hochberg <tim_hochberg@local> | 2006-03-31 16:28:24 +0000 |
---|---|---|
committer | Tim Hochberg <tim_hochberg@local> | 2006-03-31 16:28:24 +0000 |
commit | 827b79a954d3f222bb1f618fd0a14e606dbd7e6e (patch) | |
tree | b367d689b362e9325ba97ebf13239922372c3a91 /numpy/core | |
parent | f0c7ba02e42d71d120cc782395f01acc6ae15db0 (diff) | |
download | numpy-827b79a954d3f222bb1f618fd0a14e606dbd7e6e.tar.gz |
Fixed _wrapit so that it correctly handled non-array return values.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/oldnumeric.py | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_oldnumeric.py | 18 |
2 files changed, 17 insertions, 3 deletions
diff --git a/numpy/core/oldnumeric.py b/numpy/core/oldnumeric.py index 978104469..b8685d69e 100644 --- a/numpy/core/oldnumeric.py +++ b/numpy/core/oldnumeric.py @@ -166,7 +166,7 @@ def _wrapit(obj, method, *args, **kwds): except AttributeError: wrap = None result = getattr(asarray(obj),method)(*args, **kwds) - if wrap: + if wrap and isinstance(result, mu.ndarray): if not isinstance(result, mu.ndarray): result = asarray(result) result = wrap(result) diff --git a/numpy/core/tests/test_oldnumeric.py b/numpy/core/tests/test_oldnumeric.py index 2821aa899..df8d9a3db 100644 --- a/numpy/core/tests/test_oldnumeric.py +++ b/numpy/core/tests/test_oldnumeric.py @@ -1,6 +1,6 @@ from numpy.testing import * -from numpy import array +from numpy import array, ndarray, arange, argmax from numpy.core.oldnumeric import put class test_put(ScipyTestCase): @@ -9,7 +9,21 @@ class test_put(ScipyTestCase): put(a,[1],[1.2]) assert_array_equal(a,[0,1,0]) put(a,[1],array([2.2])) - assert_array_equal(a,[0,2,0]) + assert_array_equal(a,[0,2,0])
+
+class test_wrapit(ScipyTestCase):
+ def check_array_subclass(self, level=1):
+ class subarray(ndarray): + def get_argmax(self):
+ raise AttributeError
+ argmax = property(get_argmax)
+ a = subarray([3], int, arange(3))
+ assert_equal(argmax(a), 2)
+ b = subarray([3, 3], int, arange(9))
+ bmax = argmax(b)
+ assert_array_equal(bmax, [2,2,2])
+ assert_equal(type(bmax), subarray)
+ if __name__ == "__main__": ScipyTest().run() |