summaryrefslogtreecommitdiff
path: root/Cython/Compiler/Optimize.py
diff options
context:
space:
mode:
Diffstat (limited to 'Cython/Compiler/Optimize.py')
-rw-r--r--Cython/Compiler/Optimize.py585
1 files changed, 469 insertions, 116 deletions
diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py
index 7e9435ba0..fb6dc5dae 100644
--- a/Cython/Compiler/Optimize.py
+++ b/Cython/Compiler/Optimize.py
@@ -41,7 +41,7 @@ except ImportError:
try:
from __builtin__ import basestring
except ImportError:
- basestring = str # Python 3
+ basestring = str # Python 3
def load_c_utility(name):
@@ -192,19 +192,9 @@ class IterationTransform(Visitor.EnvTransform):
def _optimise_for_loop(self, node, iterable, reversed=False):
annotation_type = None
if (iterable.is_name or iterable.is_attribute) and iterable.entry and iterable.entry.annotation:
- annotation = iterable.entry.annotation
+ annotation = iterable.entry.annotation.expr
if annotation.is_subscript:
annotation = annotation.base # container base type
- # FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names?
- if annotation.is_name:
- if annotation.entry and annotation.entry.qualified_name == 'typing.Dict':
- annotation_type = Builtin.dict_type
- elif annotation.name == 'Dict':
- annotation_type = Builtin.dict_type
- if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'):
- annotation_type = Builtin.set_type
- elif annotation.name in ('Set', 'FrozenSet'):
- annotation_type = Builtin.set_type
if Builtin.dict_type in (iterable.type, annotation_type):
# like iterating over dict.keys()
@@ -228,6 +218,12 @@ class IterationTransform(Visitor.EnvTransform):
return self._transform_bytes_iteration(node, iterable, reversed=reversed)
if iterable.type is Builtin.unicode_type:
return self._transform_unicode_iteration(node, iterable, reversed=reversed)
+ # in principle _transform_indexable_iteration would work on most of the above, and
+ # also tuple and list. However, it probably isn't quite as optimized
+ if iterable.type is Builtin.bytearray_type:
+ return self._transform_indexable_iteration(node, iterable, is_mutable=True, reversed=reversed)
+ if isinstance(iterable, ExprNodes.CoerceToPyTypeNode) and iterable.arg.type.is_memoryviewslice:
+ return self._transform_indexable_iteration(node, iterable.arg, is_mutable=False, reversed=reversed)
# the rest is based on function calls
if not isinstance(iterable, ExprNodes.SimpleCallNode):
@@ -323,6 +319,92 @@ class IterationTransform(Visitor.EnvTransform):
return self._optimise_for_loop(node, arg, reversed=True)
+ def _transform_indexable_iteration(self, node, slice_node, is_mutable, reversed=False):
+ """In principle can handle any iterable that Cython has a len() for and knows how to index"""
+ unpack_temp_node = UtilNodes.LetRefNode(
+ slice_node.as_none_safe_node("'NoneType' is not iterable"),
+ may_hold_none=False, is_temp=True
+ )
+
+ start_node = ExprNodes.IntNode(
+ node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
+ def make_length_call():
+ # helper function since we need to create this node for a couple of places
+ builtin_len = ExprNodes.NameNode(node.pos, name="len",
+ entry=Builtin.builtin_scope.lookup("len"))
+ return ExprNodes.SimpleCallNode(node.pos,
+ function=builtin_len,
+ args=[unpack_temp_node]
+ )
+ length_temp = UtilNodes.LetRefNode(make_length_call(), type=PyrexTypes.c_py_ssize_t_type, is_temp=True)
+ end_node = length_temp
+
+ if reversed:
+ relation1, relation2 = '>', '>='
+ start_node, end_node = end_node, start_node
+ else:
+ relation1, relation2 = '<=', '<'
+
+ counter_ref = UtilNodes.LetRefNode(pos=node.pos, type=PyrexTypes.c_py_ssize_t_type)
+
+ target_value = ExprNodes.IndexNode(slice_node.pos, base=unpack_temp_node,
+ index=counter_ref)
+
+ target_assign = Nodes.SingleAssignmentNode(
+ pos = node.target.pos,
+ lhs = node.target,
+ rhs = target_value)
+
+ # analyse with boundscheck and wraparound
+ # off (because we're confident we know the size)
+ env = self.current_env()
+ new_directives = Options.copy_inherited_directives(env.directives, boundscheck=False, wraparound=False)
+ target_assign = Nodes.CompilerDirectivesNode(
+ target_assign.pos,
+ directives=new_directives,
+ body=target_assign,
+ )
+
+ body = Nodes.StatListNode(
+ node.pos,
+ stats = [target_assign]) # exclude node.body for now to not reanalyse it
+ if is_mutable:
+ # We need to be slightly careful here that we are actually modifying the loop
+ # bounds and not a temp copy of it. Setting is_temp=True on length_temp seems
+ # to ensure this.
+ # If this starts to fail then we could insert an "if out_of_bounds: break" instead
+ loop_length_reassign = Nodes.SingleAssignmentNode(node.pos,
+ lhs = length_temp,
+ rhs = make_length_call())
+ body.stats.append(loop_length_reassign)
+
+ loop_node = Nodes.ForFromStatNode(
+ node.pos,
+ bound1=start_node, relation1=relation1,
+ target=counter_ref,
+ relation2=relation2, bound2=end_node,
+ step=None, body=body,
+ else_clause=node.else_clause,
+ from_range=True)
+
+ ret = UtilNodes.LetNode(
+ unpack_temp_node,
+ UtilNodes.LetNode(
+ length_temp,
+ # TempResultFromStatNode provides the framework where the "counter_ref"
+ # temp is set up and can be assigned to. However, we don't need the
+ # result it returns so wrap it in an ExprStatNode.
+ Nodes.ExprStatNode(node.pos,
+ expr=UtilNodes.TempResultFromStatNode(
+ counter_ref,
+ loop_node
+ )
+ )
+ )
+ ).analyse_expressions(env)
+ body.stats.insert(1, node.body)
+ return ret
+
PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type, [
PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
@@ -1144,7 +1226,7 @@ class SwitchTransform(Visitor.EnvTransform):
# integers on iteration, whereas Py2 returns 1-char byte
# strings
characters = string_literal.value
- characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
+ characters = list({ characters[i:i+1] for i in range(len(characters)) })
characters.sort()
return [ ExprNodes.CharNode(string_literal.pos, value=charval,
constant_result=charval)
@@ -1156,7 +1238,8 @@ class SwitchTransform(Visitor.EnvTransform):
return self.NO_MATCH
elif common_var is not None and not is_common_value(var, common_var):
return self.NO_MATCH
- elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
+ elif not (var.type.is_int or var.type.is_enum) or any(
+ [not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
return self.NO_MATCH
return not_in, var, conditions
@@ -1573,7 +1656,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
utility_code = utility_code)
def _error_wrong_arg_count(self, function_name, node, args, expected=None):
- if not expected: # None or 0
+ if not expected: # None or 0
arg_str = ''
elif isinstance(expected, basestring) or expected > 1:
arg_str = '...'
@@ -1727,7 +1810,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
arg = pos_args[0]
if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type:
- list_node = pos_args[0]
+ list_node = arg
loop_node = list_node.loop
elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
@@ -1757,7 +1840,11 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
# Interestingly, PySequence_List works on a lot of non-sequence
# things as well.
list_node = loop_node = ExprNodes.PythonCapiCallNode(
- node.pos, "PySequence_List", self.PySequence_List_func_type,
+ node.pos,
+ "__Pyx_PySequence_ListKeepNew"
+ if arg.is_temp and arg.type in (PyrexTypes.py_object_type, Builtin.list_type)
+ else "PySequence_List",
+ self.PySequence_List_func_type,
args=pos_args, is_temp=True)
result_node = UtilNodes.ResultRefNode(
@@ -1803,7 +1890,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if not yield_expression.is_literal or not yield_expression.type.is_int:
return node
except AttributeError:
- return node # in case we don't have a type yet
+ return node # in case we don't have a type yet
# special case: old Py2 backwards compatible "sum([int_const for ...])"
# can safely be unpacked into a genexpr
@@ -2018,7 +2105,8 @@ class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
return node
inlined = ExprNodes.InlinedDefNodeCallNode(
node.pos, function_name=function_name,
- function=function, args=node.args)
+ function=function, args=node.args,
+ generator_arg_tag=node.generator_arg_tag)
if inlined.can_be_inlined():
return self.replace(node, inlined)
return node
@@ -2097,12 +2185,13 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
func_arg = arg.args[0]
if func_arg.type is Builtin.float_type:
return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'")
- elif func_arg.type.is_pyobject:
+ elif func_arg.type.is_pyobject and arg.function.cname == "__Pyx_PyObject_AsDouble":
return ExprNodes.PythonCapiCallNode(
node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type,
args=[func_arg],
py_name='float',
is_temp=node.is_temp,
+ utility_code = UtilityCode.load_cached("pynumber_float", "TypeConversion.c"),
result_is_used=node.result_is_used,
).coerce_to(node.type, self.current_env())
return node
@@ -2210,6 +2299,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
if func_arg.type.is_int or node.type.is_int:
if func_arg.type == node.type:
return func_arg
+ elif func_arg.type in (PyrexTypes.c_py_ucs4_type, PyrexTypes.c_py_unicode_type):
+ # need to parse (<Py_UCS4>'1') as digit 1
+ return self._pyucs4_to_number(node, function.name, func_arg)
elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type)
elif func_arg.type.is_float and node.type.is_numeric:
@@ -2230,13 +2322,40 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
if func_arg.type.is_float or node.type.is_float:
if func_arg.type == node.type:
return func_arg
+ elif func_arg.type in (PyrexTypes.c_py_ucs4_type, PyrexTypes.c_py_unicode_type):
+ # need to parse (<Py_UCS4>'1') as digit 1
+ return self._pyucs4_to_number(node, function.name, func_arg)
elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
return ExprNodes.TypecastNode(
node.pos, operand=func_arg, type=node.type)
return node
+ pyucs4_int_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_int_type, [
+ PyrexTypes.CFuncTypeArg("arg", PyrexTypes.c_py_ucs4_type, None)
+ ],
+ exception_value="-1")
+
+ pyucs4_double_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_double_type, [
+ PyrexTypes.CFuncTypeArg("arg", PyrexTypes.c_py_ucs4_type, None)
+ ],
+ exception_value="-1.0")
+
+ def _pyucs4_to_number(self, node, py_type_name, func_arg):
+ assert py_type_name in ("int", "float")
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_int_from_UCS4" if py_type_name == "int" else "__Pyx_double_from_UCS4",
+ func_type=self.pyucs4_int_func_type if py_type_name == "int" else self.pyucs4_double_func_type,
+ args=[func_arg],
+ py_name=py_type_name,
+ is_temp=node.is_temp,
+ result_is_used=node.result_is_used,
+ utility_code=UtilityCode.load_cached("int_pyucs4" if py_type_name == "int" else "float_pyucs4", "Builtins.c"),
+ ).coerce_to(node.type, self.current_env())
+
def _error_wrong_arg_count(self, function_name, node, args, expected=None):
- if not expected: # None or 0
+ if not expected: # None or 0
arg_str = ''
elif isinstance(expected, basestring) or expected > 1:
arg_str = '...'
@@ -2316,6 +2435,38 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
return ExprNodes.CachedBuiltinMethodCallNode(
node, function.obj, attr_name, arg_list)
+ PyObject_String_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [ # Change this to Builtin.str_type when removing Py2 support.
+ PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
+ ])
+
+ def _handle_simple_function_str(self, node, function, pos_args):
+ """Optimize single argument calls to str().
+ """
+ if len(pos_args) != 1:
+ if len(pos_args) == 0:
+ return ExprNodes.StringNode(node.pos, value=EncodedString(), constant_result='')
+ return node
+ arg = pos_args[0]
+
+ if arg.type is Builtin.str_type:
+ if not arg.may_be_none():
+ return arg
+
+ cname = "__Pyx_PyStr_Str"
+ utility_code = UtilityCode.load_cached('PyStr_Str', 'StringTools.c')
+ else:
+ cname = '__Pyx_PyObject_Str'
+ utility_code = UtilityCode.load_cached('PyObject_Str', 'StringTools.c')
+
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, cname, self.PyObject_String_func_type,
+ args=pos_args,
+ is_temp=node.is_temp,
+ utility_code=utility_code,
+ py_name="str"
+ )
+
PyObject_Unicode_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
@@ -2387,8 +2538,14 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
return node
arg = pos_args[0]
return ExprNodes.PythonCapiCallNode(
- node.pos, "PySequence_List", self.PySequence_List_func_type,
- args=pos_args, is_temp=node.is_temp)
+ node.pos,
+ "__Pyx_PySequence_ListKeepNew"
+ if node.is_temp and arg.is_temp and arg.type in (PyrexTypes.py_object_type, Builtin.list_type)
+ else "PySequence_List",
+ self.PySequence_List_func_type,
+ args=pos_args,
+ is_temp=node.is_temp,
+ )
PyList_AsTuple_func_type = PyrexTypes.CFuncType(
Builtin.tuple_type, [
@@ -2489,20 +2646,49 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
elif len(pos_args) != 1:
self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
return node
+
func_arg = pos_args[0]
if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
func_arg = func_arg.arg
if func_arg.type is PyrexTypes.c_double_type:
return func_arg
+ elif func_arg.type in (PyrexTypes.c_py_ucs4_type, PyrexTypes.c_py_unicode_type):
+ # need to parse (<Py_UCS4>'1') as digit 1
+ return self._pyucs4_to_number(node, function.name, func_arg)
elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
return ExprNodes.TypecastNode(
node.pos, operand=func_arg, type=node.type)
+
+ arg = None
+ if func_arg.type is Builtin.bytes_type:
+ cfunc_name = "__Pyx_PyBytes_AsDouble"
+ utility_code_name = 'pybytes_as_double'
+ elif func_arg.type is Builtin.bytearray_type:
+ cfunc_name = "__Pyx_PyByteArray_AsDouble"
+ utility_code_name = 'pybytes_as_double'
+ elif func_arg.type is Builtin.unicode_type:
+ cfunc_name = "__Pyx_PyUnicode_AsDouble"
+ utility_code_name = 'pyunicode_as_double'
+ elif func_arg.type is Builtin.str_type:
+ cfunc_name = "__Pyx_PyString_AsDouble"
+ utility_code_name = 'pystring_as_double'
+ elif func_arg.type is Builtin.long_type:
+ cfunc_name = "PyLong_AsDouble"
+ else:
+ arg = func_arg # no need for an additional None check
+ cfunc_name = "__Pyx_PyObject_AsDouble"
+ utility_code_name = 'pyobject_as_double'
+
+ if arg is None:
+ arg = func_arg.as_none_safe_node(
+ "float() argument must be a string or a number, not 'NoneType'")
+
return ExprNodes.PythonCapiCallNode(
- node.pos, "__Pyx_PyObject_AsDouble",
+ node.pos, cfunc_name,
self.PyObject_AsDouble_func_type,
- args = pos_args,
+ args = [arg],
is_temp = node.is_temp,
- utility_code = load_c_utility('pyobject_as_double'),
+ utility_code = load_c_utility(utility_code_name) if utility_code_name else None,
py_name = "float")
PyNumber_Int_func_type = PyrexTypes.CFuncType(
@@ -2556,17 +2742,59 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
# coerce back to Python object as that's the result we are expecting
return operand.coerce_to_pyobject(self.current_env())
+ PyMemoryView_FromObject_func_type = PyrexTypes.CFuncType(
+ Builtin.memoryview_type, [
+ PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None)
+ ])
+
+ PyMemoryView_FromBuffer_func_type = PyrexTypes.CFuncType(
+ Builtin.memoryview_type, [
+ PyrexTypes.CFuncTypeArg("value", Builtin.py_buffer_type, None)
+ ])
+
+ def _handle_simple_function_memoryview(self, node, function, pos_args):
+ if len(pos_args) != 1:
+ self._error_wrong_arg_count('memoryview', node, pos_args, '1')
+ return node
+ else:
+ if pos_args[0].type.is_pyobject:
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "PyMemoryView_FromObject",
+ self.PyMemoryView_FromObject_func_type,
+ args = [pos_args[0]],
+ is_temp = node.is_temp,
+ py_name = "memoryview")
+ elif pos_args[0].type.is_ptr and pos_args[0].base_type is Builtin.py_buffer_type:
+ # TODO - this currently doesn't work because the buffer fails a
+ # "can coerce to python object" test earlier. But it'd be nice to support
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "PyMemoryView_FromBuffer",
+ self.PyMemoryView_FromBuffer_func_type,
+ args = [pos_args[0]],
+ is_temp = node.is_temp,
+ py_name = "memoryview")
+ return node
+
+
### builtin functions
Pyx_strlen_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_size_t_type, [
PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None)
- ])
+ ],
+ nogil=True)
+
+ Pyx_ssize_strlen_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_py_ssize_t_type, [
+ PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None)
+ ],
+ exception_value="-1")
Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
- PyrexTypes.c_size_t_type, [
+ PyrexTypes.c_py_ssize_t_type, [
PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None)
- ])
+ ],
+ exception_value="-1")
PyObject_Size_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_py_ssize_t_type, [
@@ -2585,7 +2813,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
Builtin.dict_type: "PyDict_Size",
}.get
- _ext_types_with_pysize = set(["cpython.array.array"])
+ _ext_types_with_pysize = {"cpython.array.array"}
def _handle_simple_function_len(self, node, function, pos_args):
"""Replace len(char*) by the equivalent call to strlen(),
@@ -2600,18 +2828,19 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
arg = arg.arg
if arg.type.is_string:
new_node = ExprNodes.PythonCapiCallNode(
- node.pos, "strlen", self.Pyx_strlen_func_type,
+ node.pos, "__Pyx_ssize_strlen", self.Pyx_ssize_strlen_func_type,
args = [arg],
is_temp = node.is_temp,
- utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"))
+ utility_code = UtilityCode.load_cached("ssize_strlen", "StringTools.c"))
elif arg.type.is_pyunicode_ptr:
new_node = ExprNodes.PythonCapiCallNode(
- node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
+ node.pos, "__Pyx_Py_UNICODE_ssize_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
args = [arg],
- is_temp = node.is_temp)
+ is_temp = node.is_temp,
+ utility_code = UtilityCode.load_cached("ssize_pyunicode_strlen", "StringTools.c"))
elif arg.type.is_memoryviewslice:
func_type = PyrexTypes.CFuncType(
- PyrexTypes.c_size_t_type, [
+ PyrexTypes.c_py_ssize_t_type, [
PyrexTypes.CFuncTypeArg("memoryviewslice", arg.type, None)
], nogil=True)
new_node = ExprNodes.PythonCapiCallNode(
@@ -2622,7 +2851,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
if cfunc_name is None:
arg_type = arg.type
if ((arg_type.is_extension_type or arg_type.is_builtin_type)
- and arg_type.entry.qualified_name in self._ext_types_with_pysize):
+ and arg_type.entry.qualified_name in self._ext_types_with_pysize):
cfunc_name = 'Py_SIZE'
else:
return node
@@ -2698,6 +2927,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
builtin_type = None
if builtin_type is not None:
type_check_function = entry.type.type_check_function(exact=False)
+ if type_check_function == '__Pyx_Py3Int_Check' and builtin_type is Builtin.int_type:
+ # isinstance(x, int) should really test for 'int' in Py2, not 'int | long'
+ type_check_function = "PyInt_Check"
if type_check_function in tests:
continue
tests.append(type_check_function)
@@ -2781,11 +3013,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
return node
type_arg = args[0]
if not obj.is_name or not type_arg.is_name:
- # play safe
- return node
+ return node # not a simple case
if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
- # not a known type, play safe
- return node
+ return node # not a known type
if not type_arg.type_entry or not obj.type_entry:
if obj.name != type_arg.name:
return node
@@ -2847,6 +3077,13 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
is_temp=node.is_temp
)
+ def _handle_any_slot__class__(self, node, function, args,
+ is_unbound_method, kwargs=None):
+ # The purpose of this function is to handle calls to instance.__class__() so that
+ # it doesn't get handled by the __Pyx_CallUnboundCMethod0 mechanism.
+ # TODO: optimizations of the instance.__class__() call might be possible in future.
+ return node
+
### methods of builtin types
PyObject_Append_func_type = PyrexTypes.CFuncType(
@@ -3182,6 +3419,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method):
return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
+ def _handle_simple_method_object___mul__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Multiply', node, function, args, is_unbound_method)
+
def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method):
return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
@@ -3261,6 +3501,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
"""
Optimise math operators for (likely) float or small integer operations.
"""
+ if getattr(node, "special_bool_cmp_function", None):
+ return node # already optimized
+
if len(args) != 2:
return node
@@ -3271,66 +3514,15 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
else:
return node
- # When adding IntNode/FloatNode to something else, assume other operand is also numeric.
- # Prefer constants on RHS as they allows better size control for some operators.
- num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode)
- if isinstance(args[1], num_nodes):
- if args[0].type is not PyrexTypes.py_object_type:
- return node
- numval = args[1]
- arg_order = 'ObjC'
- elif isinstance(args[0], num_nodes):
- if args[1].type is not PyrexTypes.py_object_type:
- return node
- numval = args[0]
- arg_order = 'CObj'
- else:
- return node
-
- if not numval.has_constant_result():
- return node
-
- is_float = isinstance(numval, ExprNodes.FloatNode)
- num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type
- if is_float:
- if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'):
- return node
- elif operator == 'Divide':
- # mixed old-/new-style division is not currently optimised for integers
- return node
- elif abs(numval.constant_result) > 2**30:
- # Cut off at an integer border that is still safe for all operations.
+ result = optimise_numeric_binop(operator, node, ret_type, args[0], args[1])
+ if not result:
return node
-
- if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'):
- if args[1].constant_result == 0:
- # Don't optimise division by 0. :)
- return node
-
- args = list(args)
- args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)(
- numval.pos, value=numval.value, constant_result=numval.constant_result,
- type=num_type))
- inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False
- args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace))
- if is_float or operator not in ('Eq', 'Ne'):
- # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument.
- zerodivision_check = arg_order == 'CObj' and (
- not node.cdivision if isinstance(node, ExprNodes.DivNode) else False)
- args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check))
-
- utility_code = TempitaUtilityCode.load_cached(
- "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop",
- "Optimize.c",
- context=dict(op=operator, order=arg_order, ret_type=ret_type))
+ func_cname, utility_code, extra_args, num_type = result
+ args = list(args)+extra_args
call_node = self._substitute_method_call(
node, function,
- "__Pyx_Py%s_%s%s%s" % (
- 'Float' if is_float else 'Int',
- '' if ret_type.is_pyobject else 'Bool',
- operator,
- arg_order),
+ func_cname,
self.Pyx_BinopInt_func_types[(num_type, ret_type)],
'__%s__' % operator[:3].lower(), is_unbound_method, args,
may_return_none=True,
@@ -3389,6 +3581,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
])
+ # DISABLED: Return value can only be one character, which is not correct.
+ '''
def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
if is_unbound_method or len(args) != 1:
return node
@@ -3407,9 +3601,10 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
func_call = func_call.coerce_to_pyobject(self.current_env)
return func_call
- _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
- _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
- _handle_simple_method_unicode_title = _inject_unicode_character_conversion
+ #_handle_simple_method_unicode_lower = _inject_unicode_character_conversion
+ #_handle_simple_method_unicode_upper = _inject_unicode_character_conversion
+ #_handle_simple_method_unicode_title = _inject_unicode_character_conversion
+ '''
PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
Builtin.list_type, [
@@ -3448,6 +3643,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
return node
if len(args) < 2:
args.append(ExprNodes.NullNode(node.pos))
+ else:
+ self._inject_null_for_none(args, 1)
self._inject_int_default_argument(
node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
@@ -3788,7 +3985,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
if not stop:
# use strlen() to find the string length, just as CPython would
if not string_node.is_name:
- string_node = UtilNodes.LetRefNode(string_node) # used twice
+ string_node = UtilNodes.LetRefNode(string_node) # used twice
temps.append(string_node)
stop = ExprNodes.PythonCapiCallNode(
string_node.pos, "strlen", self.Pyx_strlen_func_type,
@@ -3963,13 +4160,35 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
format_args=[attr_name])
return self_arg
+ obj_to_obj_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
+ ])
+
+ def _inject_null_for_none(self, args, index):
+ if len(args) <= index:
+ return
+ arg = args[index]
+ args[index] = ExprNodes.NullNode(arg.pos) if arg.is_none else ExprNodes.PythonCapiCallNode(
+ arg.pos, "__Pyx_NoneAsNull",
+ self.obj_to_obj_func_type,
+ args=[arg.coerce_to_simple(self.current_env())],
+ is_temp=0,
+ )
+
def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
+ # Python usually allows passing None for range bounds,
+ # so we treat that as requesting the default.
assert len(args) >= arg_index
- if len(args) == arg_index:
+ if len(args) == arg_index or args[arg_index].is_none:
args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
type=type, constant_result=default_value))
else:
- args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
+ arg = args[arg_index].coerce_to(type, self.current_env())
+ if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
+ # Add a runtime check for None and map it to the default value.
+ arg.special_none_cvalue = str(default_value)
+ args[arg_index] = arg
def _inject_bint_default_argument(self, node, args, arg_index, default_value):
assert len(args) >= arg_index
@@ -3981,6 +4200,75 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
+def optimise_numeric_binop(operator, node, ret_type, arg0, arg1):
+ """
+ Optimise math operators for (likely) float or small integer operations.
+ """
+ # When adding IntNode/FloatNode to something else, assume other operand is also numeric.
+ # Prefer constants on RHS as they allows better size control for some operators.
+ num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode)
+ if isinstance(arg1, num_nodes):
+ if arg0.type is not PyrexTypes.py_object_type:
+ return None
+ numval = arg1
+ arg_order = 'ObjC'
+ elif isinstance(arg0, num_nodes):
+ if arg1.type is not PyrexTypes.py_object_type:
+ return None
+ numval = arg0
+ arg_order = 'CObj'
+ else:
+ return None
+
+ if not numval.has_constant_result():
+ return None
+
+ # is_float is an instance check rather that numval.type.is_float because
+ # it will often be a Python float type rather than a C float type
+ is_float = isinstance(numval, ExprNodes.FloatNode)
+ num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type
+ if is_float:
+ if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'):
+ return None
+ elif operator == 'Divide':
+ # mixed old-/new-style division is not currently optimised for integers
+ return None
+ elif abs(numval.constant_result) > 2**30:
+ # Cut off at an integer border that is still safe for all operations.
+ return None
+
+ if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'):
+ if arg1.constant_result == 0:
+ # Don't optimise division by 0. :)
+ return None
+
+ extra_args = []
+
+ extra_args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)(
+ numval.pos, value=numval.value, constant_result=numval.constant_result,
+ type=num_type))
+ inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False
+ extra_args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace))
+ if is_float or operator not in ('Eq', 'Ne'):
+ # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument.
+ zerodivision_check = arg_order == 'CObj' and (
+ not node.cdivision if isinstance(node, ExprNodes.DivNode) else False)
+ extra_args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check))
+
+ utility_code = TempitaUtilityCode.load_cached(
+ "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop",
+ "Optimize.c",
+ context=dict(op=operator, order=arg_order, ret_type=ret_type))
+
+ func_cname = "__Pyx_Py%s_%s%s%s" % (
+ 'Float' if is_float else 'Int',
+ '' if ret_type.is_pyobject else 'Bool',
+ operator,
+ arg_order)
+
+ return func_cname, utility_code, extra_args, num_type
+
+
unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c')
@@ -4439,25 +4727,25 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
args = []
items = []
- def add(arg):
+ def add(parent, arg):
if arg.is_dict_literal:
- if items:
- items[0].key_value_pairs.extend(arg.key_value_pairs)
+ if items and items[-1].reject_duplicates == arg.reject_duplicates:
+ items[-1].key_value_pairs.extend(arg.key_value_pairs)
else:
items.append(arg)
- elif isinstance(arg, ExprNodes.MergedDictNode):
+ elif isinstance(arg, ExprNodes.MergedDictNode) and parent.reject_duplicates == arg.reject_duplicates:
for child_arg in arg.keyword_args:
- add(child_arg)
+ add(arg, child_arg)
else:
if items:
- args.append(items[0])
+ args.extend(items)
del items[:]
args.append(arg)
for arg in node.keyword_args:
- add(arg)
+ add(node, arg)
if items:
- args.append(items[0])
+ args.extend(items)
if len(args) == 1:
arg = args[0]
@@ -4546,22 +4834,20 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
cascades = [[node.operand1]]
final_false_result = []
- def split_cascades(cmp_node):
+ cmp_node = node
+ while cmp_node is not None:
if cmp_node.has_constant_result():
if not cmp_node.constant_result:
# False => short-circuit
final_false_result.append(self._bool_node(cmp_node, False))
- return
+ break
else:
# True => discard and start new cascade
cascades.append([cmp_node.operand2])
else:
# not constant => append to current cascade
cascades[-1].append(cmp_node)
- if cmp_node.cascade:
- split_cascades(cmp_node.cascade)
-
- split_cascades(node)
+ cmp_node = cmp_node.cascade
cmp_nodes = []
for cascade in cascades:
@@ -4707,6 +4993,30 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
return None
return node
+ def visit_GILStatNode(self, node):
+ self.visitchildren(node)
+ if node.condition is None:
+ return node
+
+ if node.condition.has_constant_result():
+ # Condition is True - Modify node to be a normal
+ # GILStatNode with condition=None
+ if node.condition.constant_result:
+ node.condition = None
+
+ # Condition is False - the body of the GILStatNode
+ # should run without changing the state of the gil
+ # return the body of the GILStatNode
+ else:
+ return node.body
+
+ # If condition is not constant we keep the GILStatNode as it is.
+ # Either it will later become constant (e.g. a `numeric is int`
+ # expression in a fused type function) and then when ConstantFolding
+ # runs again it will be handled or a later transform (i.e. GilCheck)
+ # will raise an error
+ return node
+
# in the future, other nodes can have their own handler method here
# that can replace them with a constant result node
@@ -4723,6 +5033,7 @@ class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin):
- isinstance -> typecheck for cdef types
- eliminate checks for None and/or types that became redundant after tree changes
- eliminate useless string formatting steps
+ - inject branch hints for unlikely if-cases that only raise exceptions
- replace Python function calls that look like method calls by a faster PyMethodCallNode
"""
in_loop = False
@@ -4821,6 +5132,48 @@ class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin):
self.in_loop = old_val
return node
+ def visit_IfStatNode(self, node):
+ """Assign 'unlikely' branch hints to if-clauses that only raise exceptions.
+ """
+ self.visitchildren(node)
+ last_non_unlikely_clause = None
+ for i, if_clause in enumerate(node.if_clauses):
+ self._set_ifclause_branch_hint(if_clause, if_clause.body)
+ if not if_clause.branch_hint:
+ last_non_unlikely_clause = if_clause
+ if node.else_clause and last_non_unlikely_clause:
+ # If the 'else' clause is 'unlikely', then set the preceding 'if' clause to 'likely' to reflect that.
+ self._set_ifclause_branch_hint(last_non_unlikely_clause, node.else_clause, inverse=True)
+ return node
+
+ def _set_ifclause_branch_hint(self, clause, statements_node, inverse=False):
+ """Inject a branch hint if the if-clause unconditionally leads to a 'raise' statement.
+ """
+ if not statements_node.is_terminator:
+ return
+ # Allow simple statements, but no conditions, loops, etc.
+ non_branch_nodes = (
+ Nodes.ExprStatNode,
+ Nodes.AssignmentNode,
+ Nodes.AssertStatNode,
+ Nodes.DelStatNode,
+ Nodes.GlobalNode,
+ Nodes.NonlocalNode,
+ )
+ statements = [statements_node]
+ for next_node_pos, node in enumerate(statements, 1):
+ if isinstance(node, Nodes.GILStatNode):
+ statements.insert(next_node_pos, node.body)
+ continue
+ if isinstance(node, Nodes.StatListNode):
+ statements[next_node_pos:next_node_pos] = node.stats
+ continue
+ if not isinstance(node, non_branch_nodes):
+ if next_node_pos == len(statements) and isinstance(node, (Nodes.RaiseStatNode, Nodes.ReraiseStatNode)):
+ # Anything that unconditionally raises exceptions at the end should be considered unlikely.
+ clause.branch_hint = 'likely' if inverse else 'unlikely'
+ break
+
class ConsolidateOverflowCheck(Visitor.CythonTransform):
"""