From 9cb557c37332ae50bfdbd675409c690cdd5fd908 Mon Sep 17 00:00:00 2001 From: da-woods Date: Mon, 20 Jul 2020 19:42:00 +0100 Subject: Handle `for x in cpp_function_call()` (GH-3667) Fixes https://github.com/cython/cython/issues/3663 This ensures that rvalues here are saved as temps, while keeping the existing behaviour for `for x in deref(vec)`, where the pointer for vec is copied, meaning it doesn't crash if vec is reassigned. The bit of this change liable to have the biggest effect is that I've changed the result type of dereference(x) and x[0] (where x is a c++ type) to a reference rather than value type. I think this is OK because it matches what C++ does. If that isn't a sensible change then I can probably inspect the loop sequence more closely to try to detect this. --- Cython/Compiler/ExprNodes.py | 185 +++++++++++++++++++++++++------------------ tests/run/cpp_iterators.pyx | 37 +++++++++ 2 files changed, 146 insertions(+), 76 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index fcb69130d..3a2c4cb1b 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -2658,7 +2658,6 @@ class IteratorNode(ExprNode): type = py_object_type iter_func_ptr = None counter_cname = None - cpp_iterator_cname = None reversed = False # currently only used for list/tuple types (see Optimize.py) is_async = False @@ -2671,7 +2670,7 @@ class IteratorNode(ExprNode): # C array iteration will be transformed later on self.type = self.sequence.type elif self.sequence.type.is_cpp_class: - self.analyse_cpp_types(env) + return CppIteratorNode(self.pos, sequence=self.sequence).analyse_types(env) else: self.sequence = self.sequence.coerce_to_pyobject(env) if self.sequence.type in (list_type, tuple_type): @@ -2701,65 +2700,10 @@ class IteratorNode(ExprNode): return sequence_type return py_object_type - def analyse_cpp_types(self, env): - sequence_type = self.sequence.type - if sequence_type.is_ptr: - sequence_type = sequence_type.base_type - begin = sequence_type.scope.lookup("begin") - end = sequence_type.scope.lookup("end") - if (begin is None - or not begin.type.is_cfunction - or begin.type.args): - error(self.pos, "missing begin() on %s" % self.sequence.type) - self.type = error_type - return - if (end is None - or not end.type.is_cfunction - or end.type.args): - error(self.pos, "missing end() on %s" % self.sequence.type) - self.type = error_type - return - iter_type = begin.type.return_type - if iter_type.is_cpp_class: - if env.lookup_operator_for_types( - self.pos, - "!=", - [iter_type, end.type.return_type]) is None: - error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type) - self.type = error_type - return - if env.lookup_operator_for_types(self.pos, '++', [iter_type]) is None: - error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type) - self.type = error_type - return - if env.lookup_operator_for_types(self.pos, '*', [iter_type]) is None: - error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type) - self.type = error_type - return - self.type = iter_type - elif iter_type.is_ptr: - if not (iter_type == end.type.return_type): - error(self.pos, "incompatible types for begin() and end()") - self.type = iter_type - else: - error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type) - self.type = error_type - return - def generate_result_code(self, code): sequence_type = self.sequence.type if sequence_type.is_cpp_class: - if self.sequence.is_name: - # safe: C++ won't allow you to reassign to class references - begin_func = "%s.begin" % self.sequence.result() - else: - sequence_type = PyrexTypes.c_ptr_type(sequence_type) - self.cpp_iterator_cname = code.funcstate.allocate_temp(sequence_type, manage_ref=False) - code.putln("%s = &%s;" % (self.cpp_iterator_cname, self.sequence.result())) - begin_func = "%s->begin" % self.cpp_iterator_cname - # TODO: Limit scope. - code.putln("%s = %s();" % (self.result(), begin_func)) - return + assert False, "Should have been changed to CppIteratorNode" if sequence_type.is_array or sequence_type.is_ptr: raise InternalError("for in carray slice not transformed") @@ -2855,21 +2799,7 @@ class IteratorNode(ExprNode): sequence_type = self.sequence.type if self.reversed: code.putln("if (%s < 0) break;" % self.counter_cname) - if sequence_type.is_cpp_class: - if self.cpp_iterator_cname: - end_func = "%s->end" % self.cpp_iterator_cname - else: - end_func = "%s.end" % self.sequence.result() - # TODO: Cache end() call? - code.putln("if (!(%s != %s())) break;" % ( - self.result(), - end_func)) - code.putln("%s = *%s;" % ( - result_name, - self.result())) - code.putln("++%s;" % self.result()) - return - elif sequence_type is list_type: + if sequence_type is list_type: self.generate_next_sequence_item('List', result_name, code) return elif sequence_type is tuple_type: @@ -2908,8 +2838,109 @@ class IteratorNode(ExprNode): if self.iter_func_ptr: code.funcstate.release_temp(self.iter_func_ptr) self.iter_func_ptr = None - if self.cpp_iterator_cname: - code.funcstate.release_temp(self.cpp_iterator_cname) + ExprNode.free_temps(self, code) + + +class CppIteratorNode(ExprNode): + # Iteration over a C++ container. + # Created at the analyse_types stage by IteratorNode + cpp_sequence_cname = None + cpp_attribute_op = "." + is_temp = True + + subexprs = ['sequence'] + + def analyse_types(self, env): + sequence_type = self.sequence.type + if sequence_type.is_ptr: + sequence_type = sequence_type.base_type + begin = sequence_type.scope.lookup("begin") + end = sequence_type.scope.lookup("end") + if (begin is None + or not begin.type.is_cfunction + or begin.type.args): + error(self.pos, "missing begin() on %s" % self.sequence.type) + self.type = error_type + return self + if (end is None + or not end.type.is_cfunction + or end.type.args): + error(self.pos, "missing end() on %s" % self.sequence.type) + self.type = error_type + return self + iter_type = begin.type.return_type + if iter_type.is_cpp_class: + if env.lookup_operator_for_types( + self.pos, + "!=", + [iter_type, end.type.return_type]) is None: + error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type) + self.type = error_type + return self + if env.lookup_operator_for_types(self.pos, '++', [iter_type]) is None: + error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type) + self.type = error_type + return self + if env.lookup_operator_for_types(self.pos, '*', [iter_type]) is None: + error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type) + self.type = error_type + return self + self.type = iter_type + elif iter_type.is_ptr: + if not (iter_type == end.type.return_type): + error(self.pos, "incompatible types for begin() and end()") + self.type = iter_type + else: + error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type) + self.type = error_type + return self + + def generate_result_code(self, code): + sequence_type = self.sequence.type + # essentially 3 options: + if self.sequence.is_name or self.sequence.is_attribute: + # 1) is a name and can be accessed directly; + # assigning to it may break the container, but that's the responsibility + # of the user + code.putln("%s = %s%sbegin();" % (self.result(), + self.sequence.result(), + self.cpp_attribute_op)) + else: + # (while it'd be nice to limit the scope of the loop temp, it's essentially + # impossible to do while supporting generators) + temp_type = sequence_type + if temp_type.is_reference: + # 2) Sequence is a reference (often obtained by dereferencing a pointer); + # make the temp a pointer so we are not sensitive to users reassigning + # the pointer than it came from + temp_type = PyrexTypes.CPtrType(sequence_type.ref_base_type) + if temp_type.is_ptr: + self.cpp_attribute_op = "->" + # 3) (otherwise) sequence comes from a function call or similar, so we must + # create a temp to store it in + self.cpp_sequence_cname = code.funcstate.allocate_temp(temp_type, manage_ref=False) + code.putln("%s = %s%s;" % (self.cpp_sequence_cname, + "&" if temp_type.is_ptr else "", + self.sequence.move_result_rhs())) + code.putln("%s = %s%sbegin();" % (self.result(), self.cpp_sequence_cname, + self.cpp_attribute_op)) + + def generate_iter_next_result_code(self, result_name, code): + # end call isn't cached to support containers that allow adding while iterating + # (much as this is usually a bad idea) + code.putln("if (!(%s != %s%send())) break;" % ( + self.result(), + self.cpp_sequence_cname or self.sequence.result(), + self.cpp_attribute_op)) + code.putln("%s = *%s;" % ( + result_name, + self.result())) + code.putln("++%s;" % self.result()) + + def free_temps(self, code): + if self.cpp_sequence_cname: + code.funcstate.release_temp(self.cpp_sequence_cname) + # skip over IteratorNode since we don't use any of the temps it does ExprNode.free_temps(self, code) @@ -3793,6 +3824,8 @@ class IndexNode(_IndexingBaseNode): def analyse_as_c_array(self, env, is_slice): base_type = self.base.type self.type = base_type.base_type + if self.type.is_cpp_class: + self.type = PyrexTypes.CReferenceType(self.type) if is_slice: self.type = base_type elif self.index.type.is_pyobject: @@ -10313,7 +10346,7 @@ class DereferenceNode(CUnopNode): def analyse_c_operation(self, env): if self.operand.type.is_ptr: - self.type = self.operand.type.base_type + self.type = PyrexTypes.CReferenceType(self.operand.type.base_type) else: self.type_error() diff --git a/tests/run/cpp_iterators.pyx b/tests/run/cpp_iterators.pyx index 04cdd6777..36782710f 100644 --- a/tests/run/cpp_iterators.pyx +++ b/tests/run/cpp_iterators.pyx @@ -140,3 +140,40 @@ def test_iteration_in_generator_reassigned(): if vint is not orig_vint: del vint del orig_vint + +cdef extern from *: + """ + std::vector make_vec1() { + std::vector vint; + vint.push_back(1); + vint.push_back(2); + return vint; + } + """ + cdef vector[int] make_vec1() except + + +cdef vector[int] make_vec2() except *: + return make_vec1() + +cdef vector[int] make_vec3(): + try: + return make_vec1() + except: + pass + +def test_iteration_from_function_call(): + """ + >>> test_iteration_from_function_call() + 1 + 2 + 1 + 2 + 1 + 2 + """ + for i in make_vec1(): + print(i) + for i in make_vec2(): + print(i) + for i in make_vec3(): + print(i) -- cgit v1.2.1