diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/cyextension/util.pyx | 42 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/_py_util.py | 59 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 27 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 38 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/plugin/plugin_base.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 4 |
7 files changed, 137 insertions, 59 deletions
diff --git a/lib/sqlalchemy/cyextension/util.pyx b/lib/sqlalchemy/cyextension/util.pyx index ac15ff9de..62ca960b3 100644 --- a/lib/sqlalchemy/cyextension/util.pyx +++ b/lib/sqlalchemy/cyextension/util.pyx @@ -41,3 +41,45 @@ def _distill_raw_params(object params): return [params] else: raise exc.ArgumentError("mapping or sequence expected for parameters") + +cdef class prefix_anon_map(dict): + def __missing__(self, str key): + cdef str derived + cdef int anonymous_counter + cdef dict self_dict = self + + derived = key.split(" ", 1)[1] + + anonymous_counter = self_dict.get(derived, 1) + self_dict[derived] = anonymous_counter + 1 + value = f"{derived}_{anonymous_counter}" + self_dict[key] = value + return value + + +cdef class cache_anon_map(dict): + cdef int _index + + def __init__(self): + self._index = 0 + + def get_anon(self, obj): + cdef long idself + cdef str id_ + cdef dict self_dict = self + + idself = id(obj) + if idself in self_dict: + return self_dict[idself], True + else: + id_ = self.__missing__(idself) + return id_, False + + def __missing__(self, key): + cdef str val + cdef dict self_dict = self + + self_dict[key] = val = str(self._index) + self._index += 1 + return val + diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py new file mode 100644 index 000000000..ceb637609 --- /dev/null +++ b/lib/sqlalchemy/sql/_py_util.py @@ -0,0 +1,59 @@ +# sql/_py_util.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +class prefix_anon_map(dict): + """A map that creates new keys for missing key access. + + Considers keys of the form "<ident> <name>" to produce + new symbols "<name>_<index>", where "index" is an incrementing integer + corresponding to <name>. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __missing__(self, key): + (ident, derived) = key.split(" ", 1) + anonymous_counter = self.get(derived, 1) + self[derived] = anonymous_counter + 1 + value = f"{derived}_{anonymous_counter}" + self[key] = value + return value + + +class cache_anon_map(dict): + """A map that creates new keys for missing key access. + + Produces an incrementing sequence given a series of unique keys. + + This is similar to the compiler prefix_anon_map class although simpler. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + _index = 0 + + def get_anon(self, object_): + + idself = id(object_) + if idself in self: + return self[idself], True + else: + # inline of __missing__ + self[idself] = id_ = str(self._index) + self._index += 1 + + return id_, False + + def __missing__(self, key): + self[key] = val = str(self._index) + self._index += 1 + return val diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 4165751ca..b5a20830d 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -30,6 +30,12 @@ from .. import util from ..util import HasMemoized from ..util import hybridmethod +try: + from sqlalchemy.cyextension.util import prefix_anon_map # noqa +except ImportError: + from ._py_util import prefix_anon_map # noqa + + coercions = None elements = None type_api = None @@ -1012,27 +1018,6 @@ class Executable(roles.StatementRole, Generative): return self._execution_options -class prefix_anon_map(dict): - """A map that creates new keys for missing key access. - - Considers keys of the form "<ident> <name>" to produce - new symbols "<name>_<index>", where "index" is an incrementing integer - corresponding to <name>. - - Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which - is otherwise usually used for this type of operation. - - """ - - def __missing__(self, key): - (ident, derived) = key.split(" ", 1) - anonymous_counter = self.get(derived, 1) - self[derived] = anonymous_counter + 1 - value = derived + "_" + str(anonymous_counter) - self[key] = value - return value - - class SchemaEventTarget: """Base class for elements that are the targets of :class:`.DDLEvents` events. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 00270c9b5..08c993820 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1659,14 +1659,9 @@ class BindParameter(roles.InElementRole, ColumnElement): anon_map[NO_CACHE] = True return None - idself = id(self) - if idself in anon_map: - return (anon_map[idself], self.__class__) - else: - # inline of - # id_ = anon_map[idself] - anon_map[idself] = id_ = str(anon_map.index) - anon_map.index += 1 + id_, found = anon_map.get_anon(self) + if found: + return (id_, self.__class__) if bindparams is not None: bindparams.append(self) diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index d58b5c2bb..22398e7c1 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -12,6 +12,12 @@ from .. import util from ..inspection import inspect from ..util import HasMemoized +try: + from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa +except ImportError: + from ._py_util import cache_anon_map as anon_map # noqa + + SKIP_TRAVERSE = util.symbol("skip_traverse") COMPARE_FAILED = False COMPARE_SUCCEEDED = True @@ -177,16 +183,11 @@ class HasCacheKey: """ - idself = id(self) cls = self.__class__ - if idself in anon_map: - return (anon_map[idself], cls) - else: - # inline of - # id_ = anon_map[idself] - anon_map[idself] = id_ = str(anon_map.index) - anon_map.index += 1 + id_, found = anon_map.get_anon(self) + if found: + return (id_, cls) try: dispatcher = cls.__dict__["_generated_cache_key_traversal"] @@ -1030,27 +1031,6 @@ def _resolve_name_for_compare(element, name, anon_map, **kw): return name -class anon_map(dict): - """A map that creates new keys for missing key access. - - Produces an incrementing sequence given a series of unique keys. - - This is similar to the compiler prefix_anon_map class although simpler. - - Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which - is otherwise usually used for this type of operation. - - """ - - def __init__(self): - self.index = 0 - - def __missing__(self, key): - self[key] = val = str(self.index) - self.index += 1 - return val - - class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): __slots__ = "stack", "cache", "anon_map" diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 7bc88a14b..2a6691fc8 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -16,6 +16,7 @@ is pytest. import abc import configparser import logging +import os import re import sys @@ -370,6 +371,20 @@ def _monkeypatch_cdecimal(options, file_config): @post +def __ensure_cext(opt, file_config): + if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1": + from sqlalchemy.util import has_compiled_ext + + try: + has_compiled_ext(raise_=True) + except ImportError as err: + raise AssertionError( + "REQUIRE_SQLALCHEMY_CEXT is set but can't import the " + "cython extensions" + ) from err + + +@post def _init_symbols(options, file_config): from sqlalchemy.testing import config diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index b759490c5..dc08b0494 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1897,7 +1897,7 @@ def repr_tuple_names(names): return "%s, ..., %s" % (", ".join(res[0:3]), res[-1]) -def has_compiled_ext(): +def has_compiled_ext(raise_=False): try: from sqlalchemy.cyextension import collections # noqa F401 from sqlalchemy.cyextension import immutabledict # noqa F401 @@ -1907,4 +1907,6 @@ def has_compiled_ext(): return True except ImportError: + if raise_: + raise return False |