diff options
author | will-ca <37680486+will-ca@users.noreply.github.com> | 2020-04-01 03:08:37 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-01 12:08:37 +0200 |
commit | 27b5adbb461675ef775aee46d17b0a6d3b2c047e (patch) | |
tree | 170506544e7ac4c7257ea47a3ca1eabdf3440b9b | |
parent | 4fd901849ff30399289ebf6613c81d8075a79cc4 (diff) | |
download | cython-27b5adbb461675ef775aee46d17b0a6d3b2c047e.tar.gz |
Make fused function dispatch O(n) for `cpdef` functions. (GH-3366)
* Rewrote signature matching for fused cpdef function dispatch to use a pre-built tree index in a mutable default argument and be O(n).
* Added test to ensure proper differentiation between ambiguously compatible and definitely compatible arguments.
* Added test to ensure fused cpdef's can be called by the module itself during import.
* Added test to ensure consistent handling of ambiguous fused cpdef signatures.
* Test for explicitly defined fused cpdef method.
* Add .komodoproject to .gitignore.
* Add /cython_debug/ to .gitignore.
Closes #1385.
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | Cython/Compiler/FusedNode.py | 67 | ||||
-rw-r--r-- | tests/run/fused_cpdef.pyx | 104 |
3 files changed, 148 insertions, 26 deletions
diff --git a/.gitignore b/.gitignore index 665ba4743..fd5461a8c 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ Demos/*/*.html /TEST_TMP/ /build/ +/cython_build/ /wheelhouse*/ !tests/build/ /dist/ @@ -52,3 +53,5 @@ MANIFEST /.idea /*.iml +# Komodo EDIT/IDE project files +/*.komodoproject diff --git a/Cython/Compiler/FusedNode.py b/Cython/Compiler/FusedNode.py index dbf3bbeac..42f15eea3 100644 --- a/Cython/Compiler/FusedNode.py +++ b/Cython/Compiler/FusedNode.py @@ -584,6 +584,26 @@ class FusedCFuncDefNode(StatListNode): {{endif}} """) + def _fused_signature_index(self, pyx_code): + """ + Generate Cython code for constructing a persistent nested dictionary index of + fused type specialization signatures. + """ + pyx_code.put_chunk( + u""" + if not _fused_sigindex: + for sig in <dict>signatures: + sigindex_node = _fused_sigindex + sig_series = sig.strip('()').split('|') + for sig_type in sig_series[:-1]: + if sig_type not in sigindex_node: + sigindex_node[sig_type] = sigindex_node = {} + else: + sigindex_node = sigindex_node[sig_type] + sigindex_node[sig_series[-1]] = sig + """ + ) + def make_fused_cpdef(self, orig_py_func, env, is_def): """ This creates the function that is indexable from Python and does @@ -620,10 +640,14 @@ class FusedCFuncDefNode(StatListNode): pyx_code.put_chunk( u""" - def __pyx_fused_cpdef(signatures, args, kwargs, defaults): + def __pyx_fused_cpdef(signatures, args, kwargs, defaults, *, _fused_sigindex={}): # FIXME: use a typed signature - currently fails badly because # default arguments inherit the types we specify here! + cdef list search_list + + cdef dict sn, sigindex_node + dest_sig = [None] * {{n_fused}} if kwargs is not None and not kwargs: @@ -691,23 +715,36 @@ class FusedCFuncDefNode(StatListNode): env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c")) env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c")) + self._fused_signature_index(pyx_code) + pyx_code.put_chunk( u""" - candidates = [] - for sig in <dict>signatures: - match_found = False - src_sig = sig.strip('()').split('|') - for i in range(len(dest_sig)): - dst_type = dest_sig[i] - if dst_type is not None: - if src_sig[i] == dst_type: - match_found = True - else: - match_found = False - break + sigindex_matches = [] + sigindex_candidates = [_fused_sigindex] + + for dst_type in dest_sig: + found_matches = [] + found_candidates = [] + # Make two seperate lists: One for signature sub-trees + # with at least one definite match, and another for + # signature sub-trees with only ambiguous matches + # (where `dest_sig[i] is None`). + if dst_type is None: + for sn in sigindex_matches: + found_matches.extend(sn.values()) + for sn in sigindex_candidates: + found_candidates.extend(sn.values()) + else: + for search_list in (sigindex_matches, sigindex_candidates): + for sn in search_list: + if dst_type in sn: + found_matches.append(sn[dst_type]) + sigindex_matches = found_matches + sigindex_candidates = found_candidates + if not (found_matches or found_candidates): + break - if match_found: - candidates.append(sig) + candidates = sigindex_matches if not candidates: raise TypeError("No matching signature found") diff --git a/tests/run/fused_cpdef.pyx b/tests/run/fused_cpdef.pyx index 0b63c8b98..4a614e0f4 100644 --- a/tests/run/fused_cpdef.pyx +++ b/tests/run/fused_cpdef.pyx @@ -1,13 +1,17 @@ +# cython: language_level=3 +# mode: run + cimport cython +import sys, io cy = __import__("cython") cpdef func1(self, cython.integral x): - print "%s," % (self,), + print(f"{self},", end=' ') if cython.integral is int: - print 'x is int', x, cython.typeof(x) + print('x is int', x, cython.typeof(x)) else: - print 'x is long', x, cython.typeof(x) + print('x is long', x, cython.typeof(x)) class A(object): @@ -16,6 +20,18 @@ class A(object): def __str__(self): return "A" +cdef class B: + cpdef int meth(self, cython.integral x): + print(f"{self},", end=' ') + if cython.integral is int: + print('x is int', x, cython.typeof(x)) + else: + print('x is long', x, cython.typeof(x)) + return 0 + + def __str__(self): + return "B" + pyfunc = func1 def test_fused_cpdef(): @@ -32,23 +48,71 @@ def test_fused_cpdef(): A, x is long 2 long A, x is long 2 long A, x is long 2 long + <BLANKLINE> + B, x is long 2 long """ func1[int](None, 2) func1[long](None, 2) func1(None, 2) - print + print() pyfunc[cy.int](None, 2) pyfunc(None, 2) - print + print() A.meth[cy.int](A(), 2) A.meth(A(), 2) A().meth[cy.long](2) A().meth(2) + print() + + B().meth(2) + + +midimport_run = io.StringIO() +if sys.version_info.major < 3: + # Monkey-patch midimport_run.write to accept non-unicode strings under Python 2. + midimport_run.write = lambda c: io.StringIO.write(midimport_run, unicode(c)) + +realstdout = sys.stdout +sys.stdout = midimport_run + +try: + # Run `test_fused_cpdef()` during import and save the result for + # `test_midimport_run()`. + test_fused_cpdef() +except Exception as e: + midimport_run.write(f"{e!r}\n") +finally: + sys.stdout = realstdout + +def test_midimport_run(): + # At one point, dynamically calling fused cpdef functions during import + # would fail because the type signature-matching indices weren't + # yet initialized. + # (See Compiler.FusedNode.FusedCFuncDefNode._fused_signature_index, + # GH-3366.) + """ + >>> test_midimport_run() + None, x is int 2 int + None, x is long 2 long + None, x is long 2 long + <BLANKLINE> + None, x is int 2 int + None, x is long 2 long + <BLANKLINE> + A, x is int 2 int + A, x is long 2 long + A, x is long 2 long + A, x is long 2 long + <BLANKLINE> + B, x is long 2 long + """ + print(midimport_run.getvalue(), end='') + def assert_raise(func, *args): try: @@ -70,23 +134,31 @@ def test_badcall(): assert_raise(A.meth) assert_raise(A().meth[cy.int]) assert_raise(A.meth[cy.int]) + assert_raise(B().meth, 1, 2, 3) + +def test_nomatch(): + """ + >>> func1(None, ()) + Traceback (most recent call last): + TypeError: No matching signature found + """ ctypedef long double long_double cpdef multiarg(cython.integral x, cython.floating y): if cython.integral is int: - print "x is an int,", + print("x is an int,", end=' ') else: - print "x is a long,", + print("x is a long,", end=' ') if cython.floating is long_double: - print "y is a long double:", + print("y is a long double:", end=' ') elif float is cython.floating: - print "y is a float:", + print("y is a float:", end=' ') else: - print "y is a double:", + print("y is a double:", end=' ') - print x, y + print(x, y) def test_multiarg(): """ @@ -104,3 +176,13 @@ def test_multiarg(): multiarg[int, float](1, 2.0) multiarg[cy.int, cy.float](1, 2.0) multiarg(4, 5.0) + +def test_ambiguousmatch(): + """ + >>> multiarg(5, ()) + Traceback (most recent call last): + TypeError: Function call with ambiguous argument types + >>> multiarg((), 2.0) + Traceback (most recent call last): + TypeError: Function call with ambiguous argument types + """ |