summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/cyextension/util.pyx42
-rw-r--r--lib/sqlalchemy/sql/_py_util.py59
-rw-r--r--lib/sqlalchemy/sql/base.py27
-rw-r--r--lib/sqlalchemy/sql/elements.py11
-rw-r--r--lib/sqlalchemy/sql/traversals.py38
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py15
-rw-r--r--lib/sqlalchemy/util/langhelpers.py4
7 files changed, 137 insertions, 59 deletions
diff --git a/lib/sqlalchemy/cyextension/util.pyx b/lib/sqlalchemy/cyextension/util.pyx
index ac15ff9de..62ca960b3 100644
--- a/lib/sqlalchemy/cyextension/util.pyx
+++ b/lib/sqlalchemy/cyextension/util.pyx
@@ -41,3 +41,45 @@ def _distill_raw_params(object params):
return [params]
else:
raise exc.ArgumentError("mapping or sequence expected for parameters")
+
+cdef class prefix_anon_map(dict):
+ def __missing__(self, str key):
+ cdef str derived
+ cdef int anonymous_counter
+ cdef dict self_dict = self
+
+ derived = key.split(" ", 1)[1]
+
+ anonymous_counter = self_dict.get(derived, 1)
+ self_dict[derived] = anonymous_counter + 1
+ value = f"{derived}_{anonymous_counter}"
+ self_dict[key] = value
+ return value
+
+
+cdef class cache_anon_map(dict):
+ cdef int _index
+
+ def __init__(self):
+ self._index = 0
+
+ def get_anon(self, obj):
+ cdef long idself
+ cdef str id_
+ cdef dict self_dict = self
+
+ idself = id(obj)
+ if idself in self_dict:
+ return self_dict[idself], True
+ else:
+ id_ = self.__missing__(idself)
+ return id_, False
+
+ def __missing__(self, key):
+ cdef str val
+ cdef dict self_dict = self
+
+ self_dict[key] = val = str(self._index)
+ self._index += 1
+ return val
+
diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py
new file mode 100644
index 000000000..ceb637609
--- /dev/null
+++ b/lib/sqlalchemy/sql/_py_util.py
@@ -0,0 +1,59 @@
+# sql/_py_util.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+class prefix_anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Considers keys of the form "<ident> <name>" to produce
+ new symbols "<name>_<index>", where "index" is an incrementing integer
+ corresponding to <name>.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __missing__(self, key):
+ (ident, derived) = key.split(" ", 1)
+ anonymous_counter = self.get(derived, 1)
+ self[derived] = anonymous_counter + 1
+ value = f"{derived}_{anonymous_counter}"
+ self[key] = value
+ return value
+
+
+class cache_anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Produces an incrementing sequence given a series of unique keys.
+
+ This is similar to the compiler prefix_anon_map class although simpler.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ _index = 0
+
+ def get_anon(self, object_):
+
+ idself = id(object_)
+ if idself in self:
+ return self[idself], True
+ else:
+ # inline of __missing__
+ self[idself] = id_ = str(self._index)
+ self._index += 1
+
+ return id_, False
+
+ def __missing__(self, key):
+ self[key] = val = str(self._index)
+ self._index += 1
+ return val
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 4165751ca..b5a20830d 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -30,6 +30,12 @@ from .. import util
from ..util import HasMemoized
from ..util import hybridmethod
+try:
+ from sqlalchemy.cyextension.util import prefix_anon_map # noqa
+except ImportError:
+ from ._py_util import prefix_anon_map # noqa
+
+
coercions = None
elements = None
type_api = None
@@ -1012,27 +1018,6 @@ class Executable(roles.StatementRole, Generative):
return self._execution_options
-class prefix_anon_map(dict):
- """A map that creates new keys for missing key access.
-
- Considers keys of the form "<ident> <name>" to produce
- new symbols "<name>_<index>", where "index" is an incrementing integer
- corresponding to <name>.
-
- Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
- is otherwise usually used for this type of operation.
-
- """
-
- def __missing__(self, key):
- (ident, derived) = key.split(" ", 1)
- anonymous_counter = self.get(derived, 1)
- self[derived] = anonymous_counter + 1
- value = derived + "_" + str(anonymous_counter)
- self[key] = value
- return value
-
-
class SchemaEventTarget:
"""Base class for elements that are the targets of :class:`.DDLEvents`
events.
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 00270c9b5..08c993820 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -1659,14 +1659,9 @@ class BindParameter(roles.InElementRole, ColumnElement):
anon_map[NO_CACHE] = True
return None
- idself = id(self)
- if idself in anon_map:
- return (anon_map[idself], self.__class__)
- else:
- # inline of
- # id_ = anon_map[idself]
- anon_map[idself] = id_ = str(anon_map.index)
- anon_map.index += 1
+ id_, found = anon_map.get_anon(self)
+ if found:
+ return (id_, self.__class__)
if bindparams is not None:
bindparams.append(self)
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index d58b5c2bb..22398e7c1 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -12,6 +12,12 @@ from .. import util
from ..inspection import inspect
from ..util import HasMemoized
+try:
+ from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
+except ImportError:
+ from ._py_util import cache_anon_map as anon_map # noqa
+
+
SKIP_TRAVERSE = util.symbol("skip_traverse")
COMPARE_FAILED = False
COMPARE_SUCCEEDED = True
@@ -177,16 +183,11 @@ class HasCacheKey:
"""
- idself = id(self)
cls = self.__class__
- if idself in anon_map:
- return (anon_map[idself], cls)
- else:
- # inline of
- # id_ = anon_map[idself]
- anon_map[idself] = id_ = str(anon_map.index)
- anon_map.index += 1
+ id_, found = anon_map.get_anon(self)
+ if found:
+ return (id_, cls)
try:
dispatcher = cls.__dict__["_generated_cache_key_traversal"]
@@ -1030,27 +1031,6 @@ def _resolve_name_for_compare(element, name, anon_map, **kw):
return name
-class anon_map(dict):
- """A map that creates new keys for missing key access.
-
- Produces an incrementing sequence given a series of unique keys.
-
- This is similar to the compiler prefix_anon_map class although simpler.
-
- Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
- is otherwise usually used for this type of operation.
-
- """
-
- def __init__(self):
- self.index = 0
-
- def __missing__(self, key):
- self[key] = val = str(self.index)
- self.index += 1
- return val
-
-
class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
__slots__ = "stack", "cache", "anon_map"
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 7bc88a14b..2a6691fc8 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -16,6 +16,7 @@ is pytest.
import abc
import configparser
import logging
+import os
import re
import sys
@@ -370,6 +371,20 @@ def _monkeypatch_cdecimal(options, file_config):
@post
+def __ensure_cext(opt, file_config):
+ if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1":
+ from sqlalchemy.util import has_compiled_ext
+
+ try:
+ has_compiled_ext(raise_=True)
+ except ImportError as err:
+ raise AssertionError(
+ "REQUIRE_SQLALCHEMY_CEXT is set but can't import the "
+ "cython extensions"
+ ) from err
+
+
+@post
def _init_symbols(options, file_config):
from sqlalchemy.testing import config
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index b759490c5..dc08b0494 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -1897,7 +1897,7 @@ def repr_tuple_names(names):
return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
-def has_compiled_ext():
+def has_compiled_ext(raise_=False):
try:
from sqlalchemy.cyextension import collections # noqa F401
from sqlalchemy.cyextension import immutabledict # noqa F401
@@ -1907,4 +1907,6 @@ def has_compiled_ext():
return True
except ImportError:
+ if raise_:
+ raise
return False