diff options
author | da-woods <dw-git@d-woods.co.uk> | 2020-06-30 10:23:21 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-30 11:23:21 +0200 |
commit | ab1d7284f79794d075e79c87df671a87ae8b2b4b (patch) | |
tree | 6056b993f5b9ea4138383e0a0c38b5c58c427d37 | |
parent | 8b228a71038eb56e32fb5e27efae91520a9ba05f (diff) | |
download | cython-ab1d7284f79794d075e79c87df671a87ae8b2b4b.tar.gz |
Implement generic optimized loop iterator with indexing and type inference for memoryviews (GH-3617)
* Adds bytearray iteration since that was not previously optimised (because it allows changing length during iteration).
* Always set `entry.init` for memoryviewslice.
-rw-r--r-- | Cython/Compiler/ExprNodes.py | 22 | ||||
-rw-r--r-- | Cython/Compiler/Optimize.py | 92 | ||||
-rw-r--r-- | Cython/Compiler/Options.py | 10 | ||||
-rw-r--r-- | Cython/Compiler/ParseTreeTransforms.py | 7 | ||||
-rw-r--r-- | Cython/Compiler/PyrexTypes.py | 5 | ||||
-rw-r--r-- | Cython/Compiler/TypeInference.py | 7 | ||||
-rw-r--r-- | Cython/Compiler/UtilNodes.py | 3 | ||||
-rw-r--r-- | tests/memoryview/memoryview.pyx | 89 | ||||
-rw-r--r-- | tests/memoryview/memslice.pyx | 2 | ||||
-rw-r--r-- | tests/memoryview/numpy_memoryview.pyx | 4 | ||||
-rw-r--r-- | tests/run/bytearray_iter.py | 90 |
11 files changed, 294 insertions, 37 deletions
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 9f825af09..f8579d48c 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3564,6 +3564,8 @@ class IndexNode(_IndexingBaseNode): bytearray_type, list_type, tuple_type): # slicing these returns the same type return base_type + elif base_type.is_memoryviewslice: + return base_type else: # TODO: Handle buffers (hopefully without too much redundancy). return py_object_type @@ -3606,6 +3608,23 @@ class IndexNode(_IndexingBaseNode): index += base_type.size if 0 <= index < base_type.size: return base_type.components[index] + elif base_type.is_memoryviewslice: + if base_type.ndim == 0: + pass # probably an error, but definitely don't know what to do - return pyobject for now + if base_type.ndim == 1: + return base_type.dtype + else: + return PyrexTypes.MemoryViewSliceType(base_type.dtype, base_type.axes[1:]) + + if self.index.is_sequence_constructor and base_type.is_memoryviewslice: + inferred_type = base_type + for a in self.index.args: + if not inferred_type.is_memoryviewslice: + break # something's gone wrong + inferred_type = IndexNode(self.pos, base=ExprNode(self.base.pos, type=inferred_type), + index=a).infer_type(env) + else: + return inferred_type if base_type.is_cpp_class: class FakeOperand: @@ -13466,6 +13485,9 @@ class CoerceToTempNode(CoercionNode): # The arg is always already analysed return self + def may_be_none(self): + return self.arg.may_be_none() + def coerce_to_boolean(self, env): self.arg = self.arg.coerce_to_boolean(env) if self.arg.is_simple(): diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index ab8ec9e2c..96b2076ef 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -228,6 +228,12 @@ class IterationTransform(Visitor.EnvTransform): return self._transform_bytes_iteration(node, iterable, reversed=reversed) if iterable.type is Builtin.unicode_type: return self._transform_unicode_iteration(node, iterable, reversed=reversed) + # in principle _transform_indexable_iteration would work on most of the above, and + # also tuple and list. However, it probably isn't quite as optimized + if iterable.type is Builtin.bytearray_type: + return self._transform_indexable_iteration(node, iterable, is_mutable=True, reversed=reversed) + if isinstance(iterable, ExprNodes.CoerceToPyTypeNode) and iterable.arg.type.is_memoryviewslice: + return self._transform_indexable_iteration(node, iterable.arg, is_mutable=False, reversed=reversed) # the rest is based on function calls if not isinstance(iterable, ExprNodes.SimpleCallNode): @@ -333,6 +339,92 @@ class IterationTransform(Visitor.EnvTransform): PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) ]) + def _transform_indexable_iteration(self, node, slice_node, is_mutable, reversed=False): + """In principle can handle any iterable that Cython has a len() for and knows how to index""" + unpack_temp_node = UtilNodes.LetRefNode( + slice_node.as_none_safe_node("'NoneType' is not iterable"), + may_hold_none=False, is_temp=True + ) + + start_node = ExprNodes.IntNode( + node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type) + def make_length_call(): + # helper function since we need to create this node for a couple of places + builtin_len = ExprNodes.NameNode(node.pos, name="len", + entry=Builtin.builtin_scope.lookup("len")) + return ExprNodes.SimpleCallNode(node.pos, + function=builtin_len, + args=[unpack_temp_node] + ) + length_temp = UtilNodes.LetRefNode(make_length_call(), type=PyrexTypes.c_py_ssize_t_type, is_temp=True) + end_node = length_temp + + if reversed: + relation1, relation2 = '>', '>=' + start_node, end_node = end_node, start_node + else: + relation1, relation2 = '<=', '<' + + counter_ref = UtilNodes.LetRefNode(pos=node.pos, type=PyrexTypes.c_py_ssize_t_type) + + target_value = ExprNodes.IndexNode(slice_node.pos, base=unpack_temp_node, + index=counter_ref) + + target_assign = Nodes.SingleAssignmentNode( + pos = node.target.pos, + lhs = node.target, + rhs = target_value) + + # analyse with boundscheck and wraparound + # off (because we're confident we know the size) + env = self.current_env() + new_directives = Options.copy_inherited_directives(env.directives, boundscheck=False, wraparound=False) + target_assign = Nodes.CompilerDirectivesNode( + target_assign.pos, + directives=new_directives, + body=target_assign, + ) + + body = Nodes.StatListNode( + node.pos, + stats = [target_assign]) # exclude node.body for now to not reanalyse it + if is_mutable: + # We need to be slightly careful here that we are actually modifying the loop + # bounds and not a temp copy of it. Setting is_temp=True on length_temp seems + # to ensure this. + # If this starts to fail then we could insert an "if out_of_bounds: break" instead + loop_length_reassign = Nodes.SingleAssignmentNode(node.pos, + lhs = length_temp, + rhs = make_length_call()) + body.stats.append(loop_length_reassign) + + loop_node = Nodes.ForFromStatNode( + node.pos, + bound1=start_node, relation1=relation1, + target=counter_ref, + relation2=relation2, bound2=end_node, + step=None, body=body, + else_clause=node.else_clause, + from_range=True) + + ret = UtilNodes.LetNode( + unpack_temp_node, + UtilNodes.LetNode( + length_temp, + # TempResultFromStatNode provides the framework where the "counter_ref" + # temp is set up and can be assigned to. However, we don't need the + # result it returns so wrap it in an ExprStatNode. + Nodes.ExprStatNode(node.pos, + expr=UtilNodes.TempResultFromStatNode( + counter_ref, + loop_node + ) + ) + ) + ).analyse_expressions(env) + body.stats.insert(1, node.body) + return ret + def _transform_bytes_iteration(self, node, slice_node, reversed=False): target_type = node.target.type if not target_type.is_int and target_type is not Builtin.bytes_type: diff --git a/Cython/Compiler/Options.py b/Cython/Compiler/Options.py index bfc9d1f80..f20239321 100644 --- a/Cython/Compiler/Options.py +++ b/Cython/Compiler/Options.py @@ -166,6 +166,16 @@ def get_directive_defaults(): _directive_defaults[old_option.directive_name] = value return _directive_defaults +def copy_inherited_directives(outer_directives, **new_directives): + # A few directives are not copied downwards and this function removes them. + # For example, test_assert_path_exists and test_fail_if_path_exists should not be inherited + # otherwise they can produce very misleading test failures + new_directives_out = dict(outer_directives) + for name in ('test_assert_path_exists', 'test_fail_if_path_exists'): + new_directives_out.pop(name, None) + new_directives_out.update(new_directives) + return new_directives_out + # Declare compiler directives _directive_defaults = { 'binding': True, # was False before 3.0 diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 7ccee4fb7..c31fa9f65 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -992,12 +992,7 @@ class InterpretCompilerDirectives(CythonTransform): return self.visit_Node(node) old_directives = self.directives - new_directives = dict(old_directives) - # test_assert_path_exists and test_fail_if_path_exists should not be inherited - # otherwise they can produce very misleading test failures - new_directives.pop('test_assert_path_exists', None) - new_directives.pop('test_fail_if_path_exists', None) - new_directives.update(directives) + new_directives = Options.copy_inherited_directives(old_directives, **directives) if new_directives == old_directives: return self.visit_Node(node) diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index dc63675c3..3a7a8da5a 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -672,6 +672,10 @@ class MemoryViewSliceType(PyrexType): else: return False + def __ne__(self, other): + # TODO drop when Python2 is dropped + return not (self == other) + def same_as_resolved_type(self, other_type): return ((other_type.is_memoryviewslice and #self.writable_needed == other_type.writable_needed and # FIXME: should be only uni-directional @@ -2516,6 +2520,7 @@ class CPointerBaseType(CType): if self.is_string: assert isinstance(value, str) return '"%s"' % StringEncoding.escape_byte_string(value) + return str(value) class CArrayType(CPointerBaseType): diff --git a/Cython/Compiler/TypeInference.py b/Cython/Compiler/TypeInference.py index 461829a85..1e46050d1 100644 --- a/Cython/Compiler/TypeInference.py +++ b/Cython/Compiler/TypeInference.py @@ -140,7 +140,6 @@ class MarkParallelAssignments(EnvTransform): '+', sequence.args[0], sequence.args[2])) - if not is_special: # A for-loop basically translates to subsequent calls to # __getitem__(), so using an IndexNode here allows us to @@ -360,9 +359,11 @@ class SimpleAssignmentTypeInferer(object): applies to nested scopes in top-down order. """ def set_entry_type(self, entry, entry_type): - entry.type = entry_type for e in entry.all_entries(): e.type = entry_type + if e.type.is_memoryviewslice: + # memoryview slices crash if they don't get initialized + e.init = e.type.default_value def infer_types(self, scope): enabled = scope.directives['infer_types'] @@ -577,6 +578,8 @@ def safe_spanning_type(types, might_overflow, pos, scope): # used, won't arise in pure Python, and there shouldn't be side # effects, so I'm declaring this safe. return result_type + elif result_type.is_memoryviewslice: + return result_type # TODO: double complex should be OK as well, but we need # to make sure everything is supported. elif (result_type.is_int or result_type.is_enum) and not might_overflow: diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py index 156b1f8af..0d5db9d33 100644 --- a/Cython/Compiler/UtilNodes.py +++ b/Cython/Compiler/UtilNodes.py @@ -360,3 +360,6 @@ class TempResultFromStatNode(ExprNodes.ExprNode): def generate_result_code(self, code): self.result_ref.result_code = self.result() self.body.generate_execution_code(code) + + def generate_function_definitions(self, env, code): + self.body.generate_function_definitions(env, code) diff --git a/tests/memoryview/memoryview.pyx b/tests/memoryview/memoryview.pyx index 25cc0916d..70a3a0412 100644 --- a/tests/memoryview/memoryview.pyx +++ b/tests/memoryview/memoryview.pyx @@ -247,7 +247,7 @@ def basic_struct(MyStruct[:] mslice): >>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="ccqii")) [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)] """ - buf = mslice + cdef object buf = mslice print sorted([(k, int(v)) for k, v in buf[0].items()]) def nested_struct(NestedStruct[:] mslice): @@ -259,7 +259,7 @@ def nested_struct(NestedStruct[:] mslice): >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{ii}T{2i}i")) 1 2 3 4 5 """ - buf = mslice + cdef object buf = mslice d = buf[0] print d['x']['a'], d['x']['b'], d['y']['a'], d['y']['b'], d['z'] @@ -275,7 +275,7 @@ def packed_struct(PackedStruct[:] mslice): 1 2 """ - buf = mslice + cdef object buf = mslice print buf[0]['a'], buf[0]['b'] def nested_packed_struct(NestedPackedStruct[:] mslice): @@ -289,7 +289,7 @@ def nested_packed_struct(NestedPackedStruct[:] mslice): >>> nested_packed_struct(NestedPackedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="^c@i^ci@i")) 1 2 3 4 5 """ - buf = mslice + cdef object buf = mslice d = buf[0] print d['a'], d['b'], d['sub']['a'], d['sub']['b'], d['c'] @@ -299,7 +299,7 @@ def complex_dtype(long double complex[:] mslice): >>> complex_dtype(LongComplexMockBuffer(None, [(0, -1)])) -1j """ - buf = mslice + cdef object buf = mslice print buf[0] def complex_inplace(long double complex[:] mslice): @@ -307,7 +307,7 @@ def complex_inplace(long double complex[:] mslice): >>> complex_inplace(LongComplexMockBuffer(None, [(0, -1)])) (1+1j) """ - buf = mslice + cdef object buf = mslice buf[0] = buf[0] + 1 + 2j print buf[0] @@ -318,7 +318,7 @@ def complex_struct_dtype(LongComplex[:] mslice): >>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)])) 0.0 -1.0 """ - buf = mslice + cdef object buf = mslice print buf[0]['real'], buf[0]['imag'] # @@ -356,7 +356,7 @@ def get_int_2d(int[:, :] mslice, int i, int j): ... IndexError: Out of bounds on buffer access (axis 1) """ - buf = mslice + cdef object buf = mslice return buf[i, j] def set_int_2d(int[:, :] mslice, int i, int j, int value): @@ -409,11 +409,48 @@ def set_int_2d(int[:, :] mslice, int i, int j, int value): IndexError: Out of bounds on buffer access (axis 1) """ - buf = mslice + cdef object buf = mslice buf[i, j] = value # +# auto type inference +# (note that for most numeric types "might_overflow" stops the type inference from working well) +# +def type_infer(double[:, :] arg): + """ + >>> type_infer(DoubleMockBuffer(None, range(6), (2,3))) + double + double[:] + double[:] + double[:, :] + """ + a = arg[0,0] + print(cython.typeof(a)) + b = arg[0] + print(cython.typeof(b)) + c = arg[0,:] + print(cython.typeof(c)) + d = arg[:,:] + print(cython.typeof(d)) + +# +# Loop optimization +# +@cython.test_fail_if_path_exists("//CoerceToPyTypeNode") +def memview_iter(double[:, :] arg): + """ + memview_iter(DoubleMockBuffer("C", range(6), (2,3))) + True + """ + cdef double total = 0 + for mview1d in arg: + for val in mview1d: + total += val + if total == 15: + return True + +# # Test all kinds of indexing and flags # @@ -426,7 +463,7 @@ def writable(unsigned short int[:, :, :] mslice): >>> [str(x) for x in R.received_flags] # Py2/3 ['FORMAT', 'ND', 'STRIDES', 'WRITABLE'] """ - buf = mslice + cdef object buf = mslice buf[2, 2, 1] = 23 def strided(int[:] mslice): @@ -441,7 +478,7 @@ def strided(int[:] mslice): >>> A.release_ok True """ - buf = mslice + cdef object buf = mslice return buf[2] def c_contig(int[::1] mslice): @@ -450,7 +487,7 @@ def c_contig(int[::1] mslice): >>> c_contig(A) 2 """ - buf = mslice + cdef object buf = mslice return buf[2] def c_contig_2d(int[:, ::1] mslice): @@ -461,7 +498,7 @@ def c_contig_2d(int[:, ::1] mslice): >>> c_contig_2d(A) 7 """ - buf = mslice + cdef object buf = mslice return buf[1, 3] def f_contig(int[::1, :] mslice): @@ -470,7 +507,7 @@ def f_contig(int[::1, :] mslice): >>> f_contig(A) 2 """ - buf = mslice + cdef object buf = mslice return buf[0, 1] def f_contig_2d(int[::1, :] mslice): @@ -481,7 +518,7 @@ def f_contig_2d(int[::1, :] mslice): >>> f_contig_2d(A) 7 """ - buf = mslice + cdef object buf = mslice return buf[3, 1] def generic(int[::view.generic, ::view.generic] mslice1, @@ -552,7 +589,7 @@ def printbuf_td_cy_int(td_cy_int[:] mslice, shape): ... ValueError: Buffer dtype mismatch, expected 'td_cy_int' but got 'short' """ - buf = mslice + cdef object buf = mslice cdef int i for i in range(shape[0]): print buf[i], @@ -567,7 +604,7 @@ def printbuf_td_h_short(td_h_short[:] mslice, shape): ... ValueError: Buffer dtype mismatch, expected 'td_h_short' but got 'int' """ - buf = mslice + cdef object buf = mslice cdef int i for i in range(shape[0]): print buf[i], @@ -582,7 +619,7 @@ def printbuf_td_h_cy_short(td_h_cy_short[:] mslice, shape): ... ValueError: Buffer dtype mismatch, expected 'td_h_cy_short' but got 'int' """ - buf = mslice + cdef object buf = mslice cdef int i for i in range(shape[0]): print buf[i], @@ -597,7 +634,7 @@ def printbuf_td_h_ushort(td_h_ushort[:] mslice, shape): ... ValueError: Buffer dtype mismatch, expected 'td_h_ushort' but got 'short' """ - buf = mslice + cdef object buf = mslice cdef int i for i in range(shape[0]): print buf[i], @@ -612,7 +649,7 @@ def printbuf_td_h_double(td_h_double[:] mslice, shape): ... ValueError: Buffer dtype mismatch, expected 'td_h_double' but got 'float' """ - buf = mslice + cdef object buf = mslice cdef int i for i in range(shape[0]): print buf[i], @@ -649,7 +686,7 @@ def printbuf_object(object[:] mslice, shape): {4: 23} 2 [34, 3] 2 """ - buf = mslice + cdef object buf = mslice cdef int i for i in range(shape[0]): print repr(buf[i]), (<PyObject*>buf[i]).ob_refcnt @@ -670,7 +707,7 @@ def assign_to_object(object[:] mslice, int idx, obj): (2, 3) >>> decref(b) """ - buf = mslice + cdef object buf = mslice buf[idx] = obj def assign_temporary_to_object(object[:] mslice): @@ -697,7 +734,7 @@ def assign_temporary_to_object(object[:] mslice): >>> assign_to_object(A, 1, a) >>> decref(a) """ - buf = mslice + cdef object buf = mslice buf[1] = {3-2: 2+(2*4)-2} @@ -745,7 +782,7 @@ def test_generic_slicing(arg, indirect=False): """ cdef int[::view.generic, ::view.generic, :] _a = arg - a = _a + cdef object a = _a b = a[2:8:2, -4:1:-1, 1:3] print b.shape @@ -828,7 +865,7 @@ def test_direct_slicing(arg): released A """ cdef int[:, :, :] _a = arg - a = _a + cdef object a = _a b = a[2:8:2, -4:1:-1, 1:3] print b.shape @@ -856,7 +893,7 @@ def test_slicing_and_indexing(arg): released A """ cdef int[:, :, :] _a = arg - a = _a + cdef object a = _a b = a[-5:, 1, 1::2] c = b[4:1:-1, ::-1] d = c[2, 1:2] diff --git a/tests/memoryview/memslice.pyx b/tests/memoryview/memslice.pyx index 26dd802ef..d566d9a1e 100644 --- a/tests/memoryview/memslice.pyx +++ b/tests/memoryview/memslice.pyx @@ -1525,7 +1525,7 @@ def test_index_slicing_away_direct_indirect(): All dimensions preceding dimension 1 must be indexed and not sliced """ cdef int[:, ::view.indirect, :] a = TestIndexSlicingDirectIndirectDims() - a_obj = a + cdef object a_obj = a print a[1][2][3] print a[1, 2, 3] diff --git a/tests/memoryview/numpy_memoryview.pyx b/tests/memoryview/numpy_memoryview.pyx index 9b18be615..0d6f9e12f 100644 --- a/tests/memoryview/numpy_memoryview.pyx +++ b/tests/memoryview/numpy_memoryview.pyx @@ -186,7 +186,7 @@ def test_transpose(): numpy_obj = np.arange(4 * 3, dtype=np.int32).reshape(4, 3) a = numpy_obj - a_obj = a + cdef object a_obj = a cdef dtype_t[:, :] b = a.T print a.T.shape[0], a.T.shape[1] @@ -244,7 +244,7 @@ def test_copy_and_contig_attributes(a): >>> test_copy_and_contig_attributes(a) """ cdef np.int32_t[:, :] mslice = a - m = mslice + cdef object m = mslice # object copy # Test object copy attributes assert np.all(a == np.array(m.copy())) diff --git a/tests/run/bytearray_iter.py b/tests/run/bytearray_iter.py new file mode 100644 index 000000000..1865f057b --- /dev/null +++ b/tests/run/bytearray_iter.py @@ -0,0 +1,90 @@ +# mode: run +# tag: pure3, pure2 + +import cython + +@cython.test_assert_path_exists("//ForFromStatNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +@cython.locals(x=bytearray) +def basic_bytearray_iter(x): + """ + >>> basic_bytearray_iter(bytearray(b"hello")) + h + e + l + l + o + """ + for a in x: + print(chr(a)) + +@cython.test_assert_path_exists("//ForFromStatNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +@cython.locals(x=bytearray) +def reversed_bytearray_iter(x): + """ + >>> reversed_bytearray_iter(bytearray(b"hello")) + o + l + l + e + h + """ + for a in reversed(x): + print(chr(a)) + +@cython.test_assert_path_exists("//ForFromStatNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +@cython.locals(x=bytearray) +def modifying_bytearray_iter1(x): + """ + >>> modifying_bytearray_iter1(bytearray(b"abcdef")) + a + b + c + 3 + """ + count = 0 + for a in x: + print(chr(a)) + del x[-1] + count += 1 + print(count) + +@cython.test_assert_path_exists("//ForFromStatNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +@cython.locals(x=bytearray) +def modifying_bytearray_iter2(x): + """ + >>> modifying_bytearray_iter2(bytearray(b"abcdef")) + a + c + e + 3 + """ + count = 0 + for a in x: + print(chr(a)) + del x[0] + count += 1 + print(count) + +@cython.test_assert_path_exists("//ForFromStatNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +@cython.locals(x=bytearray) +def modifying_reversed_bytearray_iter(x): + """ + NOTE - I'm not 100% sure how well-defined this behaviour is in Python. + However, for the moment Python and Cython seem to do the same thing. + Testing that it doesn't crash is probably more important than the exact output! + >>> modifying_reversed_bytearray_iter(bytearray(b"abcdef")) + f + f + f + f + f + f + """ + for a in reversed(x): + print(chr(a)) + del x[0] |