diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-03-24 23:55:21 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-03-24 23:55:21 +0000 |
commit | bade0092d13de5c28b9bcccdaaa5b9b6e00e1ed2 (patch) | |
tree | 1c3eca8743dc3fdc52f5f2fa73338b24d69d9080 | |
parent | dde64666607698c887775c6f3704e242a413dbac (diff) | |
download | sqlalchemy-bade0092d13de5c28b9bcccdaaa5b9b6e00e1ed2.tar.gz |
removed AbstractClauseProcessor, merged its copy-and-visit behavior into ClauseVisitor
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 115 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 144 | ||||
-rw-r--r-- | test/sql/generative.py | 33 |
3 files changed, 159 insertions, 133 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 9954811d6..d4163b73b 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -147,102 +147,7 @@ class ColumnsInClause(visitors.ClauseVisitor): if self.selectable.c.get(column.key) is column: self.result = True -class AbstractClauseProcessor(object): - """Traverse and copy a ClauseElement, replacing selected elements based on rules. - - This class implements its own visit-and-copy strategy but maintains the - same public interface as visitors.ClauseVisitor. - - The convert_element() method receives the *un-copied* version of each element. - It can return a new element or None for no change. If None, the element - will be cloned afterwards and added to the new structure. Note this is the - opposite behavior of visitors.traverse(clone=True), where visitors receive - the cloned element so that it can be mutated. - """ - - __traverse_options__ = {'column_collections':False} - - def __init__(self, stop_on=None): - self.stop_on = stop_on - - def convert_element(self, elem): - """Define the *conversion* method for this ``AbstractClauseProcessor``.""" - - raise NotImplementedError() - - def chain(self, visitor): - # chaining AbstractClauseProcessor and other ClauseVisitor - # objects separately. All the ACP objects are chained on - # their convert_element() method whereas regular visitors - # chain on their visit_XXX methods. - if isinstance(visitor, AbstractClauseProcessor): - attr = '_next_acp' - else: - attr = '_next' - - tail = self - while getattr(tail, attr, None) is not None: - tail = getattr(tail, attr) - setattr(tail, attr, visitor) - return self - - def copy_and_process(self, list_): - """Copy the given list to a new list, with each element traversed individually.""" - - list_ = list(list_) - stop_on = util.Set(self.stop_on or []) - cloned = {} - for i in range(0, len(list_)): - list_[i] = self._traverse(list_[i], stop_on, cloned, _clone_toplevel=True) - return list_ - - def _convert_element(self, elem, stop_on, cloned): - v = self - while v is not None: - newelem = v.convert_element(elem) - if newelem: - stop_on.add(newelem) - return newelem - v = getattr(v, '_next_acp', None) - - if elem not in cloned: - # the full traversal will only make a clone of a particular element - # once. - cloned[elem] = elem._clone() - return cloned[elem] - - def traverse(self, elem, clone=True): - if not clone: - raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True") - - return self._traverse(elem, util.Set(self.stop_on or []), {}, _clone_toplevel=True) - - def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False): - if elem in stop_on: - return elem - - if _clone_toplevel: - elem = self._convert_element(elem, stop_on, cloned) - if elem in stop_on: - return elem - - def clone(element): - return self._convert_element(element, stop_on, cloned) - elem._copy_internals(clone=clone) - - v = getattr(self, '_next', None) - while v is not None: - meth = getattr(v, "visit_%s" % elem.__visit_name__, None) - if meth: - meth(elem) - v = getattr(v, '_next', None) - - for e in elem.get_children(**self.__traverse_options__): - if e not in stop_on: - self._traverse(e, stop_on, cloned) - return elem - -class ClauseAdapter(AbstractClauseProcessor): +class ClauseAdapter(visitors.ClauseVisitor): """Given a clause (like as in a WHERE criterion), locate columns which are embedded within a given selectable, and changes those columns to be that of the selectable. @@ -270,13 +175,21 @@ class ClauseAdapter(AbstractClauseProcessor): s.c.col1 == table2.c.col1 """ + __traverse_options__ = {'column_collections':False} + def __init__(self, selectable, include=None, exclude=None, equivalents=None): - AbstractClauseProcessor.__init__(self, [selectable]) + self.__traverse_options__ = self.__traverse_options__.copy() + self.__traverse_options__['stop_on'] = [selectable] self.selectable = selectable self.include = include self.exclude = exclude self.equivalents = equivalents - + + def traverse(self, obj, clone=True): + if not clone: + raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True") + return visitors.ClauseVisitor.traverse(self, obj, clone=True) + def copy_and_chain(self, adapter): """create a copy of this adapter and chain to the given adapter. @@ -289,14 +202,14 @@ class ClauseAdapter(AbstractClauseProcessor): if adapter is None: return self - if hasattr(self, '_next_acp') or hasattr(self, '_next'): + if hasattr(self, '_next'): raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)") ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents) - ca._next_acp = adapter + ca._next = adapter return ca - def convert_element(self, col): + def before_clone(self, col): if isinstance(col, expression.FromClause): if self.selectable.is_derived_from(col): return self.selectable diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index bb63ab09c..57dfb4b96 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -1,7 +1,9 @@ +from sqlalchemy import util + class ClauseVisitor(object): """Traverses and visits ``ClauseElement`` structures. - Calls visit_XXX() methods dynamically generated for each particular + Calls visit_XXX() methods for each particular ``ClauseElement`` subclass encountered. Traversal of a hierarchy of ``ClauseElements`` is achieved via the ``traverse()`` method, which is passed the lead @@ -25,19 +27,18 @@ class ClauseVisitor(object): __traverse_options__ = {} def traverse_single(self, obj, **kwargs): - meth = getattr(self, "visit_%s" % obj.__visit_name__, None) - if meth: - return meth(obj, **kwargs) - - def traverse_chained(self, obj, **kwargs): - v = self - while v is not None: - meth = getattr(self, "visit_%s" % obj.__visit_name__, None) + """visit a single element, without traversing its child elements.""" + + for v in self._iterate_visitors: + meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: - meth(obj, **kwargs) - v = getattr(v, '_next', None) + return meth(obj, **kwargs) + + traverse_chained = traverse_single def iterate(self, obj): + """traverse the given expression structure, and return an iterator of all elements.""" + stack = [obj] traversal = [] while len(stack) > 0: @@ -48,39 +49,118 @@ class ClauseVisitor(object): stack.append(c) def traverse(self, obj, clone=False): + """traverse the given expression structure. + + Returns the structure given, or a copy of the structure if + clone=True. + When the copy operation takes place, the before_clone() method + will receive each element before it is copied. If the method + returns a non-None value, the return value is taken as the + "copied" element and traversal will not descend further. + + The visit_XXX() methods receive the element *after* it's been + copied. To compare an element to another regardless of + one element being a cloned copy of the original, the + '_cloned_set' attribute of ClauseElement can be used for the compare, + i.e.:: + + original in copied._cloned_set + + + """ if clone: - cloned = {} - def do_clone(obj): - # the full traversal will only make a clone of a particular element - # once. - if obj not in cloned: - cloned[obj] = obj._clone() - return cloned[obj] + return self._cloned_traversal(obj) + else: + return self._non_cloned_traversal(obj) + + def copy_and_process(self, list_): + """Apply cloned traversal to the given list of elements, and return the new list.""" + + return [self._cloned_traversal(x) for x in list_] + + def before_clone(self, elem): + """receive pre-copied elements during a cloning traversal. + + If the method returns a new element, the element is used + instead of creating a simple copy of the element. Traversal + will halt on the newly returned element if it is re-encountered. + """ + return None + + def _clone_element(self, elem, stop_on, cloned): + for v in self._iterate_visitors: + newelem = v.before_clone(elem) + if newelem: + stop_on.add(newelem) + return newelem + + if elem not in cloned: + # the full traversal will only make a clone of a particular element + # once. + cloned[elem] = elem._clone() + return cloned[elem] - obj = do_clone(obj) + def _cloned_traversal(self, obj): + """a recursive traversal which creates copies of elements, returning the new structure.""" + + stop_on = self.__traverse_options__.get('stop_on', []) + return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True) + + def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False): + if elem in stop_on: + return elem + + if _clone_toplevel: + elem = self._clone_element(elem, stop_on, cloned) + if elem in stop_on: + return elem + + def clone(element): + return self._clone_element(element, stop_on, cloned) + elem._copy_internals(clone=clone) + + for v in self._iterate_visitors: + meth = getattr(v, "visit_%s" % elem.__visit_name__, None) + if meth: + meth(elem) + + for e in elem.get_children(**self.__traverse_options__): + if e not in stop_on: + self._cloned_traversal_impl(e, stop_on, cloned) + return elem + + def _non_cloned_traversal(self, obj): + """a non-recursive, non-cloning traversal.""" + stack = [obj] traversal = [] while len(stack) > 0: t = stack.pop() traversal.insert(0, t) - if clone: - t._copy_internals(clone=do_clone) for c in t.get_children(**self.__traverse_options__): stack.append(c) for target in traversal: - v = self - while v is not None: + for v in self._iterate_visitors: meth = getattr(v, "visit_%s" % target.__visit_name__, None) if meth: meth(target) - v = getattr(v, '_next', None) return obj + def _iterate_visitors(self): + """iterate through this visitor and each 'chained' visitor.""" + + v = self + while v is not None: + yield v + v = getattr(v, '_next', None) + _iterate_visitors = property(_iterate_visitors) + def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. - the chained visitor will receive all visit events after this one.""" + the chained visitor will receive all visit events after this one. + """ tail = self while getattr(tail, '_next', None) is not None: tail = tail._next @@ -96,14 +176,16 @@ class NoColumnVisitor(ClauseVisitor): __traverse_options__ = {'column_collections':False} + def traverse(clause, **kwargs): + """traverse the given clause, applying visit functions passed in as keyword arguments.""" + clone = kwargs.pop('clone', False) class Vis(ClauseVisitor): __traverse_options__ = kwargs.pop('traverse_options', {}) - def __getattr__(self, key): - if key in kwargs: - return kwargs[key] - else: - return None - return Vis().traverse(clause, clone=clone) + vis = Vis() + for key in kwargs: + if key.startswith('visit_'): + setattr(vis, key, kwargs[key]) + return vis.traverse(clause, clone=clone) diff --git a/test/sql/generative.py b/test/sql/generative.py index 3d7c88972..0994491d9 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -92,7 +92,7 @@ class TraversalTest(TestBase, AssertsExecutionResults): s2 = vis.traverse(struct, clone=True) assert struct == s2 assert not struct.is_other(s2) - + def test_no_clone(self): struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) @@ -430,7 +430,38 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL): "(SELECT foo.col1 AS col1, foo.col2 AS col2, foo.col3 AS col3 FROM "\ "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo LIMIT 5 OFFSET 10) AS anon_1 "\ "LEFT OUTER JOIN table1 AS bar ON anon_1.col1 = bar.col1") + + def test_recursive(self): + metadata = MetaData() + a = Table('a', metadata, + Column('id', Integer, primary_key=True)) + b = Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + ) + c = Table('c', metadata, + Column('id', Integer, primary_key=True), + Column('bid', Integer, ForeignKey('b.id')), + ) + + d = Table('d', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + ) + u = union( + a.join(b).select().apply_labels(), + a.join(d).select().apply_labels() + ).alias() + + self.assert_compile( + sql_util.ClauseAdapter(u).traverse(select([c.c.bid]).where(c.c.bid==u.c.b_aid)), + "SELECT c.bid "\ + "FROM c, (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid "\ + "FROM a JOIN b ON a.id = b.aid UNION SELECT a.id AS a_id, d.id AS d_id, d.aid AS d_aid "\ + "FROM a JOIN d ON a.id = d.aid) AS anon_1 "\ + "WHERE c.bid = anon_1.b_aid" + ) class SelectTest(TestBase, AssertsCompiledSQL): """tests the generative capability of Select""" |