summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorAllan Haldane <ealloc@gmail.com>2018-04-23 12:11:03 -0400
committerGitHub <noreply@github.com>2018-04-23 12:11:03 -0400
commitf2888dbfc440cc3023b751fb7a5d91b9817fc271 (patch)
tree67dc23a230c34c4881181c64661c6896b6bbf04a /numpy/core
parentb5c1bcf1e8ef6e9c11bb4138a15286e648fcbce0 (diff)
parentac7d543f52ab50c878b64a13662dce198c6fcb64 (diff)
downloadnumpy-f2888dbfc440cc3023b751fb7a5d91b9817fc271.tar.gz
Merge pull request #10951 from mattip/nditer-close-fixes
BUG: it.close() disallows access to iterator, fixes #10950
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/nditer_pywrap.c38
-rw-r--r--numpy/core/tests/test_nditer.py13
2 files changed, 29 insertions, 22 deletions
diff --git a/numpy/core/src/multiarray/nditer_pywrap.c b/numpy/core/src/multiarray/nditer_pywrap.c
index d36be61f5..4505e645b 100644
--- a/numpy/core/src/multiarray/nditer_pywrap.c
+++ b/numpy/core/src/multiarray/nditer_pywrap.c
@@ -20,16 +20,14 @@
typedef struct NewNpyArrayIterObject_tag NewNpyArrayIterObject;
-enum NPYITER_CONTEXT {CONTEXT_NOTENTERED, CONTEXT_INSIDE, CONTEXT_EXITED};
-
struct NewNpyArrayIterObject_tag {
PyObject_HEAD
/* The iterator */
NpyIter *iter;
/* Flag indicating iteration started/stopped */
char started, finished;
- /* iter must used as a context manager if writebackifcopy semantics used */
- char managed;
+ /* iter operands cannot be referenced if iter is closed */
+ npy_bool is_closed;
/* Child to update for nested iteration */
NewNpyArrayIterObject *nested_child;
/* Cached values from the iterator */
@@ -89,7 +87,7 @@ npyiter_new(PyTypeObject *subtype, PyObject *args, PyObject *kwds)
if (self != NULL) {
self->iter = NULL;
self->nested_child = NULL;
- self->managed = CONTEXT_NOTENTERED;
+ self->is_closed = 0;
}
return (PyObject *)self;
@@ -1419,7 +1417,7 @@ static PyObject *npyiter_value_get(NewNpyArrayIterObject *self)
ret = npyiter_seq_item(self, 0);
}
else {
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return NULL;
@@ -1454,7 +1452,7 @@ static PyObject *npyiter_operands_get(NewNpyArrayIterObject *self)
"Iterator is invalid");
return NULL;
}
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return NULL;
@@ -1489,7 +1487,7 @@ static PyObject *npyiter_itviews_get(NewNpyArrayIterObject *self)
return NULL;
}
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return NULL;
@@ -1517,7 +1515,8 @@ static PyObject *npyiter_itviews_get(NewNpyArrayIterObject *self)
static PyObject *
npyiter_next(NewNpyArrayIterObject *self)
{
- if (self->iter == NULL || self->iternext == NULL || self->finished) {
+ if (self->iter == NULL || self->iternext == NULL ||
+ self->finished || self->is_closed) {
return NULL;
}
@@ -1912,7 +1911,7 @@ static PyObject *npyiter_dtypes_get(NewNpyArrayIterObject *self)
return NULL;
}
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return NULL;
@@ -2014,7 +2013,7 @@ npyiter_seq_item(NewNpyArrayIterObject *self, Py_ssize_t i)
return NULL;
}
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return NULL;
@@ -2104,7 +2103,7 @@ npyiter_seq_slice(NewNpyArrayIterObject *self,
return NULL;
}
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return NULL;
@@ -2170,7 +2169,7 @@ npyiter_seq_ass_item(NewNpyArrayIterObject *self, Py_ssize_t i, PyObject *v)
return -1;
}
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return -1;
@@ -2250,7 +2249,7 @@ npyiter_seq_ass_slice(NewNpyArrayIterObject *self, Py_ssize_t ilow,
return -1;
}
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return -1;
@@ -2307,7 +2306,7 @@ npyiter_subscript(NewNpyArrayIterObject *self, PyObject *op)
return NULL;
}
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return NULL;
@@ -2362,7 +2361,7 @@ npyiter_ass_subscript(NewNpyArrayIterObject *self, PyObject *op,
return -1;
}
- if (self->managed == CONTEXT_EXITED) {
+ if (self->is_closed) {
PyErr_SetString(PyExc_ValueError,
"Iterator is closed");
return -1;
@@ -2402,11 +2401,10 @@ npyiter_enter(NewNpyArrayIterObject *self)
PyErr_SetString(PyExc_RuntimeError, "operation on non-initialized iterator");
return NULL;
}
- if (self->managed == CONTEXT_EXITED) {
- PyErr_SetString(PyExc_ValueError, "cannot reuse iterator after exit");
+ if (self->is_closed) {
+ PyErr_SetString(PyExc_ValueError, "cannot reuse closed iterator");
return NULL;
}
- self->managed = CONTEXT_INSIDE;
Py_INCREF(self);
return (PyObject *)self;
}
@@ -2420,6 +2418,7 @@ npyiter_close(NewNpyArrayIterObject *self)
Py_RETURN_NONE;
}
ret = NpyIter_Close(iter);
+ self->is_closed = 1;
if (ret < 0) {
return NULL;
}
@@ -2429,7 +2428,6 @@ npyiter_close(NewNpyArrayIterObject *self)
static PyObject *
npyiter_exit(NewNpyArrayIterObject *self, PyObject *args)
{
- self->managed = CONTEXT_EXITED;
/* even if called via exception handling, writeback any data */
return npyiter_close(self);
}
diff --git a/numpy/core/tests/test_nditer.py b/numpy/core/tests/test_nditer.py
index bc9456536..77c26eacf 100644
--- a/numpy/core/tests/test_nditer.py
+++ b/numpy/core/tests/test_nditer.py
@@ -2847,7 +2847,7 @@ def test_writebacks():
enter = it.__enter__
assert_raises(ValueError, enter)
-def test_close():
+def test_close_equivalent():
''' using a context amanger and using nditer.close are equivalent
'''
def add_close(x, y, out=None):
@@ -2856,8 +2856,10 @@ def test_close():
[['readonly'], ['readonly'], ['writeonly','allocate']])
for (a, b, c) in it:
addop(a, b, out=c)
+ ret = it.operands[2]
it.close()
- return it.operands[2]
+ return ret
+
def add_context(x, y, out=None):
addop = np.add
it = np.nditer([x, y, out], [],
@@ -2871,6 +2873,13 @@ def test_close():
z = add_context(range(5), range(5))
assert_equal(z, range(0, 10, 2))
+def test_close_raises():
+ it = np.nditer(np.arange(3))
+ assert_equal (next(it), 0)
+ it.close()
+ assert_raises(StopIteration, next, it)
+ assert_raises(ValueError, getattr, it, 'operands')
+
def test_warn_noclose():
a = np.arange(6, dtype='f4')
au = a.byteswap().newbyteorder()