summaryrefslogtreecommitdiff
path: root/Cython/Compiler/FusedNode.py
diff options
context:
space:
mode:
Diffstat (limited to 'Cython/Compiler/FusedNode.py')
-rw-r--r--Cython/Compiler/FusedNode.py193
1 files changed, 147 insertions, 46 deletions
diff --git a/Cython/Compiler/FusedNode.py b/Cython/Compiler/FusedNode.py
index 26d6ffd3d..7876916db 100644
--- a/Cython/Compiler/FusedNode.py
+++ b/Cython/Compiler/FusedNode.py
@@ -7,6 +7,7 @@ from . import (ExprNodes, PyrexTypes, MemoryView,
from .ExprNodes import CloneNode, ProxyNode, TupleNode
from .Nodes import FuncDefNode, CFuncDefNode, StatListNode, DefNode
from ..Utils import OrderedSet
+from .Errors import error, CannotSpecialize
class FusedCFuncDefNode(StatListNode):
@@ -141,7 +142,14 @@ class FusedCFuncDefNode(StatListNode):
copied_node = copy.deepcopy(self.node)
# Make the types in our CFuncType specific.
- type = copied_node.type.specialize(fused_to_specific)
+ try:
+ type = copied_node.type.specialize(fused_to_specific)
+ except CannotSpecialize:
+ # unlike for the argument types, specializing the return type can fail
+ error(copied_node.pos, "Return type is a fused type that cannot "
+ "be determined from the function arguments")
+ self.py_func = None # this is just to let the compiler exit gracefully
+ return
entry = copied_node.entry
type.specialize_entry(entry, cname)
@@ -220,6 +228,10 @@ class FusedCFuncDefNode(StatListNode):
arg.type = arg.type.specialize(fused_to_specific)
if arg.type.is_memoryviewslice:
arg.type.validate_memslice_dtype(arg.pos)
+ if arg.annotation:
+ # TODO might be nice if annotations were specialized instead?
+ # (Or might be hard to do reliably)
+ arg.annotation.untyped = True
def create_new_local_scope(self, node, env, f2s):
"""
@@ -264,12 +276,12 @@ class FusedCFuncDefNode(StatListNode):
Returns whether an error was issued and whether we should stop in
in order to prevent a flood of errors.
"""
- num_errors = Errors.num_errors
+ num_errors = Errors.get_errors_count()
transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
copied_node.local_scope)
transform(copied_node)
- if Errors.num_errors > num_errors:
+ if Errors.get_errors_count() > num_errors:
return False
return True
@@ -309,25 +321,21 @@ class FusedCFuncDefNode(StatListNode):
def _buffer_check_numpy_dtype_setup_cases(self, pyx_code):
"Setup some common cases to match dtypes against specializations"
- if pyx_code.indenter("if kind in b'iu':"):
+ with pyx_code.indenter("if kind in b'iu':"):
pyx_code.putln("pass")
pyx_code.named_insertion_point("dtype_int")
- pyx_code.dedent()
- if pyx_code.indenter("elif kind == b'f':"):
+ with pyx_code.indenter("elif kind == b'f':"):
pyx_code.putln("pass")
pyx_code.named_insertion_point("dtype_float")
- pyx_code.dedent()
- if pyx_code.indenter("elif kind == b'c':"):
+ with pyx_code.indenter("elif kind == b'c':"):
pyx_code.putln("pass")
pyx_code.named_insertion_point("dtype_complex")
- pyx_code.dedent()
- if pyx_code.indenter("elif kind == b'O':"):
+ with pyx_code.indenter("elif kind == b'O':"):
pyx_code.putln("pass")
pyx_code.named_insertion_point("dtype_object")
- pyx_code.dedent()
match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'"
no_match = "dest_sig[{{dest_sig_idx}}] = None"
@@ -364,11 +372,10 @@ class FusedCFuncDefNode(StatListNode):
if final_type.is_pythran_expr:
cond += ' and arg_is_pythran_compatible'
- if codewriter.indenter("if %s:" % cond):
+ with codewriter.indenter("if %s:" % cond):
#codewriter.putln("print 'buffer match found based on numpy dtype'")
codewriter.putln(self.match)
codewriter.putln("break")
- codewriter.dedent()
def _buffer_parse_format_string_check(self, pyx_code, decl_code,
specialized_type, env):
@@ -394,15 +401,30 @@ class FusedCFuncDefNode(StatListNode):
pyx_code.context.update(
specialized_type_name=specialized_type.specialization_string,
- sizeof_dtype=self._sizeof_dtype(dtype))
+ sizeof_dtype=self._sizeof_dtype(dtype),
+ ndim_dtype=specialized_type.ndim,
+ dtype_is_struct_obj=int(dtype.is_struct or dtype.is_pyobject))
+ # use the memoryview object to check itemsize and ndim.
+ # In principle it could check more, but these are the easiest to do quickly
pyx_code.put_chunk(
u"""
# try {{dtype}}
- if itemsize == -1 or itemsize == {{sizeof_dtype}}:
- memslice = {{coerce_from_py_func}}(arg, 0)
+ if (((itemsize == -1 and arg_as_memoryview.itemsize == {{sizeof_dtype}})
+ or itemsize == {{sizeof_dtype}})
+ and arg_as_memoryview.ndim == {{ndim_dtype}}):
+ {{if dtype_is_struct_obj}}
+ if __PYX_IS_PYPY2:
+ # I wasn't able to diagnose why, but PyPy2 fails to convert a
+ # memoryview to a Cython memoryview in this case
+ memslice = {{coerce_from_py_func}}(arg, 0)
+ else:
+ {{else}}
+ if True:
+ {{endif}}
+ memslice = {{coerce_from_py_func}}(arg_as_memoryview, 0)
if memslice.memview:
- __PYX_XDEC_MEMVIEW(&memslice, 1)
+ __PYX_XCLEAR_MEMVIEW(&memslice, 1)
# print 'found a match for the buffer through format parsing'
%s
break
@@ -410,7 +432,7 @@ class FusedCFuncDefNode(StatListNode):
__pyx_PyErr_Clear()
""" % self.match)
- def _buffer_checks(self, buffer_types, pythran_types, pyx_code, decl_code, env):
+ def _buffer_checks(self, buffer_types, pythran_types, pyx_code, decl_code, accept_none, env):
"""
Generate Cython code to match objects to buffer specializations.
First try to get a numpy dtype object and match it against the individual
@@ -467,9 +489,35 @@ class FusedCFuncDefNode(StatListNode):
self._buffer_check_numpy_dtype(pyx_code, buffer_types, pythran_types)
pyx_code.dedent(2)
- for specialized_type in buffer_types:
- self._buffer_parse_format_string_check(
- pyx_code, decl_code, specialized_type, env)
+ if accept_none:
+ # If None is acceptable, then Cython <3.0 matched None with the
+ # first type. This behaviour isn't ideal, but keep it for backwards
+ # compatibility. Better behaviour would be to see if subsequent
+ # arguments give a stronger match.
+ pyx_code.context.update(
+ specialized_type_name=buffer_types[0].specialization_string
+ )
+ pyx_code.put_chunk(
+ """
+ if arg is None:
+ %s
+ break
+ """ % self.match)
+
+ # creating a Cython memoryview from a Python memoryview avoids the
+ # need to get the buffer multiple times, and we can
+ # also use it to check itemsizes etc
+ pyx_code.put_chunk(
+ """
+ try:
+ arg_as_memoryview = memoryview(arg)
+ except (ValueError, TypeError):
+ pass
+ """)
+ with pyx_code.indenter("else:"):
+ for specialized_type in buffer_types:
+ self._buffer_parse_format_string_check(
+ pyx_code, decl_code, specialized_type, env)
def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types, pythran_types):
"""
@@ -481,8 +529,9 @@ class FusedCFuncDefNode(StatListNode):
ctypedef struct {{memviewslice_cname}}:
void *memview
- void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil)
+ void __PYX_XCLEAR_MEMVIEW({{memviewslice_cname}} *, int have_gil)
bint __pyx_memoryview_check(object)
+ bint __PYX_IS_PYPY2 "(CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION == 2)"
""")
pyx_code.local_variable_declarations.put_chunk(
@@ -507,6 +556,12 @@ class FusedCFuncDefNode(StatListNode):
ndarray = __Pyx_ImportNumPyArrayTypeIfAvailable()
""")
+ pyx_code.imports.put_chunk(
+ u"""
+ cdef memoryview arg_as_memoryview
+ """
+ )
+
seen_typedefs = set()
seen_int_dtypes = set()
for buffer_type in all_buffer_types:
@@ -580,6 +635,26 @@ class FusedCFuncDefNode(StatListNode):
{{endif}}
""")
+ def _fused_signature_index(self, pyx_code):
+ """
+ Generate Cython code for constructing a persistent nested dictionary index of
+ fused type specialization signatures.
+ """
+ pyx_code.put_chunk(
+ u"""
+ if not _fused_sigindex:
+ for sig in <dict>signatures:
+ sigindex_node = _fused_sigindex
+ *sig_series, last_type = sig.strip('()').split('|')
+ for sig_type in sig_series:
+ if sig_type not in sigindex_node:
+ sigindex_node[sig_type] = sigindex_node = {}
+ else:
+ sigindex_node = sigindex_node[sig_type]
+ sigindex_node[last_type] = sig
+ """
+ )
+
def make_fused_cpdef(self, orig_py_func, env, is_def):
"""
This creates the function that is indexable from Python and does
@@ -616,10 +691,14 @@ class FusedCFuncDefNode(StatListNode):
pyx_code.put_chunk(
u"""
- def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
+ def __pyx_fused_cpdef(signatures, args, kwargs, defaults, _fused_sigindex={}):
# FIXME: use a typed signature - currently fails badly because
# default arguments inherit the types we specify here!
+ cdef list search_list
+
+ cdef dict sn, sigindex_node
+
dest_sig = [None] * {{n_fused}}
if kwargs is not None and not kwargs:
@@ -630,7 +709,7 @@ class FusedCFuncDefNode(StatListNode):
# instance check body
""")
- pyx_code.indent() # indent following code to function body
+ pyx_code.indent() # indent following code to function body
pyx_code.named_insertion_point("imports")
pyx_code.named_insertion_point("func_defs")
pyx_code.named_insertion_point("local_variable_declarations")
@@ -661,19 +740,20 @@ class FusedCFuncDefNode(StatListNode):
self._unpack_argument(pyx_code)
# 'unrolled' loop, first match breaks out of it
- if pyx_code.indenter("while 1:"):
+ with pyx_code.indenter("while 1:"):
if normal_types:
self._fused_instance_checks(normal_types, pyx_code, env)
if buffer_types or pythran_types:
env.use_utility_code(Code.UtilityCode.load_cached("IsLittleEndian", "ModuleSetupCode.c"))
- self._buffer_checks(buffer_types, pythran_types, pyx_code, decl_code, env)
+ self._buffer_checks(
+ buffer_types, pythran_types, pyx_code, decl_code,
+ arg.accept_none, env)
if has_object_fallback:
pyx_code.context.update(specialized_type_name='object')
pyx_code.putln(self.match)
else:
pyx_code.putln(self.no_match)
pyx_code.putln("break")
- pyx_code.dedent()
fused_index += 1
all_buffer_types.update(buffer_types)
@@ -687,23 +767,36 @@ class FusedCFuncDefNode(StatListNode):
env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c"))
env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c"))
+ self._fused_signature_index(pyx_code)
+
pyx_code.put_chunk(
u"""
- candidates = []
- for sig in <dict>signatures:
- match_found = False
- src_sig = sig.strip('()').split('|')
- for i in range(len(dest_sig)):
- dst_type = dest_sig[i]
- if dst_type is not None:
- if src_sig[i] == dst_type:
- match_found = True
- else:
- match_found = False
- break
+ sigindex_matches = []
+ sigindex_candidates = [_fused_sigindex]
+
+ for dst_type in dest_sig:
+ found_matches = []
+ found_candidates = []
+ # Make two seperate lists: One for signature sub-trees
+ # with at least one definite match, and another for
+ # signature sub-trees with only ambiguous matches
+ # (where `dest_sig[i] is None`).
+ if dst_type is None:
+ for sn in sigindex_matches:
+ found_matches.extend(sn.values())
+ for sn in sigindex_candidates:
+ found_candidates.extend(sn.values())
+ else:
+ for search_list in (sigindex_matches, sigindex_candidates):
+ for sn in search_list:
+ if dst_type in sn:
+ found_matches.append(sn[dst_type])
+ sigindex_matches = found_matches
+ sigindex_candidates = found_candidates
+ if not (found_matches or found_candidates):
+ break
- if match_found:
- candidates.append(sig)
+ candidates = sigindex_matches
if not candidates:
raise TypeError("No matching signature found")
@@ -792,16 +885,18 @@ class FusedCFuncDefNode(StatListNode):
for arg in self.node.args:
if arg.default:
arg.default = arg.default.analyse_expressions(env)
- defaults.append(ProxyNode(arg.default))
+ # coerce the argument to temp since CloneNode really requires a temp
+ defaults.append(ProxyNode(arg.default.coerce_to_temp(env)))
else:
defaults.append(None)
for i, stat in enumerate(self.stats):
stat = self.stats[i] = stat.analyse_expressions(env)
- if isinstance(stat, FuncDefNode):
+ if isinstance(stat, FuncDefNode) and stat is not self.py_func:
+ # the dispatcher specifically doesn't want its defaults overriding
for arg, default in zip(stat.args, defaults):
if default is not None:
- arg.default = CloneNode(default).coerce_to(arg.type, env)
+ arg.default = CloneNode(default).analyse_expressions(env).coerce_to(arg.type, env)
if self.py_func:
args = [CloneNode(default) for default in defaults if default]
@@ -829,6 +924,10 @@ class FusedCFuncDefNode(StatListNode):
else:
nodes = self.nodes
+ # For the moment, fused functions do not support METH_FASTCALL
+ for node in nodes:
+ node.entry.signature.use_fastcall = False
+
signatures = [StringEncoding.EncodedString(node.specialized_signature_string)
for node in nodes]
keys = [ExprNodes.StringNode(node.pos, value=sig)
@@ -847,8 +946,10 @@ class FusedCFuncDefNode(StatListNode):
self.py_func.pymethdef_required = True
self.fused_func_assignment.generate_function_definitions(env, code)
+ from . import Options
for stat in self.stats:
- if isinstance(stat, FuncDefNode) and stat.entry.used:
+ from_pyx = Options.cimport_from_pyx and not stat.entry.visibility == 'extern'
+ if isinstance(stat, FuncDefNode) and (stat.entry.used or from_pyx):
code.mark_pos(stat.pos)
stat.generate_function_definitions(env, code)
@@ -877,7 +978,7 @@ class FusedCFuncDefNode(StatListNode):
"((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" %
(self.resulting_fused_function.result(),
self.__signatures__.result()))
- code.put_giveref(self.__signatures__.result())
+ self.__signatures__.generate_giveref(code)
self.__signatures__.generate_post_assignment_code(code)
self.__signatures__.free_temps(code)