diff options
Diffstat (limited to 'Cython/Compiler/Optimize.py')
-rw-r--r-- | Cython/Compiler/Optimize.py | 585 |
1 files changed, 469 insertions, 116 deletions
diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 7e9435ba0..fb6dc5dae 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -41,7 +41,7 @@ except ImportError: try: from __builtin__ import basestring except ImportError: - basestring = str # Python 3 + basestring = str # Python 3 def load_c_utility(name): @@ -192,19 +192,9 @@ class IterationTransform(Visitor.EnvTransform): def _optimise_for_loop(self, node, iterable, reversed=False): annotation_type = None if (iterable.is_name or iterable.is_attribute) and iterable.entry and iterable.entry.annotation: - annotation = iterable.entry.annotation + annotation = iterable.entry.annotation.expr if annotation.is_subscript: annotation = annotation.base # container base type - # FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names? - if annotation.is_name: - if annotation.entry and annotation.entry.qualified_name == 'typing.Dict': - annotation_type = Builtin.dict_type - elif annotation.name == 'Dict': - annotation_type = Builtin.dict_type - if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'): - annotation_type = Builtin.set_type - elif annotation.name in ('Set', 'FrozenSet'): - annotation_type = Builtin.set_type if Builtin.dict_type in (iterable.type, annotation_type): # like iterating over dict.keys() @@ -228,6 +218,12 @@ class IterationTransform(Visitor.EnvTransform): return self._transform_bytes_iteration(node, iterable, reversed=reversed) if iterable.type is Builtin.unicode_type: return self._transform_unicode_iteration(node, iterable, reversed=reversed) + # in principle _transform_indexable_iteration would work on most of the above, and + # also tuple and list. However, it probably isn't quite as optimized + if iterable.type is Builtin.bytearray_type: + return self._transform_indexable_iteration(node, iterable, is_mutable=True, reversed=reversed) + if isinstance(iterable, ExprNodes.CoerceToPyTypeNode) and iterable.arg.type.is_memoryviewslice: + return self._transform_indexable_iteration(node, iterable.arg, is_mutable=False, reversed=reversed) # the rest is based on function calls if not isinstance(iterable, ExprNodes.SimpleCallNode): @@ -323,6 +319,92 @@ class IterationTransform(Visitor.EnvTransform): return self._optimise_for_loop(node, arg, reversed=True) + def _transform_indexable_iteration(self, node, slice_node, is_mutable, reversed=False): + """In principle can handle any iterable that Cython has a len() for and knows how to index""" + unpack_temp_node = UtilNodes.LetRefNode( + slice_node.as_none_safe_node("'NoneType' is not iterable"), + may_hold_none=False, is_temp=True + ) + + start_node = ExprNodes.IntNode( + node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type) + def make_length_call(): + # helper function since we need to create this node for a couple of places + builtin_len = ExprNodes.NameNode(node.pos, name="len", + entry=Builtin.builtin_scope.lookup("len")) + return ExprNodes.SimpleCallNode(node.pos, + function=builtin_len, + args=[unpack_temp_node] + ) + length_temp = UtilNodes.LetRefNode(make_length_call(), type=PyrexTypes.c_py_ssize_t_type, is_temp=True) + end_node = length_temp + + if reversed: + relation1, relation2 = '>', '>=' + start_node, end_node = end_node, start_node + else: + relation1, relation2 = '<=', '<' + + counter_ref = UtilNodes.LetRefNode(pos=node.pos, type=PyrexTypes.c_py_ssize_t_type) + + target_value = ExprNodes.IndexNode(slice_node.pos, base=unpack_temp_node, + index=counter_ref) + + target_assign = Nodes.SingleAssignmentNode( + pos = node.target.pos, + lhs = node.target, + rhs = target_value) + + # analyse with boundscheck and wraparound + # off (because we're confident we know the size) + env = self.current_env() + new_directives = Options.copy_inherited_directives(env.directives, boundscheck=False, wraparound=False) + target_assign = Nodes.CompilerDirectivesNode( + target_assign.pos, + directives=new_directives, + body=target_assign, + ) + + body = Nodes.StatListNode( + node.pos, + stats = [target_assign]) # exclude node.body for now to not reanalyse it + if is_mutable: + # We need to be slightly careful here that we are actually modifying the loop + # bounds and not a temp copy of it. Setting is_temp=True on length_temp seems + # to ensure this. + # If this starts to fail then we could insert an "if out_of_bounds: break" instead + loop_length_reassign = Nodes.SingleAssignmentNode(node.pos, + lhs = length_temp, + rhs = make_length_call()) + body.stats.append(loop_length_reassign) + + loop_node = Nodes.ForFromStatNode( + node.pos, + bound1=start_node, relation1=relation1, + target=counter_ref, + relation2=relation2, bound2=end_node, + step=None, body=body, + else_clause=node.else_clause, + from_range=True) + + ret = UtilNodes.LetNode( + unpack_temp_node, + UtilNodes.LetNode( + length_temp, + # TempResultFromStatNode provides the framework where the "counter_ref" + # temp is set up and can be assigned to. However, we don't need the + # result it returns so wrap it in an ExprStatNode. + Nodes.ExprStatNode(node.pos, + expr=UtilNodes.TempResultFromStatNode( + counter_ref, + loop_node + ) + ) + ) + ).analyse_expressions(env) + body.stats.insert(1, node.body) + return ret + PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType( PyrexTypes.c_char_ptr_type, [ PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) @@ -1144,7 +1226,7 @@ class SwitchTransform(Visitor.EnvTransform): # integers on iteration, whereas Py2 returns 1-char byte # strings characters = string_literal.value - characters = list(set([ characters[i:i+1] for i in range(len(characters)) ])) + characters = list({ characters[i:i+1] for i in range(len(characters)) }) characters.sort() return [ ExprNodes.CharNode(string_literal.pos, value=charval, constant_result=charval) @@ -1156,7 +1238,8 @@ class SwitchTransform(Visitor.EnvTransform): return self.NO_MATCH elif common_var is not None and not is_common_value(var, common_var): return self.NO_MATCH - elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]): + elif not (var.type.is_int or var.type.is_enum) or any( + [not (cond.type.is_int or cond.type.is_enum) for cond in conditions]): return self.NO_MATCH return not_in, var, conditions @@ -1573,7 +1656,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): utility_code = utility_code) def _error_wrong_arg_count(self, function_name, node, args, expected=None): - if not expected: # None or 0 + if not expected: # None or 0 arg_str = '' elif isinstance(expected, basestring) or expected > 1: arg_str = '...' @@ -1727,7 +1810,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): arg = pos_args[0] if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type: - list_node = pos_args[0] + list_node = arg loop_node = list_node.loop elif isinstance(arg, ExprNodes.GeneratorExpressionNode): @@ -1757,7 +1840,11 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): # Interestingly, PySequence_List works on a lot of non-sequence # things as well. list_node = loop_node = ExprNodes.PythonCapiCallNode( - node.pos, "PySequence_List", self.PySequence_List_func_type, + node.pos, + "__Pyx_PySequence_ListKeepNew" + if arg.is_temp and arg.type in (PyrexTypes.py_object_type, Builtin.list_type) + else "PySequence_List", + self.PySequence_List_func_type, args=pos_args, is_temp=True) result_node = UtilNodes.ResultRefNode( @@ -1803,7 +1890,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): if not yield_expression.is_literal or not yield_expression.type.is_int: return node except AttributeError: - return node # in case we don't have a type yet + return node # in case we don't have a type yet # special case: old Py2 backwards compatible "sum([int_const for ...])" # can safely be unpacked into a genexpr @@ -2018,7 +2105,8 @@ class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform): return node inlined = ExprNodes.InlinedDefNodeCallNode( node.pos, function_name=function_name, - function=function, args=node.args) + function=function, args=node.args, + generator_arg_tag=node.generator_arg_tag) if inlined.can_be_inlined(): return self.replace(node, inlined) return node @@ -2097,12 +2185,13 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, func_arg = arg.args[0] if func_arg.type is Builtin.float_type: return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'") - elif func_arg.type.is_pyobject: + elif func_arg.type.is_pyobject and arg.function.cname == "__Pyx_PyObject_AsDouble": return ExprNodes.PythonCapiCallNode( node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type, args=[func_arg], py_name='float', is_temp=node.is_temp, + utility_code = UtilityCode.load_cached("pynumber_float", "TypeConversion.c"), result_is_used=node.result_is_used, ).coerce_to(node.type, self.current_env()) return node @@ -2210,6 +2299,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, if func_arg.type.is_int or node.type.is_int: if func_arg.type == node.type: return func_arg + elif func_arg.type in (PyrexTypes.c_py_ucs4_type, PyrexTypes.c_py_unicode_type): + # need to parse (<Py_UCS4>'1') as digit 1 + return self._pyucs4_to_number(node, function.name, func_arg) elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type) elif func_arg.type.is_float and node.type.is_numeric: @@ -2230,13 +2322,40 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, if func_arg.type.is_float or node.type.is_float: if func_arg.type == node.type: return func_arg + elif func_arg.type in (PyrexTypes.c_py_ucs4_type, PyrexTypes.c_py_unicode_type): + # need to parse (<Py_UCS4>'1') as digit 1 + return self._pyucs4_to_number(node, function.name, func_arg) elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: return ExprNodes.TypecastNode( node.pos, operand=func_arg, type=node.type) return node + pyucs4_int_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_int_type, [ + PyrexTypes.CFuncTypeArg("arg", PyrexTypes.c_py_ucs4_type, None) + ], + exception_value="-1") + + pyucs4_double_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_double_type, [ + PyrexTypes.CFuncTypeArg("arg", PyrexTypes.c_py_ucs4_type, None) + ], + exception_value="-1.0") + + def _pyucs4_to_number(self, node, py_type_name, func_arg): + assert py_type_name in ("int", "float") + return ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_int_from_UCS4" if py_type_name == "int" else "__Pyx_double_from_UCS4", + func_type=self.pyucs4_int_func_type if py_type_name == "int" else self.pyucs4_double_func_type, + args=[func_arg], + py_name=py_type_name, + is_temp=node.is_temp, + result_is_used=node.result_is_used, + utility_code=UtilityCode.load_cached("int_pyucs4" if py_type_name == "int" else "float_pyucs4", "Builtins.c"), + ).coerce_to(node.type, self.current_env()) + def _error_wrong_arg_count(self, function_name, node, args, expected=None): - if not expected: # None or 0 + if not expected: # None or 0 arg_str = '' elif isinstance(expected, basestring) or expected > 1: arg_str = '...' @@ -2316,6 +2435,38 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return ExprNodes.CachedBuiltinMethodCallNode( node, function.obj, attr_name, arg_list) + PyObject_String_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ # Change this to Builtin.str_type when removing Py2 support. + PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) + ]) + + def _handle_simple_function_str(self, node, function, pos_args): + """Optimize single argument calls to str(). + """ + if len(pos_args) != 1: + if len(pos_args) == 0: + return ExprNodes.StringNode(node.pos, value=EncodedString(), constant_result='') + return node + arg = pos_args[0] + + if arg.type is Builtin.str_type: + if not arg.may_be_none(): + return arg + + cname = "__Pyx_PyStr_Str" + utility_code = UtilityCode.load_cached('PyStr_Str', 'StringTools.c') + else: + cname = '__Pyx_PyObject_Str' + utility_code = UtilityCode.load_cached('PyObject_Str', 'StringTools.c') + + return ExprNodes.PythonCapiCallNode( + node.pos, cname, self.PyObject_String_func_type, + args=pos_args, + is_temp=node.is_temp, + utility_code=utility_code, + py_name="str" + ) + PyObject_Unicode_func_type = PyrexTypes.CFuncType( Builtin.unicode_type, [ PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) @@ -2387,8 +2538,14 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return node arg = pos_args[0] return ExprNodes.PythonCapiCallNode( - node.pos, "PySequence_List", self.PySequence_List_func_type, - args=pos_args, is_temp=node.is_temp) + node.pos, + "__Pyx_PySequence_ListKeepNew" + if node.is_temp and arg.is_temp and arg.type in (PyrexTypes.py_object_type, Builtin.list_type) + else "PySequence_List", + self.PySequence_List_func_type, + args=pos_args, + is_temp=node.is_temp, + ) PyList_AsTuple_func_type = PyrexTypes.CFuncType( Builtin.tuple_type, [ @@ -2489,20 +2646,49 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, elif len(pos_args) != 1: self._error_wrong_arg_count('float', node, pos_args, '0 or 1') return node + func_arg = pos_args[0] if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): func_arg = func_arg.arg if func_arg.type is PyrexTypes.c_double_type: return func_arg + elif func_arg.type in (PyrexTypes.c_py_ucs4_type, PyrexTypes.c_py_unicode_type): + # need to parse (<Py_UCS4>'1') as digit 1 + return self._pyucs4_to_number(node, function.name, func_arg) elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric: return ExprNodes.TypecastNode( node.pos, operand=func_arg, type=node.type) + + arg = None + if func_arg.type is Builtin.bytes_type: + cfunc_name = "__Pyx_PyBytes_AsDouble" + utility_code_name = 'pybytes_as_double' + elif func_arg.type is Builtin.bytearray_type: + cfunc_name = "__Pyx_PyByteArray_AsDouble" + utility_code_name = 'pybytes_as_double' + elif func_arg.type is Builtin.unicode_type: + cfunc_name = "__Pyx_PyUnicode_AsDouble" + utility_code_name = 'pyunicode_as_double' + elif func_arg.type is Builtin.str_type: + cfunc_name = "__Pyx_PyString_AsDouble" + utility_code_name = 'pystring_as_double' + elif func_arg.type is Builtin.long_type: + cfunc_name = "PyLong_AsDouble" + else: + arg = func_arg # no need for an additional None check + cfunc_name = "__Pyx_PyObject_AsDouble" + utility_code_name = 'pyobject_as_double' + + if arg is None: + arg = func_arg.as_none_safe_node( + "float() argument must be a string or a number, not 'NoneType'") + return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_PyObject_AsDouble", + node.pos, cfunc_name, self.PyObject_AsDouble_func_type, - args = pos_args, + args = [arg], is_temp = node.is_temp, - utility_code = load_c_utility('pyobject_as_double'), + utility_code = load_c_utility(utility_code_name) if utility_code_name else None, py_name = "float") PyNumber_Int_func_type = PyrexTypes.CFuncType( @@ -2556,17 +2742,59 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, # coerce back to Python object as that's the result we are expecting return operand.coerce_to_pyobject(self.current_env()) + PyMemoryView_FromObject_func_type = PyrexTypes.CFuncType( + Builtin.memoryview_type, [ + PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None) + ]) + + PyMemoryView_FromBuffer_func_type = PyrexTypes.CFuncType( + Builtin.memoryview_type, [ + PyrexTypes.CFuncTypeArg("value", Builtin.py_buffer_type, None) + ]) + + def _handle_simple_function_memoryview(self, node, function, pos_args): + if len(pos_args) != 1: + self._error_wrong_arg_count('memoryview', node, pos_args, '1') + return node + else: + if pos_args[0].type.is_pyobject: + return ExprNodes.PythonCapiCallNode( + node.pos, "PyMemoryView_FromObject", + self.PyMemoryView_FromObject_func_type, + args = [pos_args[0]], + is_temp = node.is_temp, + py_name = "memoryview") + elif pos_args[0].type.is_ptr and pos_args[0].base_type is Builtin.py_buffer_type: + # TODO - this currently doesn't work because the buffer fails a + # "can coerce to python object" test earlier. But it'd be nice to support + return ExprNodes.PythonCapiCallNode( + node.pos, "PyMemoryView_FromBuffer", + self.PyMemoryView_FromBuffer_func_type, + args = [pos_args[0]], + is_temp = node.is_temp, + py_name = "memoryview") + return node + + ### builtin functions Pyx_strlen_func_type = PyrexTypes.CFuncType( PyrexTypes.c_size_t_type, [ PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None) - ]) + ], + nogil=True) + + Pyx_ssize_strlen_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ssize_t_type, [ + PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None) + ], + exception_value="-1") Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_size_t_type, [ + PyrexTypes.c_py_ssize_t_type, [ PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None) - ]) + ], + exception_value="-1") PyObject_Size_func_type = PyrexTypes.CFuncType( PyrexTypes.c_py_ssize_t_type, [ @@ -2585,7 +2813,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, Builtin.dict_type: "PyDict_Size", }.get - _ext_types_with_pysize = set(["cpython.array.array"]) + _ext_types_with_pysize = {"cpython.array.array"} def _handle_simple_function_len(self, node, function, pos_args): """Replace len(char*) by the equivalent call to strlen(), @@ -2600,18 +2828,19 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, arg = arg.arg if arg.type.is_string: new_node = ExprNodes.PythonCapiCallNode( - node.pos, "strlen", self.Pyx_strlen_func_type, + node.pos, "__Pyx_ssize_strlen", self.Pyx_ssize_strlen_func_type, args = [arg], is_temp = node.is_temp, - utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c")) + utility_code = UtilityCode.load_cached("ssize_strlen", "StringTools.c")) elif arg.type.is_pyunicode_ptr: new_node = ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type, + node.pos, "__Pyx_Py_UNICODE_ssize_strlen", self.Pyx_Py_UNICODE_strlen_func_type, args = [arg], - is_temp = node.is_temp) + is_temp = node.is_temp, + utility_code = UtilityCode.load_cached("ssize_pyunicode_strlen", "StringTools.c")) elif arg.type.is_memoryviewslice: func_type = PyrexTypes.CFuncType( - PyrexTypes.c_size_t_type, [ + PyrexTypes.c_py_ssize_t_type, [ PyrexTypes.CFuncTypeArg("memoryviewslice", arg.type, None) ], nogil=True) new_node = ExprNodes.PythonCapiCallNode( @@ -2622,7 +2851,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, if cfunc_name is None: arg_type = arg.type if ((arg_type.is_extension_type or arg_type.is_builtin_type) - and arg_type.entry.qualified_name in self._ext_types_with_pysize): + and arg_type.entry.qualified_name in self._ext_types_with_pysize): cfunc_name = 'Py_SIZE' else: return node @@ -2698,6 +2927,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, builtin_type = None if builtin_type is not None: type_check_function = entry.type.type_check_function(exact=False) + if type_check_function == '__Pyx_Py3Int_Check' and builtin_type is Builtin.int_type: + # isinstance(x, int) should really test for 'int' in Py2, not 'int | long' + type_check_function = "PyInt_Check" if type_check_function in tests: continue tests.append(type_check_function) @@ -2781,11 +3013,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return node type_arg = args[0] if not obj.is_name or not type_arg.is_name: - # play safe - return node + return node # not a simple case if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type: - # not a known type, play safe - return node + return node # not a known type if not type_arg.type_entry or not obj.type_entry: if obj.name != type_arg.name: return node @@ -2847,6 +3077,13 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, is_temp=node.is_temp ) + def _handle_any_slot__class__(self, node, function, args, + is_unbound_method, kwargs=None): + # The purpose of this function is to handle calls to instance.__class__() so that + # it doesn't get handled by the __Pyx_CallUnboundCMethod0 mechanism. + # TODO: optimizations of the instance.__class__() call might be possible in future. + return node + ### methods of builtin types PyObject_Append_func_type = PyrexTypes.CFuncType( @@ -3182,6 +3419,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method): return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method) + def _handle_simple_method_object___mul__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Multiply', node, function, args, is_unbound_method) + def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method): return self._optimise_num_binop('Eq', node, function, args, is_unbound_method) @@ -3261,6 +3501,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, """ Optimise math operators for (likely) float or small integer operations. """ + if getattr(node, "special_bool_cmp_function", None): + return node # already optimized + if len(args) != 2: return node @@ -3271,66 +3514,15 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, else: return node - # When adding IntNode/FloatNode to something else, assume other operand is also numeric. - # Prefer constants on RHS as they allows better size control for some operators. - num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode) - if isinstance(args[1], num_nodes): - if args[0].type is not PyrexTypes.py_object_type: - return node - numval = args[1] - arg_order = 'ObjC' - elif isinstance(args[0], num_nodes): - if args[1].type is not PyrexTypes.py_object_type: - return node - numval = args[0] - arg_order = 'CObj' - else: - return node - - if not numval.has_constant_result(): - return node - - is_float = isinstance(numval, ExprNodes.FloatNode) - num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type - if is_float: - if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'): - return node - elif operator == 'Divide': - # mixed old-/new-style division is not currently optimised for integers - return node - elif abs(numval.constant_result) > 2**30: - # Cut off at an integer border that is still safe for all operations. + result = optimise_numeric_binop(operator, node, ret_type, args[0], args[1]) + if not result: return node - - if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'): - if args[1].constant_result == 0: - # Don't optimise division by 0. :) - return node - - args = list(args) - args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)( - numval.pos, value=numval.value, constant_result=numval.constant_result, - type=num_type)) - inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False - args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace)) - if is_float or operator not in ('Eq', 'Ne'): - # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument. - zerodivision_check = arg_order == 'CObj' and ( - not node.cdivision if isinstance(node, ExprNodes.DivNode) else False) - args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check)) - - utility_code = TempitaUtilityCode.load_cached( - "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop", - "Optimize.c", - context=dict(op=operator, order=arg_order, ret_type=ret_type)) + func_cname, utility_code, extra_args, num_type = result + args = list(args)+extra_args call_node = self._substitute_method_call( node, function, - "__Pyx_Py%s_%s%s%s" % ( - 'Float' if is_float else 'Int', - '' if ret_type.is_pyobject else 'Bool', - operator, - arg_order), + func_cname, self.Pyx_BinopInt_func_types[(num_type, ret_type)], '__%s__' % operator[:3].lower(), is_unbound_method, args, may_return_none=True, @@ -3389,6 +3581,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), ]) + # DISABLED: Return value can only be one character, which is not correct. + ''' def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method): if is_unbound_method or len(args) != 1: return node @@ -3407,9 +3601,10 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, func_call = func_call.coerce_to_pyobject(self.current_env) return func_call - _handle_simple_method_unicode_lower = _inject_unicode_character_conversion - _handle_simple_method_unicode_upper = _inject_unicode_character_conversion - _handle_simple_method_unicode_title = _inject_unicode_character_conversion + #_handle_simple_method_unicode_lower = _inject_unicode_character_conversion + #_handle_simple_method_unicode_upper = _inject_unicode_character_conversion + #_handle_simple_method_unicode_title = _inject_unicode_character_conversion + ''' PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType( Builtin.list_type, [ @@ -3448,6 +3643,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return node if len(args) < 2: args.append(ExprNodes.NullNode(node.pos)) + else: + self._inject_null_for_none(args, 1) self._inject_int_default_argument( node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1") @@ -3788,7 +3985,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, if not stop: # use strlen() to find the string length, just as CPython would if not string_node.is_name: - string_node = UtilNodes.LetRefNode(string_node) # used twice + string_node = UtilNodes.LetRefNode(string_node) # used twice temps.append(string_node) stop = ExprNodes.PythonCapiCallNode( string_node.pos, "strlen", self.Pyx_strlen_func_type, @@ -3963,13 +4160,35 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, format_args=[attr_name]) return self_arg + obj_to_obj_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) + ]) + + def _inject_null_for_none(self, args, index): + if len(args) <= index: + return + arg = args[index] + args[index] = ExprNodes.NullNode(arg.pos) if arg.is_none else ExprNodes.PythonCapiCallNode( + arg.pos, "__Pyx_NoneAsNull", + self.obj_to_obj_func_type, + args=[arg.coerce_to_simple(self.current_env())], + is_temp=0, + ) + def _inject_int_default_argument(self, node, args, arg_index, type, default_value): + # Python usually allows passing None for range bounds, + # so we treat that as requesting the default. assert len(args) >= arg_index - if len(args) == arg_index: + if len(args) == arg_index or args[arg_index].is_none: args.append(ExprNodes.IntNode(node.pos, value=str(default_value), type=type, constant_result=default_value)) else: - args[arg_index] = args[arg_index].coerce_to(type, self.current_env()) + arg = args[arg_index].coerce_to(type, self.current_env()) + if isinstance(arg, ExprNodes.CoerceFromPyTypeNode): + # Add a runtime check for None and map it to the default value. + arg.special_none_cvalue = str(default_value) + args[arg_index] = arg def _inject_bint_default_argument(self, node, args, arg_index, default_value): assert len(args) >= arg_index @@ -3981,6 +4200,75 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env()) +def optimise_numeric_binop(operator, node, ret_type, arg0, arg1): + """ + Optimise math operators for (likely) float or small integer operations. + """ + # When adding IntNode/FloatNode to something else, assume other operand is also numeric. + # Prefer constants on RHS as they allows better size control for some operators. + num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode) + if isinstance(arg1, num_nodes): + if arg0.type is not PyrexTypes.py_object_type: + return None + numval = arg1 + arg_order = 'ObjC' + elif isinstance(arg0, num_nodes): + if arg1.type is not PyrexTypes.py_object_type: + return None + numval = arg0 + arg_order = 'CObj' + else: + return None + + if not numval.has_constant_result(): + return None + + # is_float is an instance check rather that numval.type.is_float because + # it will often be a Python float type rather than a C float type + is_float = isinstance(numval, ExprNodes.FloatNode) + num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type + if is_float: + if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'): + return None + elif operator == 'Divide': + # mixed old-/new-style division is not currently optimised for integers + return None + elif abs(numval.constant_result) > 2**30: + # Cut off at an integer border that is still safe for all operations. + return None + + if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'): + if arg1.constant_result == 0: + # Don't optimise division by 0. :) + return None + + extra_args = [] + + extra_args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)( + numval.pos, value=numval.value, constant_result=numval.constant_result, + type=num_type)) + inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False + extra_args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace)) + if is_float or operator not in ('Eq', 'Ne'): + # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument. + zerodivision_check = arg_order == 'CObj' and ( + not node.cdivision if isinstance(node, ExprNodes.DivNode) else False) + extra_args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check)) + + utility_code = TempitaUtilityCode.load_cached( + "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop", + "Optimize.c", + context=dict(op=operator, order=arg_order, ret_type=ret_type)) + + func_cname = "__Pyx_Py%s_%s%s%s" % ( + 'Float' if is_float else 'Int', + '' if ret_type.is_pyobject else 'Bool', + operator, + arg_order) + + return func_cname, utility_code, extra_args, num_type + + unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c') bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c') str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c') @@ -4439,25 +4727,25 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): args = [] items = [] - def add(arg): + def add(parent, arg): if arg.is_dict_literal: - if items: - items[0].key_value_pairs.extend(arg.key_value_pairs) + if items and items[-1].reject_duplicates == arg.reject_duplicates: + items[-1].key_value_pairs.extend(arg.key_value_pairs) else: items.append(arg) - elif isinstance(arg, ExprNodes.MergedDictNode): + elif isinstance(arg, ExprNodes.MergedDictNode) and parent.reject_duplicates == arg.reject_duplicates: for child_arg in arg.keyword_args: - add(child_arg) + add(arg, child_arg) else: if items: - args.append(items[0]) + args.extend(items) del items[:] args.append(arg) for arg in node.keyword_args: - add(arg) + add(node, arg) if items: - args.append(items[0]) + args.extend(items) if len(args) == 1: arg = args[0] @@ -4546,22 +4834,20 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): cascades = [[node.operand1]] final_false_result = [] - def split_cascades(cmp_node): + cmp_node = node + while cmp_node is not None: if cmp_node.has_constant_result(): if not cmp_node.constant_result: # False => short-circuit final_false_result.append(self._bool_node(cmp_node, False)) - return + break else: # True => discard and start new cascade cascades.append([cmp_node.operand2]) else: # not constant => append to current cascade cascades[-1].append(cmp_node) - if cmp_node.cascade: - split_cascades(cmp_node.cascade) - - split_cascades(node) + cmp_node = cmp_node.cascade cmp_nodes = [] for cascade in cascades: @@ -4707,6 +4993,30 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): return None return node + def visit_GILStatNode(self, node): + self.visitchildren(node) + if node.condition is None: + return node + + if node.condition.has_constant_result(): + # Condition is True - Modify node to be a normal + # GILStatNode with condition=None + if node.condition.constant_result: + node.condition = None + + # Condition is False - the body of the GILStatNode + # should run without changing the state of the gil + # return the body of the GILStatNode + else: + return node.body + + # If condition is not constant we keep the GILStatNode as it is. + # Either it will later become constant (e.g. a `numeric is int` + # expression in a fused type function) and then when ConstantFolding + # runs again it will be handled or a later transform (i.e. GilCheck) + # will raise an error + return node + # in the future, other nodes can have their own handler method here # that can replace them with a constant result node @@ -4723,6 +5033,7 @@ class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin): - isinstance -> typecheck for cdef types - eliminate checks for None and/or types that became redundant after tree changes - eliminate useless string formatting steps + - inject branch hints for unlikely if-cases that only raise exceptions - replace Python function calls that look like method calls by a faster PyMethodCallNode """ in_loop = False @@ -4821,6 +5132,48 @@ class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin): self.in_loop = old_val return node + def visit_IfStatNode(self, node): + """Assign 'unlikely' branch hints to if-clauses that only raise exceptions. + """ + self.visitchildren(node) + last_non_unlikely_clause = None + for i, if_clause in enumerate(node.if_clauses): + self._set_ifclause_branch_hint(if_clause, if_clause.body) + if not if_clause.branch_hint: + last_non_unlikely_clause = if_clause + if node.else_clause and last_non_unlikely_clause: + # If the 'else' clause is 'unlikely', then set the preceding 'if' clause to 'likely' to reflect that. + self._set_ifclause_branch_hint(last_non_unlikely_clause, node.else_clause, inverse=True) + return node + + def _set_ifclause_branch_hint(self, clause, statements_node, inverse=False): + """Inject a branch hint if the if-clause unconditionally leads to a 'raise' statement. + """ + if not statements_node.is_terminator: + return + # Allow simple statements, but no conditions, loops, etc. + non_branch_nodes = ( + Nodes.ExprStatNode, + Nodes.AssignmentNode, + Nodes.AssertStatNode, + Nodes.DelStatNode, + Nodes.GlobalNode, + Nodes.NonlocalNode, + ) + statements = [statements_node] + for next_node_pos, node in enumerate(statements, 1): + if isinstance(node, Nodes.GILStatNode): + statements.insert(next_node_pos, node.body) + continue + if isinstance(node, Nodes.StatListNode): + statements[next_node_pos:next_node_pos] = node.stats + continue + if not isinstance(node, non_branch_nodes): + if next_node_pos == len(statements) and isinstance(node, (Nodes.RaiseStatNode, Nodes.ReraiseStatNode)): + # Anything that unconditionally raises exceptions at the end should be considered unlikely. + clause.branch_hint = 'likely' if inverse else 'unlikely' + break + class ConsolidateOverflowCheck(Visitor.CythonTransform): """ |