diff options
Diffstat (limited to 'Cython/Compiler/FusedNode.py')
-rw-r--r-- | Cython/Compiler/FusedNode.py | 193 |
1 files changed, 147 insertions, 46 deletions
diff --git a/Cython/Compiler/FusedNode.py b/Cython/Compiler/FusedNode.py index 26d6ffd3d..7876916db 100644 --- a/Cython/Compiler/FusedNode.py +++ b/Cython/Compiler/FusedNode.py @@ -7,6 +7,7 @@ from . import (ExprNodes, PyrexTypes, MemoryView, from .ExprNodes import CloneNode, ProxyNode, TupleNode from .Nodes import FuncDefNode, CFuncDefNode, StatListNode, DefNode from ..Utils import OrderedSet +from .Errors import error, CannotSpecialize class FusedCFuncDefNode(StatListNode): @@ -141,7 +142,14 @@ class FusedCFuncDefNode(StatListNode): copied_node = copy.deepcopy(self.node) # Make the types in our CFuncType specific. - type = copied_node.type.specialize(fused_to_specific) + try: + type = copied_node.type.specialize(fused_to_specific) + except CannotSpecialize: + # unlike for the argument types, specializing the return type can fail + error(copied_node.pos, "Return type is a fused type that cannot " + "be determined from the function arguments") + self.py_func = None # this is just to let the compiler exit gracefully + return entry = copied_node.entry type.specialize_entry(entry, cname) @@ -220,6 +228,10 @@ class FusedCFuncDefNode(StatListNode): arg.type = arg.type.specialize(fused_to_specific) if arg.type.is_memoryviewslice: arg.type.validate_memslice_dtype(arg.pos) + if arg.annotation: + # TODO might be nice if annotations were specialized instead? + # (Or might be hard to do reliably) + arg.annotation.untyped = True def create_new_local_scope(self, node, env, f2s): """ @@ -264,12 +276,12 @@ class FusedCFuncDefNode(StatListNode): Returns whether an error was issued and whether we should stop in in order to prevent a flood of errors. """ - num_errors = Errors.num_errors + num_errors = Errors.get_errors_count() transform = ParseTreeTransforms.ReplaceFusedTypeChecks( copied_node.local_scope) transform(copied_node) - if Errors.num_errors > num_errors: + if Errors.get_errors_count() > num_errors: return False return True @@ -309,25 +321,21 @@ class FusedCFuncDefNode(StatListNode): def _buffer_check_numpy_dtype_setup_cases(self, pyx_code): "Setup some common cases to match dtypes against specializations" - if pyx_code.indenter("if kind in b'iu':"): + with pyx_code.indenter("if kind in b'iu':"): pyx_code.putln("pass") pyx_code.named_insertion_point("dtype_int") - pyx_code.dedent() - if pyx_code.indenter("elif kind == b'f':"): + with pyx_code.indenter("elif kind == b'f':"): pyx_code.putln("pass") pyx_code.named_insertion_point("dtype_float") - pyx_code.dedent() - if pyx_code.indenter("elif kind == b'c':"): + with pyx_code.indenter("elif kind == b'c':"): pyx_code.putln("pass") pyx_code.named_insertion_point("dtype_complex") - pyx_code.dedent() - if pyx_code.indenter("elif kind == b'O':"): + with pyx_code.indenter("elif kind == b'O':"): pyx_code.putln("pass") pyx_code.named_insertion_point("dtype_object") - pyx_code.dedent() match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'" no_match = "dest_sig[{{dest_sig_idx}}] = None" @@ -364,11 +372,10 @@ class FusedCFuncDefNode(StatListNode): if final_type.is_pythran_expr: cond += ' and arg_is_pythran_compatible' - if codewriter.indenter("if %s:" % cond): + with codewriter.indenter("if %s:" % cond): #codewriter.putln("print 'buffer match found based on numpy dtype'") codewriter.putln(self.match) codewriter.putln("break") - codewriter.dedent() def _buffer_parse_format_string_check(self, pyx_code, decl_code, specialized_type, env): @@ -394,15 +401,30 @@ class FusedCFuncDefNode(StatListNode): pyx_code.context.update( specialized_type_name=specialized_type.specialization_string, - sizeof_dtype=self._sizeof_dtype(dtype)) + sizeof_dtype=self._sizeof_dtype(dtype), + ndim_dtype=specialized_type.ndim, + dtype_is_struct_obj=int(dtype.is_struct or dtype.is_pyobject)) + # use the memoryview object to check itemsize and ndim. + # In principle it could check more, but these are the easiest to do quickly pyx_code.put_chunk( u""" # try {{dtype}} - if itemsize == -1 or itemsize == {{sizeof_dtype}}: - memslice = {{coerce_from_py_func}}(arg, 0) + if (((itemsize == -1 and arg_as_memoryview.itemsize == {{sizeof_dtype}}) + or itemsize == {{sizeof_dtype}}) + and arg_as_memoryview.ndim == {{ndim_dtype}}): + {{if dtype_is_struct_obj}} + if __PYX_IS_PYPY2: + # I wasn't able to diagnose why, but PyPy2 fails to convert a + # memoryview to a Cython memoryview in this case + memslice = {{coerce_from_py_func}}(arg, 0) + else: + {{else}} + if True: + {{endif}} + memslice = {{coerce_from_py_func}}(arg_as_memoryview, 0) if memslice.memview: - __PYX_XDEC_MEMVIEW(&memslice, 1) + __PYX_XCLEAR_MEMVIEW(&memslice, 1) # print 'found a match for the buffer through format parsing' %s break @@ -410,7 +432,7 @@ class FusedCFuncDefNode(StatListNode): __pyx_PyErr_Clear() """ % self.match) - def _buffer_checks(self, buffer_types, pythran_types, pyx_code, decl_code, env): + def _buffer_checks(self, buffer_types, pythran_types, pyx_code, decl_code, accept_none, env): """ Generate Cython code to match objects to buffer specializations. First try to get a numpy dtype object and match it against the individual @@ -467,9 +489,35 @@ class FusedCFuncDefNode(StatListNode): self._buffer_check_numpy_dtype(pyx_code, buffer_types, pythran_types) pyx_code.dedent(2) - for specialized_type in buffer_types: - self._buffer_parse_format_string_check( - pyx_code, decl_code, specialized_type, env) + if accept_none: + # If None is acceptable, then Cython <3.0 matched None with the + # first type. This behaviour isn't ideal, but keep it for backwards + # compatibility. Better behaviour would be to see if subsequent + # arguments give a stronger match. + pyx_code.context.update( + specialized_type_name=buffer_types[0].specialization_string + ) + pyx_code.put_chunk( + """ + if arg is None: + %s + break + """ % self.match) + + # creating a Cython memoryview from a Python memoryview avoids the + # need to get the buffer multiple times, and we can + # also use it to check itemsizes etc + pyx_code.put_chunk( + """ + try: + arg_as_memoryview = memoryview(arg) + except (ValueError, TypeError): + pass + """) + with pyx_code.indenter("else:"): + for specialized_type in buffer_types: + self._buffer_parse_format_string_check( + pyx_code, decl_code, specialized_type, env) def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types, pythran_types): """ @@ -481,8 +529,9 @@ class FusedCFuncDefNode(StatListNode): ctypedef struct {{memviewslice_cname}}: void *memview - void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil) + void __PYX_XCLEAR_MEMVIEW({{memviewslice_cname}} *, int have_gil) bint __pyx_memoryview_check(object) + bint __PYX_IS_PYPY2 "(CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION == 2)" """) pyx_code.local_variable_declarations.put_chunk( @@ -507,6 +556,12 @@ class FusedCFuncDefNode(StatListNode): ndarray = __Pyx_ImportNumPyArrayTypeIfAvailable() """) + pyx_code.imports.put_chunk( + u""" + cdef memoryview arg_as_memoryview + """ + ) + seen_typedefs = set() seen_int_dtypes = set() for buffer_type in all_buffer_types: @@ -580,6 +635,26 @@ class FusedCFuncDefNode(StatListNode): {{endif}} """) + def _fused_signature_index(self, pyx_code): + """ + Generate Cython code for constructing a persistent nested dictionary index of + fused type specialization signatures. + """ + pyx_code.put_chunk( + u""" + if not _fused_sigindex: + for sig in <dict>signatures: + sigindex_node = _fused_sigindex + *sig_series, last_type = sig.strip('()').split('|') + for sig_type in sig_series: + if sig_type not in sigindex_node: + sigindex_node[sig_type] = sigindex_node = {} + else: + sigindex_node = sigindex_node[sig_type] + sigindex_node[last_type] = sig + """ + ) + def make_fused_cpdef(self, orig_py_func, env, is_def): """ This creates the function that is indexable from Python and does @@ -616,10 +691,14 @@ class FusedCFuncDefNode(StatListNode): pyx_code.put_chunk( u""" - def __pyx_fused_cpdef(signatures, args, kwargs, defaults): + def __pyx_fused_cpdef(signatures, args, kwargs, defaults, _fused_sigindex={}): # FIXME: use a typed signature - currently fails badly because # default arguments inherit the types we specify here! + cdef list search_list + + cdef dict sn, sigindex_node + dest_sig = [None] * {{n_fused}} if kwargs is not None and not kwargs: @@ -630,7 +709,7 @@ class FusedCFuncDefNode(StatListNode): # instance check body """) - pyx_code.indent() # indent following code to function body + pyx_code.indent() # indent following code to function body pyx_code.named_insertion_point("imports") pyx_code.named_insertion_point("func_defs") pyx_code.named_insertion_point("local_variable_declarations") @@ -661,19 +740,20 @@ class FusedCFuncDefNode(StatListNode): self._unpack_argument(pyx_code) # 'unrolled' loop, first match breaks out of it - if pyx_code.indenter("while 1:"): + with pyx_code.indenter("while 1:"): if normal_types: self._fused_instance_checks(normal_types, pyx_code, env) if buffer_types or pythran_types: env.use_utility_code(Code.UtilityCode.load_cached("IsLittleEndian", "ModuleSetupCode.c")) - self._buffer_checks(buffer_types, pythran_types, pyx_code, decl_code, env) + self._buffer_checks( + buffer_types, pythran_types, pyx_code, decl_code, + arg.accept_none, env) if has_object_fallback: pyx_code.context.update(specialized_type_name='object') pyx_code.putln(self.match) else: pyx_code.putln(self.no_match) pyx_code.putln("break") - pyx_code.dedent() fused_index += 1 all_buffer_types.update(buffer_types) @@ -687,23 +767,36 @@ class FusedCFuncDefNode(StatListNode): env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c")) env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c")) + self._fused_signature_index(pyx_code) + pyx_code.put_chunk( u""" - candidates = [] - for sig in <dict>signatures: - match_found = False - src_sig = sig.strip('()').split('|') - for i in range(len(dest_sig)): - dst_type = dest_sig[i] - if dst_type is not None: - if src_sig[i] == dst_type: - match_found = True - else: - match_found = False - break + sigindex_matches = [] + sigindex_candidates = [_fused_sigindex] + + for dst_type in dest_sig: + found_matches = [] + found_candidates = [] + # Make two seperate lists: One for signature sub-trees + # with at least one definite match, and another for + # signature sub-trees with only ambiguous matches + # (where `dest_sig[i] is None`). + if dst_type is None: + for sn in sigindex_matches: + found_matches.extend(sn.values()) + for sn in sigindex_candidates: + found_candidates.extend(sn.values()) + else: + for search_list in (sigindex_matches, sigindex_candidates): + for sn in search_list: + if dst_type in sn: + found_matches.append(sn[dst_type]) + sigindex_matches = found_matches + sigindex_candidates = found_candidates + if not (found_matches or found_candidates): + break - if match_found: - candidates.append(sig) + candidates = sigindex_matches if not candidates: raise TypeError("No matching signature found") @@ -792,16 +885,18 @@ class FusedCFuncDefNode(StatListNode): for arg in self.node.args: if arg.default: arg.default = arg.default.analyse_expressions(env) - defaults.append(ProxyNode(arg.default)) + # coerce the argument to temp since CloneNode really requires a temp + defaults.append(ProxyNode(arg.default.coerce_to_temp(env))) else: defaults.append(None) for i, stat in enumerate(self.stats): stat = self.stats[i] = stat.analyse_expressions(env) - if isinstance(stat, FuncDefNode): + if isinstance(stat, FuncDefNode) and stat is not self.py_func: + # the dispatcher specifically doesn't want its defaults overriding for arg, default in zip(stat.args, defaults): if default is not None: - arg.default = CloneNode(default).coerce_to(arg.type, env) + arg.default = CloneNode(default).analyse_expressions(env).coerce_to(arg.type, env) if self.py_func: args = [CloneNode(default) for default in defaults if default] @@ -829,6 +924,10 @@ class FusedCFuncDefNode(StatListNode): else: nodes = self.nodes + # For the moment, fused functions do not support METH_FASTCALL + for node in nodes: + node.entry.signature.use_fastcall = False + signatures = [StringEncoding.EncodedString(node.specialized_signature_string) for node in nodes] keys = [ExprNodes.StringNode(node.pos, value=sig) @@ -847,8 +946,10 @@ class FusedCFuncDefNode(StatListNode): self.py_func.pymethdef_required = True self.fused_func_assignment.generate_function_definitions(env, code) + from . import Options for stat in self.stats: - if isinstance(stat, FuncDefNode) and stat.entry.used: + from_pyx = Options.cimport_from_pyx and not stat.entry.visibility == 'extern' + if isinstance(stat, FuncDefNode) and (stat.entry.used or from_pyx): code.mark_pos(stat.pos) stat.generate_function_definitions(env, code) @@ -877,7 +978,7 @@ class FusedCFuncDefNode(StatListNode): "((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" % (self.resulting_fused_function.result(), self.__signatures__.result())) - code.put_giveref(self.__signatures__.result()) + self.__signatures__.generate_giveref(code) self.__signatures__.generate_post_assignment_code(code) self.__signatures__.free_temps(code) |