diff options
Diffstat (limited to 'lib/sqlalchemy/sql/visitors.py')
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 270 |
1 files changed, 136 insertions, 134 deletions
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 87fe36944..70c4dc133 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -32,6 +32,11 @@ from .. import util from ..util import langhelpers from ..util import symbol +try: + from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa +except ImportError: + from ._py_util import cache_anon_map as anon_map # noqa + __all__ = [ "iterate", "traverse_using", @@ -39,88 +44,77 @@ __all__ = [ "cloned_traverse", "replacement_traverse", "Traversible", - "TraversibleType", "ExternalTraversal", "InternalTraversal", ] -def _generate_compiler_dispatch(cls): - """Generate a _compiler_dispatch() external traversal on classes with a - __visit_name__ attribute. - - """ - visit_name = cls.__visit_name__ - - if "_compiler_dispatch" in cls.__dict__: - # class has a fixed _compiler_dispatch() method. - # copy it to "original" so that we can get it back if - # sqlalchemy.ext.compiles overrides it. - cls._original_compiler_dispatch = cls._compiler_dispatch - return - - if not isinstance(visit_name, str): - raise exc.InvalidRequestError( - "__visit_name__ on class %s must be a string at the class level" - % cls.__name__ - ) - - name = "visit_%s" % visit_name - getter = operator.attrgetter(name) - - def _compiler_dispatch(self, visitor, **kw): - """Look for an attribute named "visit_<visit_name>" on the - visitor, and call it with the same kw params. - - """ - try: - meth = getter(visitor) - except AttributeError as err: - return visitor.visit_unsupported_compilation(self, err, **kw) - - else: - return meth(self, **kw) - - cls._compiler_dispatch = ( - cls._original_compiler_dispatch - ) = _compiler_dispatch +class Traversible: + """Base class for visitable objects.""" + __slots__ = () -class TraversibleType(type): - """Metaclass which assigns dispatch attributes to various kinds of - "visitable" classes. + __visit_name__: str - Attributes include: + def __init_subclass__(cls) -> None: + if "__visit_name__" in cls.__dict__: + cls._generate_compiler_dispatch() + super().__init_subclass__() - * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``. - This is called "external traversal" because the caller of each visit() - method is responsible for sub-traversing the inner elements of each - object. This is appropriate for string compilers and other traversals - that need to call upon the inner elements in a specific pattern. + @classmethod + def _generate_compiler_dispatch(cls): + """Assign dispatch attributes to various kinds of + "visitable" classes. - * internal traversal collections ``_children_traversal``, - ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from - an optional ``_traverse_internals`` collection of symbols which comes - from the :class:`.InternalTraversal` list of symbols. This is called - "internal traversal" MARKMARK + Attributes include: - """ + * The ``_compiler_dispatch`` method, corresponding to + ``__visit_name__``. This is called "external traversal" because the + caller of each visit() method is responsible for sub-traversing the + inner elements of each object. This is appropriate for string + compilers and other traversals that need to call upon the inner + elements in a specific pattern. - def __init__(cls, clsname, bases, clsdict): - if clsname != "Traversible": - if "__visit_name__" in clsdict: - _generate_compiler_dispatch(cls) + * internal traversal collections ``_children_traversal``, + ``_cache_key_traversal``, ``_copy_internals_traversal``, generated + from an optional ``_traverse_internals`` collection of symbols which + comes from the :class:`.InternalTraversal` list of symbols. This is + called "internal traversal". - super(TraversibleType, cls).__init__(clsname, bases, clsdict) + """ + visit_name = cls.__visit_name__ + + if "_compiler_dispatch" in cls.__dict__: + # class has a fixed _compiler_dispatch() method. + # copy it to "original" so that we can get it back if + # sqlalchemy.ext.compiles overrides it. + cls._original_compiler_dispatch = cls._compiler_dispatch + return + + if not isinstance(visit_name, str): + raise exc.InvalidRequestError( + f"__visit_name__ on class {cls.__name__} must be a string " + "at the class level" + ) + name = "visit_%s" % visit_name + getter = operator.attrgetter(name) -class Traversible(metaclass=TraversibleType): - """Base class for visitable objects, applies the - :class:`.visitors.TraversibleType` metaclass. + def _compiler_dispatch(self, visitor, **kw): + """Look for an attribute named "visit_<visit_name>" on the + visitor, and call it with the same kw params. - """ + """ + try: + meth = getter(visitor) + except AttributeError as err: + return visitor.visit_unsupported_compilation(self, err, **kw) + else: + return meth(self, **kw) - __slots__ = () + cls._compiler_dispatch = ( + cls._original_compiler_dispatch + ) = _compiler_dispatch def __class_getitem__(cls, key): # allow generic classes in py3.9+ @@ -159,48 +153,90 @@ class Traversible(metaclass=TraversibleType): ) -class _InternalTraversalType(type): - def __init__(cls, clsname, bases, clsdict): - if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"): - lookup = {} - for key, sym in clsdict.items(): - if key.startswith("dp_"): - visit_key = key.replace("dp_", "visit_") - sym_name = sym.name - assert sym_name not in lookup, sym_name - lookup[sym] = lookup[sym_name] = visit_key - if hasattr(cls, "_dispatch_lookup"): - lookup.update(cls._dispatch_lookup) - cls._dispatch_lookup = lookup +class _HasTraversalDispatch: + r"""Define infrastructure for the :class:`.InternalTraversal` class. - super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict) + .. versionadded:: 2.0 + """ -def _generate_dispatcher(visitor, internal_dispatch, method_name): - names = [] - for attrname, visit_sym in internal_dispatch: - meth = visitor.dispatch(visit_sym) - if meth: - visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym] - names.append((attrname, visit_name)) - - code = ( - (" return [\n") - + ( - ", \n".join( - " (%r, self.%s, visitor.%s)" - % (attrname, attrname, visit_name) - for attrname, visit_name in names + def __init_subclass__(cls) -> None: + cls._generate_traversal_dispatch() + super().__init_subclass__() + + def dispatch(self, visit_symbol): + """Given a method from :class:`._HasTraversalDispatch`, return the + corresponding method on a subclass. + + """ + name = self._dispatch_lookup[visit_symbol] + return getattr(self, name, None) + + def run_generated_dispatch( + self, target, internal_dispatch, generate_dispatcher_name + ): + try: + dispatcher = target.__class__.__dict__[generate_dispatcher_name] + except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. + dispatcher = self.generate_dispatch( + target.__class__, internal_dispatch, generate_dispatcher_name + ) + return dispatcher(target, self) + + def generate_dispatch( + self, target_cls, internal_dispatch, generate_dispatcher_name + ): + dispatcher = self._generate_dispatcher( + internal_dispatch, generate_dispatcher_name + ) + # assert isinstance(target_cls, type) + setattr(target_cls, generate_dispatcher_name, dispatcher) + return dispatcher + + @classmethod + def _generate_traversal_dispatch(cls): + lookup = {} + clsdict = cls.__dict__ + for key, sym in clsdict.items(): + if key.startswith("dp_"): + visit_key = key.replace("dp_", "visit_") + sym_name = sym.name + assert sym_name not in lookup, sym_name + lookup[sym] = lookup[sym_name] = visit_key + if hasattr(cls, "_dispatch_lookup"): + lookup.update(cls._dispatch_lookup) + cls._dispatch_lookup = lookup + + def _generate_dispatcher(self, internal_dispatch, method_name): + names = [] + for attrname, visit_sym in internal_dispatch: + meth = self.dispatch(visit_sym) + if meth: + visit_name = ExtendedInternalTraversal._dispatch_lookup[ + visit_sym + ] + names.append((attrname, visit_name)) + + code = ( + (" return [\n") + + ( + ", \n".join( + " (%r, self.%s, visitor.%s)" + % (attrname, attrname, visit_name) + for attrname, visit_name in names + ) ) + + ("\n ]\n") ) - + ("\n ]\n") - ) - meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" - # print(meth_text) - return langhelpers._exec_code_in_env(meth_text, {}, method_name) + meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" + return langhelpers._exec_code_in_env(meth_text, {}, method_name) -class InternalTraversal(metaclass=_InternalTraversalType): +class InternalTraversal(_HasTraversalDispatch): r"""Defines visitor symbols used for internal traversal. The :class:`.InternalTraversal` class is used in two ways. One is that @@ -239,39 +275,6 @@ class InternalTraversal(metaclass=_InternalTraversalType): """ - def dispatch(self, visit_symbol): - """Given a method from :class:`.InternalTraversal`, return the - corresponding method on a subclass. - - """ - name = self._dispatch_lookup[visit_symbol] - return getattr(self, name, None) - - def run_generated_dispatch( - self, target, internal_dispatch, generate_dispatcher_name - ): - try: - dispatcher = target.__class__.__dict__[generate_dispatcher_name] - except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). - # this block will generate any remaining dispatchers. - dispatcher = self.generate_dispatch( - target.__class__, internal_dispatch, generate_dispatcher_name - ) - return dispatcher(target, self) - - def generate_dispatch( - self, target_cls, internal_dispatch, generate_dispatcher_name - ): - dispatcher = _generate_dispatcher( - self, internal_dispatch, generate_dispatcher_name - ) - # assert isinstance(target_cls, type) - setattr(target_cls, generate_dispatcher_name, dispatcher) - return dispatcher - dp_has_cache_key = symbol("HC") """Visit a :class:`.HasCacheKey` object.""" @@ -623,7 +626,6 @@ class ReplacingExternalTraversal(CloningExternalTraversal): # backwards compatibility Visitable = Traversible -VisitableType = TraversibleType ClauseVisitor = ExternalTraversal CloningVisitor = CloningExternalTraversal ReplacingCloningVisitor = ReplacingExternalTraversal |