diff options
author | Travis Oliphant <oliphant@enthought.com> | 2008-10-02 20:27:17 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2008-10-02 20:27:17 +0000 |
commit | 6d5283a811f5f727bc7486e596d8048f5315a6d2 (patch) | |
tree | f7613d3d99249d891877933a90e84a120b9e264d /numpy | |
parent | dd96fbda20855997eb6d234a345d45c462c4e74b (diff) | |
download | numpy-6d5283a811f5f727bc7486e596d8048f5315a6d2.tar.gz |
Fix problem with subclasses of object arrays.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/ufuncobject.c | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 10 |
2 files changed, 15 insertions, 2 deletions
diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c index fd14936a8..ab4ecef75 100644 --- a/numpy/core/src/ufuncobject.c +++ b/numpy/core/src/ufuncobject.c @@ -1427,12 +1427,15 @@ construct_arrays(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps, * FAIL with NotImplemented if the other object has * the __r<op>__ method and has __array_priority__ as * an attribute (signalling it can handle ndarray's) - * and is not already an ndarray + * and is not already an ndarray or a subtype of the same type. */ if ((arg_types[1] == PyArray_OBJECT) && \ (loop->ufunc->nin==2) && (loop->ufunc->nout == 1)) { PyObject *_obj = PyTuple_GET_ITEM(args, 1); - if (!PyArray_CheckExact(_obj) && \ + if (!PyArray_CheckExact(_obj) && + /* If both are same subtype of object arrays, then proceed */ + !(_obj->ob_type == (PyTuple_GET_ITEM(args, 0))->ob_type) && \ + PyObject_HasAttrString(_obj, "__array_priority__") && \ _has_reflected_op(_obj, loop->ufunc->name)) { loop->notimplemented = 1; diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index b9c2675fd..ee2893f18 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -281,6 +281,16 @@ class TestAttributes(TestCase): assert_equal(add.nout, 1) assert_equal(add.identity, 0) +class TestSubclass(TestCase): + def test_subclass_op(self): + class simple(np.ndarray): + def __new__(subtype, shape): + self = np.ndarray.__new__(subtype, shape, dtype=object) + self.fill(0) + return self + a = simple((3,4)) + assert_equal(a+a, a) + def _check_branch_cut(f, x0, dx, re_sign=1, im_sign=-1, sig_zero_ok=False, dtype=np.complex): """ |