summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/sql/util.py115
-rw-r--r--lib/sqlalchemy/sql/visitors.py144
2 files changed, 127 insertions, 132 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 9954811d6..d4163b73b 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -147,102 +147,7 @@ class ColumnsInClause(visitors.ClauseVisitor):
if self.selectable.c.get(column.key) is column:
self.result = True
-class AbstractClauseProcessor(object):
- """Traverse and copy a ClauseElement, replacing selected elements based on rules.
-
- This class implements its own visit-and-copy strategy but maintains the
- same public interface as visitors.ClauseVisitor.
-
- The convert_element() method receives the *un-copied* version of each element.
- It can return a new element or None for no change. If None, the element
- will be cloned afterwards and added to the new structure. Note this is the
- opposite behavior of visitors.traverse(clone=True), where visitors receive
- the cloned element so that it can be mutated.
- """
-
- __traverse_options__ = {'column_collections':False}
-
- def __init__(self, stop_on=None):
- self.stop_on = stop_on
-
- def convert_element(self, elem):
- """Define the *conversion* method for this ``AbstractClauseProcessor``."""
-
- raise NotImplementedError()
-
- def chain(self, visitor):
- # chaining AbstractClauseProcessor and other ClauseVisitor
- # objects separately. All the ACP objects are chained on
- # their convert_element() method whereas regular visitors
- # chain on their visit_XXX methods.
- if isinstance(visitor, AbstractClauseProcessor):
- attr = '_next_acp'
- else:
- attr = '_next'
-
- tail = self
- while getattr(tail, attr, None) is not None:
- tail = getattr(tail, attr)
- setattr(tail, attr, visitor)
- return self
-
- def copy_and_process(self, list_):
- """Copy the given list to a new list, with each element traversed individually."""
-
- list_ = list(list_)
- stop_on = util.Set(self.stop_on or [])
- cloned = {}
- for i in range(0, len(list_)):
- list_[i] = self._traverse(list_[i], stop_on, cloned, _clone_toplevel=True)
- return list_
-
- def _convert_element(self, elem, stop_on, cloned):
- v = self
- while v is not None:
- newelem = v.convert_element(elem)
- if newelem:
- stop_on.add(newelem)
- return newelem
- v = getattr(v, '_next_acp', None)
-
- if elem not in cloned:
- # the full traversal will only make a clone of a particular element
- # once.
- cloned[elem] = elem._clone()
- return cloned[elem]
-
- def traverse(self, elem, clone=True):
- if not clone:
- raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True")
-
- return self._traverse(elem, util.Set(self.stop_on or []), {}, _clone_toplevel=True)
-
- def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False):
- if elem in stop_on:
- return elem
-
- if _clone_toplevel:
- elem = self._convert_element(elem, stop_on, cloned)
- if elem in stop_on:
- return elem
-
- def clone(element):
- return self._convert_element(element, stop_on, cloned)
- elem._copy_internals(clone=clone)
-
- v = getattr(self, '_next', None)
- while v is not None:
- meth = getattr(v, "visit_%s" % elem.__visit_name__, None)
- if meth:
- meth(elem)
- v = getattr(v, '_next', None)
-
- for e in elem.get_children(**self.__traverse_options__):
- if e not in stop_on:
- self._traverse(e, stop_on, cloned)
- return elem
-
-class ClauseAdapter(AbstractClauseProcessor):
+class ClauseAdapter(visitors.ClauseVisitor):
"""Given a clause (like as in a WHERE criterion), locate columns
which are embedded within a given selectable, and changes those
columns to be that of the selectable.
@@ -270,13 +175,21 @@ class ClauseAdapter(AbstractClauseProcessor):
s.c.col1 == table2.c.col1
"""
+ __traverse_options__ = {'column_collections':False}
+
def __init__(self, selectable, include=None, exclude=None, equivalents=None):
- AbstractClauseProcessor.__init__(self, [selectable])
+ self.__traverse_options__ = self.__traverse_options__.copy()
+ self.__traverse_options__['stop_on'] = [selectable]
self.selectable = selectable
self.include = include
self.exclude = exclude
self.equivalents = equivalents
-
+
+ def traverse(self, obj, clone=True):
+ if not clone:
+ raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True")
+ return visitors.ClauseVisitor.traverse(self, obj, clone=True)
+
def copy_and_chain(self, adapter):
"""create a copy of this adapter and chain to the given adapter.
@@ -289,14 +202,14 @@ class ClauseAdapter(AbstractClauseProcessor):
if adapter is None:
return self
- if hasattr(self, '_next_acp') or hasattr(self, '_next'):
+ if hasattr(self, '_next'):
raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)")
ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents)
- ca._next_acp = adapter
+ ca._next = adapter
return ca
- def convert_element(self, col):
+ def before_clone(self, col):
if isinstance(col, expression.FromClause):
if self.selectable.is_derived_from(col):
return self.selectable
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index bb63ab09c..57dfb4b96 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -1,7 +1,9 @@
+from sqlalchemy import util
+
class ClauseVisitor(object):
"""Traverses and visits ``ClauseElement`` structures.
- Calls visit_XXX() methods dynamically generated for each particular
+ Calls visit_XXX() methods for each particular
``ClauseElement`` subclass encountered. Traversal of a
hierarchy of ``ClauseElements`` is achieved via the
``traverse()`` method, which is passed the lead
@@ -25,19 +27,18 @@ class ClauseVisitor(object):
__traverse_options__ = {}
def traverse_single(self, obj, **kwargs):
- meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
- if meth:
- return meth(obj, **kwargs)
-
- def traverse_chained(self, obj, **kwargs):
- v = self
- while v is not None:
- meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+ """visit a single element, without traversing its child elements."""
+
+ for v in self._iterate_visitors:
+ meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
- meth(obj, **kwargs)
- v = getattr(v, '_next', None)
+ return meth(obj, **kwargs)
+
+ traverse_chained = traverse_single
def iterate(self, obj):
+ """traverse the given expression structure, and return an iterator of all elements."""
+
stack = [obj]
traversal = []
while len(stack) > 0:
@@ -48,39 +49,118 @@ class ClauseVisitor(object):
stack.append(c)
def traverse(self, obj, clone=False):
+ """traverse the given expression structure.
+
+ Returns the structure given, or a copy of the structure if
+ clone=True.
+ When the copy operation takes place, the before_clone() method
+ will receive each element before it is copied. If the method
+ returns a non-None value, the return value is taken as the
+ "copied" element and traversal will not descend further.
+
+ The visit_XXX() methods receive the element *after* it's been
+ copied. To compare an element to another regardless of
+ one element being a cloned copy of the original, the
+ '_cloned_set' attribute of ClauseElement can be used for the compare,
+ i.e.::
+
+ original in copied._cloned_set
+
+
+ """
if clone:
- cloned = {}
- def do_clone(obj):
- # the full traversal will only make a clone of a particular element
- # once.
- if obj not in cloned:
- cloned[obj] = obj._clone()
- return cloned[obj]
+ return self._cloned_traversal(obj)
+ else:
+ return self._non_cloned_traversal(obj)
+
+ def copy_and_process(self, list_):
+ """Apply cloned traversal to the given list of elements, and return the new list."""
+
+ return [self._cloned_traversal(x) for x in list_]
+
+ def before_clone(self, elem):
+ """receive pre-copied elements during a cloning traversal.
+
+ If the method returns a new element, the element is used
+ instead of creating a simple copy of the element. Traversal
+ will halt on the newly returned element if it is re-encountered.
+ """
+ return None
+
+ def _clone_element(self, elem, stop_on, cloned):
+ for v in self._iterate_visitors:
+ newelem = v.before_clone(elem)
+ if newelem:
+ stop_on.add(newelem)
+ return newelem
+
+ if elem not in cloned:
+ # the full traversal will only make a clone of a particular element
+ # once.
+ cloned[elem] = elem._clone()
+ return cloned[elem]
- obj = do_clone(obj)
+ def _cloned_traversal(self, obj):
+ """a recursive traversal which creates copies of elements, returning the new structure."""
+
+ stop_on = self.__traverse_options__.get('stop_on', [])
+ return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True)
+
+ def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False):
+ if elem in stop_on:
+ return elem
+
+ if _clone_toplevel:
+ elem = self._clone_element(elem, stop_on, cloned)
+ if elem in stop_on:
+ return elem
+
+ def clone(element):
+ return self._clone_element(element, stop_on, cloned)
+ elem._copy_internals(clone=clone)
+
+ for v in self._iterate_visitors:
+ meth = getattr(v, "visit_%s" % elem.__visit_name__, None)
+ if meth:
+ meth(elem)
+
+ for e in elem.get_children(**self.__traverse_options__):
+ if e not in stop_on:
+ self._cloned_traversal_impl(e, stop_on, cloned)
+ return elem
+
+ def _non_cloned_traversal(self, obj):
+ """a non-recursive, non-cloning traversal."""
+
stack = [obj]
traversal = []
while len(stack) > 0:
t = stack.pop()
traversal.insert(0, t)
- if clone:
- t._copy_internals(clone=do_clone)
for c in t.get_children(**self.__traverse_options__):
stack.append(c)
for target in traversal:
- v = self
- while v is not None:
+ for v in self._iterate_visitors:
meth = getattr(v, "visit_%s" % target.__visit_name__, None)
if meth:
meth(target)
- v = getattr(v, '_next', None)
return obj
+ def _iterate_visitors(self):
+ """iterate through this visitor and each 'chained' visitor."""
+
+ v = self
+ while v is not None:
+ yield v
+ v = getattr(v, '_next', None)
+ _iterate_visitors = property(_iterate_visitors)
+
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
- the chained visitor will receive all visit events after this one."""
+ the chained visitor will receive all visit events after this one.
+ """
tail = self
while getattr(tail, '_next', None) is not None:
tail = tail._next
@@ -96,14 +176,16 @@ class NoColumnVisitor(ClauseVisitor):
__traverse_options__ = {'column_collections':False}
+
def traverse(clause, **kwargs):
+ """traverse the given clause, applying visit functions passed in as keyword arguments."""
+
clone = kwargs.pop('clone', False)
class Vis(ClauseVisitor):
__traverse_options__ = kwargs.pop('traverse_options', {})
- def __getattr__(self, key):
- if key in kwargs:
- return kwargs[key]
- else:
- return None
- return Vis().traverse(clause, clone=clone)
+ vis = Vis()
+ for key in kwargs:
+ if key.startswith('visit_'):
+ setattr(vis, key, kwargs[key])
+ return vis.traverse(clause, clone=clone)