summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/visitors.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/visitors.py')
-rw-r--r--lib/sqlalchemy/sql/visitors.py270
1 files changed, 136 insertions, 134 deletions
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 87fe36944..70c4dc133 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -32,6 +32,11 @@ from .. import util
from ..util import langhelpers
from ..util import symbol
+try:
+ from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
+except ImportError:
+ from ._py_util import cache_anon_map as anon_map # noqa
+
__all__ = [
"iterate",
"traverse_using",
@@ -39,88 +44,77 @@ __all__ = [
"cloned_traverse",
"replacement_traverse",
"Traversible",
- "TraversibleType",
"ExternalTraversal",
"InternalTraversal",
]
-def _generate_compiler_dispatch(cls):
- """Generate a _compiler_dispatch() external traversal on classes with a
- __visit_name__ attribute.
-
- """
- visit_name = cls.__visit_name__
-
- if "_compiler_dispatch" in cls.__dict__:
- # class has a fixed _compiler_dispatch() method.
- # copy it to "original" so that we can get it back if
- # sqlalchemy.ext.compiles overrides it.
- cls._original_compiler_dispatch = cls._compiler_dispatch
- return
-
- if not isinstance(visit_name, str):
- raise exc.InvalidRequestError(
- "__visit_name__ on class %s must be a string at the class level"
- % cls.__name__
- )
-
- name = "visit_%s" % visit_name
- getter = operator.attrgetter(name)
-
- def _compiler_dispatch(self, visitor, **kw):
- """Look for an attribute named "visit_<visit_name>" on the
- visitor, and call it with the same kw params.
-
- """
- try:
- meth = getter(visitor)
- except AttributeError as err:
- return visitor.visit_unsupported_compilation(self, err, **kw)
-
- else:
- return meth(self, **kw)
-
- cls._compiler_dispatch = (
- cls._original_compiler_dispatch
- ) = _compiler_dispatch
+class Traversible:
+ """Base class for visitable objects."""
+ __slots__ = ()
-class TraversibleType(type):
- """Metaclass which assigns dispatch attributes to various kinds of
- "visitable" classes.
+ __visit_name__: str
- Attributes include:
+ def __init_subclass__(cls) -> None:
+ if "__visit_name__" in cls.__dict__:
+ cls._generate_compiler_dispatch()
+ super().__init_subclass__()
- * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
- This is called "external traversal" because the caller of each visit()
- method is responsible for sub-traversing the inner elements of each
- object. This is appropriate for string compilers and other traversals
- that need to call upon the inner elements in a specific pattern.
+ @classmethod
+ def _generate_compiler_dispatch(cls):
+ """Assign dispatch attributes to various kinds of
+ "visitable" classes.
- * internal traversal collections ``_children_traversal``,
- ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
- an optional ``_traverse_internals`` collection of symbols which comes
- from the :class:`.InternalTraversal` list of symbols. This is called
- "internal traversal" MARKMARK
+ Attributes include:
- """
+ * The ``_compiler_dispatch`` method, corresponding to
+ ``__visit_name__``. This is called "external traversal" because the
+ caller of each visit() method is responsible for sub-traversing the
+ inner elements of each object. This is appropriate for string
+ compilers and other traversals that need to call upon the inner
+ elements in a specific pattern.
- def __init__(cls, clsname, bases, clsdict):
- if clsname != "Traversible":
- if "__visit_name__" in clsdict:
- _generate_compiler_dispatch(cls)
+ * internal traversal collections ``_children_traversal``,
+ ``_cache_key_traversal``, ``_copy_internals_traversal``, generated
+ from an optional ``_traverse_internals`` collection of symbols which
+ comes from the :class:`.InternalTraversal` list of symbols. This is
+ called "internal traversal".
- super(TraversibleType, cls).__init__(clsname, bases, clsdict)
+ """
+ visit_name = cls.__visit_name__
+
+ if "_compiler_dispatch" in cls.__dict__:
+ # class has a fixed _compiler_dispatch() method.
+ # copy it to "original" so that we can get it back if
+ # sqlalchemy.ext.compiles overrides it.
+ cls._original_compiler_dispatch = cls._compiler_dispatch
+ return
+
+ if not isinstance(visit_name, str):
+ raise exc.InvalidRequestError(
+ f"__visit_name__ on class {cls.__name__} must be a string "
+ "at the class level"
+ )
+ name = "visit_%s" % visit_name
+ getter = operator.attrgetter(name)
-class Traversible(metaclass=TraversibleType):
- """Base class for visitable objects, applies the
- :class:`.visitors.TraversibleType` metaclass.
+ def _compiler_dispatch(self, visitor, **kw):
+ """Look for an attribute named "visit_<visit_name>" on the
+ visitor, and call it with the same kw params.
- """
+ """
+ try:
+ meth = getter(visitor)
+ except AttributeError as err:
+ return visitor.visit_unsupported_compilation(self, err, **kw)
+ else:
+ return meth(self, **kw)
- __slots__ = ()
+ cls._compiler_dispatch = (
+ cls._original_compiler_dispatch
+ ) = _compiler_dispatch
def __class_getitem__(cls, key):
# allow generic classes in py3.9+
@@ -159,48 +153,90 @@ class Traversible(metaclass=TraversibleType):
)
-class _InternalTraversalType(type):
- def __init__(cls, clsname, bases, clsdict):
- if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
- lookup = {}
- for key, sym in clsdict.items():
- if key.startswith("dp_"):
- visit_key = key.replace("dp_", "visit_")
- sym_name = sym.name
- assert sym_name not in lookup, sym_name
- lookup[sym] = lookup[sym_name] = visit_key
- if hasattr(cls, "_dispatch_lookup"):
- lookup.update(cls._dispatch_lookup)
- cls._dispatch_lookup = lookup
+class _HasTraversalDispatch:
+ r"""Define infrastructure for the :class:`.InternalTraversal` class.
- super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
+ .. versionadded:: 2.0
+ """
-def _generate_dispatcher(visitor, internal_dispatch, method_name):
- names = []
- for attrname, visit_sym in internal_dispatch:
- meth = visitor.dispatch(visit_sym)
- if meth:
- visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
- names.append((attrname, visit_name))
-
- code = (
- (" return [\n")
- + (
- ", \n".join(
- " (%r, self.%s, visitor.%s)"
- % (attrname, attrname, visit_name)
- for attrname, visit_name in names
+ def __init_subclass__(cls) -> None:
+ cls._generate_traversal_dispatch()
+ super().__init_subclass__()
+
+ def dispatch(self, visit_symbol):
+ """Given a method from :class:`._HasTraversalDispatch`, return the
+ corresponding method on a subclass.
+
+ """
+ name = self._dispatch_lookup[visit_symbol]
+ return getattr(self, name, None)
+
+ def run_generated_dispatch(
+ self, target, internal_dispatch, generate_dispatcher_name
+ ):
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.generate_dispatch(
+ target.__class__, internal_dispatch, generate_dispatcher_name
+ )
+ return dispatcher(target, self)
+
+ def generate_dispatch(
+ self, target_cls, internal_dispatch, generate_dispatcher_name
+ ):
+ dispatcher = self._generate_dispatcher(
+ internal_dispatch, generate_dispatcher_name
+ )
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
+ return dispatcher
+
+ @classmethod
+ def _generate_traversal_dispatch(cls):
+ lookup = {}
+ clsdict = cls.__dict__
+ for key, sym in clsdict.items():
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.name
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+ if hasattr(cls, "_dispatch_lookup"):
+ lookup.update(cls._dispatch_lookup)
+ cls._dispatch_lookup = lookup
+
+ def _generate_dispatcher(self, internal_dispatch, method_name):
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = self.dispatch(visit_sym)
+ if meth:
+ visit_name = ExtendedInternalTraversal._dispatch_lookup[
+ visit_sym
+ ]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
)
+ + ("\n ]\n")
)
- + ("\n ]\n")
- )
- meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
- # print(meth_text)
- return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
-class InternalTraversal(metaclass=_InternalTraversalType):
+class InternalTraversal(_HasTraversalDispatch):
r"""Defines visitor symbols used for internal traversal.
The :class:`.InternalTraversal` class is used in two ways. One is that
@@ -239,39 +275,6 @@ class InternalTraversal(metaclass=_InternalTraversalType):
"""
- def dispatch(self, visit_symbol):
- """Given a method from :class:`.InternalTraversal`, return the
- corresponding method on a subclass.
-
- """
- name = self._dispatch_lookup[visit_symbol]
- return getattr(self, name, None)
-
- def run_generated_dispatch(
- self, target, internal_dispatch, generate_dispatcher_name
- ):
- try:
- dispatcher = target.__class__.__dict__[generate_dispatcher_name]
- except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
- # this block will generate any remaining dispatchers.
- dispatcher = self.generate_dispatch(
- target.__class__, internal_dispatch, generate_dispatcher_name
- )
- return dispatcher(target, self)
-
- def generate_dispatch(
- self, target_cls, internal_dispatch, generate_dispatcher_name
- ):
- dispatcher = _generate_dispatcher(
- self, internal_dispatch, generate_dispatcher_name
- )
- # assert isinstance(target_cls, type)
- setattr(target_cls, generate_dispatcher_name, dispatcher)
- return dispatcher
-
dp_has_cache_key = symbol("HC")
"""Visit a :class:`.HasCacheKey` object."""
@@ -623,7 +626,6 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
# backwards compatibility
Visitable = Traversible
-VisitableType = TraversibleType
ClauseVisitor = ExternalTraversal
CloningVisitor = CloningExternalTraversal
ReplacingCloningVisitor = ReplacingExternalTraversal