diff options
author | Allan Haldane <ealloc@gmail.com> | 2018-04-23 12:11:03 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-23 12:11:03 -0400 |
commit | f2888dbfc440cc3023b751fb7a5d91b9817fc271 (patch) | |
tree | 67dc23a230c34c4881181c64661c6896b6bbf04a /numpy/core | |
parent | b5c1bcf1e8ef6e9c11bb4138a15286e648fcbce0 (diff) | |
parent | ac7d543f52ab50c878b64a13662dce198c6fcb64 (diff) | |
download | numpy-f2888dbfc440cc3023b751fb7a5d91b9817fc271.tar.gz |
Merge pull request #10951 from mattip/nditer-close-fixes
BUG: it.close() disallows access to iterator, fixes #10950
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/nditer_pywrap.c | 38 | ||||
-rw-r--r-- | numpy/core/tests/test_nditer.py | 13 |
2 files changed, 29 insertions, 22 deletions
diff --git a/numpy/core/src/multiarray/nditer_pywrap.c b/numpy/core/src/multiarray/nditer_pywrap.c index d36be61f5..4505e645b 100644 --- a/numpy/core/src/multiarray/nditer_pywrap.c +++ b/numpy/core/src/multiarray/nditer_pywrap.c @@ -20,16 +20,14 @@ typedef struct NewNpyArrayIterObject_tag NewNpyArrayIterObject; -enum NPYITER_CONTEXT {CONTEXT_NOTENTERED, CONTEXT_INSIDE, CONTEXT_EXITED}; - struct NewNpyArrayIterObject_tag { PyObject_HEAD /* The iterator */ NpyIter *iter; /* Flag indicating iteration started/stopped */ char started, finished; - /* iter must used as a context manager if writebackifcopy semantics used */ - char managed; + /* iter operands cannot be referenced if iter is closed */ + npy_bool is_closed; /* Child to update for nested iteration */ NewNpyArrayIterObject *nested_child; /* Cached values from the iterator */ @@ -89,7 +87,7 @@ npyiter_new(PyTypeObject *subtype, PyObject *args, PyObject *kwds) if (self != NULL) { self->iter = NULL; self->nested_child = NULL; - self->managed = CONTEXT_NOTENTERED; + self->is_closed = 0; } return (PyObject *)self; @@ -1419,7 +1417,7 @@ static PyObject *npyiter_value_get(NewNpyArrayIterObject *self) ret = npyiter_seq_item(self, 0); } else { - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return NULL; @@ -1454,7 +1452,7 @@ static PyObject *npyiter_operands_get(NewNpyArrayIterObject *self) "Iterator is invalid"); return NULL; } - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return NULL; @@ -1489,7 +1487,7 @@ static PyObject *npyiter_itviews_get(NewNpyArrayIterObject *self) return NULL; } - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return NULL; @@ -1517,7 +1515,8 @@ static PyObject *npyiter_itviews_get(NewNpyArrayIterObject *self) static PyObject * npyiter_next(NewNpyArrayIterObject *self) { - if (self->iter == NULL || self->iternext == NULL || self->finished) { + if (self->iter == NULL || self->iternext == NULL || + self->finished || self->is_closed) { return NULL; } @@ -1912,7 +1911,7 @@ static PyObject *npyiter_dtypes_get(NewNpyArrayIterObject *self) return NULL; } - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return NULL; @@ -2014,7 +2013,7 @@ npyiter_seq_item(NewNpyArrayIterObject *self, Py_ssize_t i) return NULL; } - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return NULL; @@ -2104,7 +2103,7 @@ npyiter_seq_slice(NewNpyArrayIterObject *self, return NULL; } - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return NULL; @@ -2170,7 +2169,7 @@ npyiter_seq_ass_item(NewNpyArrayIterObject *self, Py_ssize_t i, PyObject *v) return -1; } - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return -1; @@ -2250,7 +2249,7 @@ npyiter_seq_ass_slice(NewNpyArrayIterObject *self, Py_ssize_t ilow, return -1; } - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return -1; @@ -2307,7 +2306,7 @@ npyiter_subscript(NewNpyArrayIterObject *self, PyObject *op) return NULL; } - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return NULL; @@ -2362,7 +2361,7 @@ npyiter_ass_subscript(NewNpyArrayIterObject *self, PyObject *op, return -1; } - if (self->managed == CONTEXT_EXITED) { + if (self->is_closed) { PyErr_SetString(PyExc_ValueError, "Iterator is closed"); return -1; @@ -2402,11 +2401,10 @@ npyiter_enter(NewNpyArrayIterObject *self) PyErr_SetString(PyExc_RuntimeError, "operation on non-initialized iterator"); return NULL; } - if (self->managed == CONTEXT_EXITED) { - PyErr_SetString(PyExc_ValueError, "cannot reuse iterator after exit"); + if (self->is_closed) { + PyErr_SetString(PyExc_ValueError, "cannot reuse closed iterator"); return NULL; } - self->managed = CONTEXT_INSIDE; Py_INCREF(self); return (PyObject *)self; } @@ -2420,6 +2418,7 @@ npyiter_close(NewNpyArrayIterObject *self) Py_RETURN_NONE; } ret = NpyIter_Close(iter); + self->is_closed = 1; if (ret < 0) { return NULL; } @@ -2429,7 +2428,6 @@ npyiter_close(NewNpyArrayIterObject *self) static PyObject * npyiter_exit(NewNpyArrayIterObject *self, PyObject *args) { - self->managed = CONTEXT_EXITED; /* even if called via exception handling, writeback any data */ return npyiter_close(self); } diff --git a/numpy/core/tests/test_nditer.py b/numpy/core/tests/test_nditer.py index bc9456536..77c26eacf 100644 --- a/numpy/core/tests/test_nditer.py +++ b/numpy/core/tests/test_nditer.py @@ -2847,7 +2847,7 @@ def test_writebacks(): enter = it.__enter__ assert_raises(ValueError, enter) -def test_close(): +def test_close_equivalent(): ''' using a context amanger and using nditer.close are equivalent ''' def add_close(x, y, out=None): @@ -2856,8 +2856,10 @@ def test_close(): [['readonly'], ['readonly'], ['writeonly','allocate']]) for (a, b, c) in it: addop(a, b, out=c) + ret = it.operands[2] it.close() - return it.operands[2] + return ret + def add_context(x, y, out=None): addop = np.add it = np.nditer([x, y, out], [], @@ -2871,6 +2873,13 @@ def test_close(): z = add_context(range(5), range(5)) assert_equal(z, range(0, 10, 2)) +def test_close_raises(): + it = np.nditer(np.arange(3)) + assert_equal (next(it), 0) + it.close() + assert_raises(StopIteration, next, it) + assert_raises(ValueError, getattr, it, 'operands') + def test_warn_noclose(): a = np.arange(6, dtype='f4') au = a.byteswap().newbyteorder() |