diff options
Diffstat (limited to 'lib/sqlalchemy/sql/visitors.py')
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 48 |
1 files changed, 30 insertions, 18 deletions
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index b39ec8167..bf1743643 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -29,11 +29,20 @@ from .. import util import operator from .. import exc -__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', - 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate', - 'iterate_depthfirst', 'traverse_using', 'traverse', - 'traverse_depthfirst', - 'cloned_traverse', 'replacement_traverse'] +__all__ = [ + "VisitableType", + "Visitable", + "ClauseVisitor", + "CloningVisitor", + "ReplacingCloningVisitor", + "iterate", + "iterate_depthfirst", + "traverse_using", + "traverse", + "traverse_depthfirst", + "cloned_traverse", + "replacement_traverse", +] class VisitableType(type): @@ -53,8 +62,7 @@ class VisitableType(type): """ def __init__(cls, clsname, bases, clsdict): - if clsname != 'Visitable' and \ - hasattr(cls, '__visit_name__'): + if clsname != "Visitable" and hasattr(cls, "__visit_name__"): _generate_dispatch(cls) super(VisitableType, cls).__init__(clsname, bases, clsdict) @@ -64,7 +72,7 @@ def _generate_dispatch(cls): """Return an optimized visit dispatch function for the cls for use by the compiler. """ - if '__visit_name__' in cls.__dict__: + if "__visit_name__" in cls.__dict__: visit_name = cls.__visit_name__ if isinstance(visit_name, str): # There is an optimization opportunity here because the @@ -79,12 +87,13 @@ def _generate_dispatch(cls): raise exc.UnsupportedCompilationError(visitor, cls) else: return meth(self, **kw) + else: # The optimization opportunity is lost for this case because the # __visit_name__ is not yet a string. As a result, the visit # string has to be recalculated with each compilation. def _compiler_dispatch(self, visitor, **kw): - visit_attr = 'visit_%s' % self.__visit_name__ + visit_attr = "visit_%s" % self.__visit_name__ try: meth = getattr(visitor, visit_attr) except AttributeError: @@ -92,8 +101,7 @@ def _generate_dispatch(cls): else: return meth(self, **kw) - _compiler_dispatch.__doc__ = \ - """Look for an attribute named "visit_" + self.__visit_name__ + _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__ on the visitor, and call it with the same kw params. """ cls._compiler_dispatch = _compiler_dispatch @@ -137,7 +145,7 @@ class ClauseVisitor(object): visitors = {} for name in dir(self): - if name.startswith('visit_'): + if name.startswith("visit_"): visitors[name[6:]] = getattr(self, name) return visitors @@ -148,7 +156,7 @@ class ClauseVisitor(object): v = self while v: yield v - v = getattr(v, '_next', None) + v = getattr(v, "_next", None) def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. @@ -178,7 +186,8 @@ class CloningVisitor(ClauseVisitor): """traverse and visit the given expression structure.""" return cloned_traverse( - obj, self.__traverse_options__, self._visitor_dict) + obj, self.__traverse_options__, self._visitor_dict + ) class ReplacingCloningVisitor(CloningVisitor): @@ -204,6 +213,7 @@ class ReplacingCloningVisitor(CloningVisitor): e = v.replace(elem) if e is not None: return e + return replacement_traverse(obj, self.__traverse_options__, replace) @@ -282,7 +292,7 @@ def cloned_traverse(obj, opts, visitors): modifications by visitors.""" cloned = {} - stop_on = set(opts.get('stop_on', [])) + stop_on = set(opts.get("stop_on", [])) def clone(elem): if elem in stop_on: @@ -306,11 +316,13 @@ def replacement_traverse(obj, opts, replace): replacement by a given replacement function.""" cloned = {} - stop_on = {id(x) for x in opts.get('stop_on', [])} + stop_on = {id(x) for x in opts.get("stop_on", [])} def clone(elem, **kw): - if id(elem) in stop_on or \ - 'no_replacement_traverse' in elem._annotations: + if ( + id(elem) in stop_on + or "no_replacement_traverse" in elem._annotations + ): return elem else: newelem = replace(elem) |