summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorTim Hochberg <tim_hochberg@local>2006-03-31 16:28:24 +0000
committerTim Hochberg <tim_hochberg@local>2006-03-31 16:28:24 +0000
commit827b79a954d3f222bb1f618fd0a14e606dbd7e6e (patch)
treeb367d689b362e9325ba97ebf13239922372c3a91 /numpy/core
parentf0c7ba02e42d71d120cc782395f01acc6ae15db0 (diff)
downloadnumpy-827b79a954d3f222bb1f618fd0a14e606dbd7e6e.tar.gz
Fixed _wrapit so that it correctly handled non-array return values.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/oldnumeric.py2
-rw-r--r--numpy/core/tests/test_oldnumeric.py18
2 files changed, 17 insertions, 3 deletions
diff --git a/numpy/core/oldnumeric.py b/numpy/core/oldnumeric.py
index 978104469..b8685d69e 100644
--- a/numpy/core/oldnumeric.py
+++ b/numpy/core/oldnumeric.py
@@ -166,7 +166,7 @@ def _wrapit(obj, method, *args, **kwds):
except AttributeError:
wrap = None
result = getattr(asarray(obj),method)(*args, **kwds)
- if wrap:
+ if wrap and isinstance(result, mu.ndarray):
if not isinstance(result, mu.ndarray):
result = asarray(result)
result = wrap(result)
diff --git a/numpy/core/tests/test_oldnumeric.py b/numpy/core/tests/test_oldnumeric.py
index 2821aa899..df8d9a3db 100644
--- a/numpy/core/tests/test_oldnumeric.py
+++ b/numpy/core/tests/test_oldnumeric.py
@@ -1,6 +1,6 @@
from numpy.testing import *
-from numpy import array
+from numpy import array, ndarray, arange, argmax
from numpy.core.oldnumeric import put
class test_put(ScipyTestCase):
@@ -9,7 +9,21 @@ class test_put(ScipyTestCase):
put(a,[1],[1.2])
assert_array_equal(a,[0,1,0])
put(a,[1],array([2.2]))
- assert_array_equal(a,[0,2,0])
+ assert_array_equal(a,[0,2,0])
+
+class test_wrapit(ScipyTestCase):
+ def check_array_subclass(self, level=1):
+ class subarray(ndarray):
+ def get_argmax(self):
+ raise AttributeError
+ argmax = property(get_argmax)
+ a = subarray([3], int, arange(3))
+ assert_equal(argmax(a), 2)
+ b = subarray([3, 3], int, arange(9))
+ bmax = argmax(b)
+ assert_array_equal(bmax, [2,2,2])
+ assert_equal(type(bmax), subarray)
+
if __name__ == "__main__":
ScipyTest().run()