summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorda-woods <dw-git@d-woods.co.uk>2020-07-20 19:42:00 +0100
committerGitHub <noreply@github.com>2020-07-20 20:42:00 +0200
commit9cb557c37332ae50bfdbd675409c690cdd5fd908 (patch)
tree9bdfd28c98728649ed6d846fda10e92de13822c4
parentc42ad91755f6c17e26e4d80d79926925bfb76731 (diff)
downloadcython-9cb557c37332ae50bfdbd675409c690cdd5fd908.tar.gz
Handle `for x in cpp_function_call()` (GH-3667)
Fixes https://github.com/cython/cython/issues/3663 This ensures that rvalues here are saved as temps, while keeping the existing behaviour for `for x in deref(vec)`, where the pointer for vec is copied, meaning it doesn't crash if vec is reassigned. The bit of this change liable to have the biggest effect is that I've changed the result type of dereference(x) and x[0] (where x is a c++ type) to a reference rather than value type. I think this is OK because it matches what C++ does. If that isn't a sensible change then I can probably inspect the loop sequence more closely to try to detect this.
-rw-r--r--Cython/Compiler/ExprNodes.py185
-rw-r--r--tests/run/cpp_iterators.pyx37
2 files changed, 146 insertions, 76 deletions
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
index fcb69130d..3a2c4cb1b 100644
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -2658,7 +2658,6 @@ class IteratorNode(ExprNode):
type = py_object_type
iter_func_ptr = None
counter_cname = None
- cpp_iterator_cname = None
reversed = False # currently only used for list/tuple types (see Optimize.py)
is_async = False
@@ -2671,7 +2670,7 @@ class IteratorNode(ExprNode):
# C array iteration will be transformed later on
self.type = self.sequence.type
elif self.sequence.type.is_cpp_class:
- self.analyse_cpp_types(env)
+ return CppIteratorNode(self.pos, sequence=self.sequence).analyse_types(env)
else:
self.sequence = self.sequence.coerce_to_pyobject(env)
if self.sequence.type in (list_type, tuple_type):
@@ -2701,65 +2700,10 @@ class IteratorNode(ExprNode):
return sequence_type
return py_object_type
- def analyse_cpp_types(self, env):
- sequence_type = self.sequence.type
- if sequence_type.is_ptr:
- sequence_type = sequence_type.base_type
- begin = sequence_type.scope.lookup("begin")
- end = sequence_type.scope.lookup("end")
- if (begin is None
- or not begin.type.is_cfunction
- or begin.type.args):
- error(self.pos, "missing begin() on %s" % self.sequence.type)
- self.type = error_type
- return
- if (end is None
- or not end.type.is_cfunction
- or end.type.args):
- error(self.pos, "missing end() on %s" % self.sequence.type)
- self.type = error_type
- return
- iter_type = begin.type.return_type
- if iter_type.is_cpp_class:
- if env.lookup_operator_for_types(
- self.pos,
- "!=",
- [iter_type, end.type.return_type]) is None:
- error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
- self.type = error_type
- return
- if env.lookup_operator_for_types(self.pos, '++', [iter_type]) is None:
- error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type)
- self.type = error_type
- return
- if env.lookup_operator_for_types(self.pos, '*', [iter_type]) is None:
- error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type)
- self.type = error_type
- return
- self.type = iter_type
- elif iter_type.is_ptr:
- if not (iter_type == end.type.return_type):
- error(self.pos, "incompatible types for begin() and end()")
- self.type = iter_type
- else:
- error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
- self.type = error_type
- return
-
def generate_result_code(self, code):
sequence_type = self.sequence.type
if sequence_type.is_cpp_class:
- if self.sequence.is_name:
- # safe: C++ won't allow you to reassign to class references
- begin_func = "%s.begin" % self.sequence.result()
- else:
- sequence_type = PyrexTypes.c_ptr_type(sequence_type)
- self.cpp_iterator_cname = code.funcstate.allocate_temp(sequence_type, manage_ref=False)
- code.putln("%s = &%s;" % (self.cpp_iterator_cname, self.sequence.result()))
- begin_func = "%s->begin" % self.cpp_iterator_cname
- # TODO: Limit scope.
- code.putln("%s = %s();" % (self.result(), begin_func))
- return
+ assert False, "Should have been changed to CppIteratorNode"
if sequence_type.is_array or sequence_type.is_ptr:
raise InternalError("for in carray slice not transformed")
@@ -2855,21 +2799,7 @@ class IteratorNode(ExprNode):
sequence_type = self.sequence.type
if self.reversed:
code.putln("if (%s < 0) break;" % self.counter_cname)
- if sequence_type.is_cpp_class:
- if self.cpp_iterator_cname:
- end_func = "%s->end" % self.cpp_iterator_cname
- else:
- end_func = "%s.end" % self.sequence.result()
- # TODO: Cache end() call?
- code.putln("if (!(%s != %s())) break;" % (
- self.result(),
- end_func))
- code.putln("%s = *%s;" % (
- result_name,
- self.result()))
- code.putln("++%s;" % self.result())
- return
- elif sequence_type is list_type:
+ if sequence_type is list_type:
self.generate_next_sequence_item('List', result_name, code)
return
elif sequence_type is tuple_type:
@@ -2908,8 +2838,109 @@ class IteratorNode(ExprNode):
if self.iter_func_ptr:
code.funcstate.release_temp(self.iter_func_ptr)
self.iter_func_ptr = None
- if self.cpp_iterator_cname:
- code.funcstate.release_temp(self.cpp_iterator_cname)
+ ExprNode.free_temps(self, code)
+
+
+class CppIteratorNode(ExprNode):
+ # Iteration over a C++ container.
+ # Created at the analyse_types stage by IteratorNode
+ cpp_sequence_cname = None
+ cpp_attribute_op = "."
+ is_temp = True
+
+ subexprs = ['sequence']
+
+ def analyse_types(self, env):
+ sequence_type = self.sequence.type
+ if sequence_type.is_ptr:
+ sequence_type = sequence_type.base_type
+ begin = sequence_type.scope.lookup("begin")
+ end = sequence_type.scope.lookup("end")
+ if (begin is None
+ or not begin.type.is_cfunction
+ or begin.type.args):
+ error(self.pos, "missing begin() on %s" % self.sequence.type)
+ self.type = error_type
+ return self
+ if (end is None
+ or not end.type.is_cfunction
+ or end.type.args):
+ error(self.pos, "missing end() on %s" % self.sequence.type)
+ self.type = error_type
+ return self
+ iter_type = begin.type.return_type
+ if iter_type.is_cpp_class:
+ if env.lookup_operator_for_types(
+ self.pos,
+ "!=",
+ [iter_type, end.type.return_type]) is None:
+ error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
+ self.type = error_type
+ return self
+ if env.lookup_operator_for_types(self.pos, '++', [iter_type]) is None:
+ error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type)
+ self.type = error_type
+ return self
+ if env.lookup_operator_for_types(self.pos, '*', [iter_type]) is None:
+ error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type)
+ self.type = error_type
+ return self
+ self.type = iter_type
+ elif iter_type.is_ptr:
+ if not (iter_type == end.type.return_type):
+ error(self.pos, "incompatible types for begin() and end()")
+ self.type = iter_type
+ else:
+ error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
+ self.type = error_type
+ return self
+
+ def generate_result_code(self, code):
+ sequence_type = self.sequence.type
+ # essentially 3 options:
+ if self.sequence.is_name or self.sequence.is_attribute:
+ # 1) is a name and can be accessed directly;
+ # assigning to it may break the container, but that's the responsibility
+ # of the user
+ code.putln("%s = %s%sbegin();" % (self.result(),
+ self.sequence.result(),
+ self.cpp_attribute_op))
+ else:
+ # (while it'd be nice to limit the scope of the loop temp, it's essentially
+ # impossible to do while supporting generators)
+ temp_type = sequence_type
+ if temp_type.is_reference:
+ # 2) Sequence is a reference (often obtained by dereferencing a pointer);
+ # make the temp a pointer so we are not sensitive to users reassigning
+ # the pointer than it came from
+ temp_type = PyrexTypes.CPtrType(sequence_type.ref_base_type)
+ if temp_type.is_ptr:
+ self.cpp_attribute_op = "->"
+ # 3) (otherwise) sequence comes from a function call or similar, so we must
+ # create a temp to store it in
+ self.cpp_sequence_cname = code.funcstate.allocate_temp(temp_type, manage_ref=False)
+ code.putln("%s = %s%s;" % (self.cpp_sequence_cname,
+ "&" if temp_type.is_ptr else "",
+ self.sequence.move_result_rhs()))
+ code.putln("%s = %s%sbegin();" % (self.result(), self.cpp_sequence_cname,
+ self.cpp_attribute_op))
+
+ def generate_iter_next_result_code(self, result_name, code):
+ # end call isn't cached to support containers that allow adding while iterating
+ # (much as this is usually a bad idea)
+ code.putln("if (!(%s != %s%send())) break;" % (
+ self.result(),
+ self.cpp_sequence_cname or self.sequence.result(),
+ self.cpp_attribute_op))
+ code.putln("%s = *%s;" % (
+ result_name,
+ self.result()))
+ code.putln("++%s;" % self.result())
+
+ def free_temps(self, code):
+ if self.cpp_sequence_cname:
+ code.funcstate.release_temp(self.cpp_sequence_cname)
+ # skip over IteratorNode since we don't use any of the temps it does
ExprNode.free_temps(self, code)
@@ -3793,6 +3824,8 @@ class IndexNode(_IndexingBaseNode):
def analyse_as_c_array(self, env, is_slice):
base_type = self.base.type
self.type = base_type.base_type
+ if self.type.is_cpp_class:
+ self.type = PyrexTypes.CReferenceType(self.type)
if is_slice:
self.type = base_type
elif self.index.type.is_pyobject:
@@ -10313,7 +10346,7 @@ class DereferenceNode(CUnopNode):
def analyse_c_operation(self, env):
if self.operand.type.is_ptr:
- self.type = self.operand.type.base_type
+ self.type = PyrexTypes.CReferenceType(self.operand.type.base_type)
else:
self.type_error()
diff --git a/tests/run/cpp_iterators.pyx b/tests/run/cpp_iterators.pyx
index 04cdd6777..36782710f 100644
--- a/tests/run/cpp_iterators.pyx
+++ b/tests/run/cpp_iterators.pyx
@@ -140,3 +140,40 @@ def test_iteration_in_generator_reassigned():
if vint is not orig_vint:
del vint
del orig_vint
+
+cdef extern from *:
+ """
+ std::vector<int> make_vec1() {
+ std::vector<int> vint;
+ vint.push_back(1);
+ vint.push_back(2);
+ return vint;
+ }
+ """
+ cdef vector[int] make_vec1() except +
+
+cdef vector[int] make_vec2() except *:
+ return make_vec1()
+
+cdef vector[int] make_vec3():
+ try:
+ return make_vec1()
+ except:
+ pass
+
+def test_iteration_from_function_call():
+ """
+ >>> test_iteration_from_function_call()
+ 1
+ 2
+ 1
+ 2
+ 1
+ 2
+ """
+ for i in make_vec1():
+ print(i)
+ for i in make_vec2():
+ print(i)
+ for i in make_vec3():
+ print(i)