diff options
author | will-ca <37680486+will-ca@users.noreply.github.com> | 2020-03-26 04:06:12 -0700 |
---|---|---|
committer | Stefan Behnel <stefan_ml@behnel.de> | 2020-03-31 11:33:02 +0200 |
commit | 75aeda62c1a2bcdb7a3edc91766f0d3666e4ffd0 (patch) | |
tree | 3556ede6a4bd7e4fd96f20b411fee2c7b8697466 | |
parent | b2833f2ebe7c9567883a0140d68deb2e40d4e09d (diff) | |
download | cython-75aeda62c1a2bcdb7a3edc91766f0d3666e4ffd0.tar.gz |
Make `Shadow.inline()` caching account for language version and compilation environment. (GH-3440)
Closes https://github.com/cython/cython/issues/3419
-rw-r--r-- | Cython/Build/Inline.py | 23 | ||||
-rw-r--r-- | Cython/Build/Tests/TestInline.py | 12 |
2 files changed, 27 insertions, 8 deletions
diff --git a/Cython/Build/Inline.py b/Cython/Build/Inline.py index 47be53427..fdd38e21d 100644 --- a/Cython/Build/Inline.py +++ b/Cython/Build/Inline.py @@ -144,6 +144,10 @@ def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None): else: print("Couldn't find %r" % symbol) +def _inline_key(orig_code, arg_sigs, language_level): + key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__ + return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest() + def cython_inline(code, get_type=unsafe_type, lib_dir=os.path.join(get_cython_cache_dir(), 'inline'), cython_include_dirs=None, cython_compiler_directives=None, @@ -153,13 +157,20 @@ def cython_inline(code, get_type=unsafe_type, get_type = lambda x: 'object' ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context + cython_compiler_directives = dict(cython_compiler_directives or {}) + if language_level is None and 'language_level' not in cython_compiler_directives: + language_level = '3str' + if language_level is not None: + cython_compiler_directives['language_level'] = language_level + # Fast path if this has been called in this session. _unbound_symbols = _cython_inline_cache.get(code) if _unbound_symbols is not None: _populate_unbound(kwds, _unbound_symbols, locals, globals) args = sorted(kwds.items()) arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args]) - invoke = _cython_inline_cache.get((code, arg_sigs)) + key_hash = _inline_key(code, arg_sigs, language_level) + invoke = _cython_inline_cache.get((code, arg_sigs, key_hash)) if invoke is not None: arg_list = [arg[1] for arg in args] return invoke(*arg_list) @@ -180,10 +191,6 @@ def cython_inline(code, get_type=unsafe_type, # Parsing from strings not fully supported (e.g. cimports). print("Could not parse code as a string (to extract unbound symbols).") - cython_compiler_directives = dict(cython_compiler_directives or {}) - if language_level is not None: - cython_compiler_directives['language_level'] = language_level - cimports = [] for name, arg in list(kwds.items()): if arg is cython_module: @@ -191,8 +198,8 @@ def cython_inline(code, get_type=unsafe_type, del kwds[name] arg_names = sorted(kwds) arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) - key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__ - module_name = "_cython_inline_" + hashlib.md5(_unicode(key).encode('utf-8')).hexdigest() + key_hash = _inline_key(orig_code, arg_sigs, language_level) + module_name = "_cython_inline_" + key_hash if module_name in sys.modules: module = sys.modules[module_name] @@ -259,7 +266,7 @@ def __invoke(%(params)s): module = load_dynamic(module_name, module_path) - _cython_inline_cache[orig_code, arg_sigs] = module.__invoke + _cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke arg_list = [kwds[arg] for arg in arg_names] return module.__invoke(*arg_list) diff --git a/Cython/Build/Tests/TestInline.py b/Cython/Build/Tests/TestInline.py index 0a40e0de5..5ef9fec4e 100644 --- a/Cython/Build/Tests/TestInline.py +++ b/Cython/Build/Tests/TestInline.py @@ -74,6 +74,18 @@ class TestInline(CythonTest): 6 ) + def test_lang_version(self): + # GH-3419. Caching for inline code didn't always respect compiler directives. + inline_divcode = "def f(int a, int b): return a/b" + self.assertEqual( + inline(inline_divcode, language_level=2)['f'](5,2), + 2 + ) + self.assertEqual( + inline(inline_divcode, language_level=3)['f'](5,2), + 2.5 + ) + if has_numpy: def test_numpy(self): |