summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/expression.py9
-rw-r--r--lib/sqlalchemy/sql/util.py83
-rw-r--r--lib/sqlalchemy/sql/visitors.py37
3 files changed, 65 insertions, 64 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index f4611de6d..3d95948cb 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1621,6 +1621,15 @@ class FromClause(Selectable):
from sqlalchemy.sql.util import ClauseAdapter
return ClauseAdapter(alias).traverse(self, clone=True)
+ def correspond_on_equivalents(self, column, equivalents):
+ col = self.corresponding_column(column, require_embedded=True)
+ if col is None and col in equivalents:
+ for equiv in equivalents[col]:
+ nc = self.corresponding_column(equiv, require_embedded=True)
+ if nc:
+ return nc
+ return col
+
def corresponding_column(self, column, require_embedded=False):
"""Given a ``ColumnElement``, return the exported ``ColumnElement``
object from this ``Selectable`` which corresponds to that
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index d4163b73b..8ed561e5f 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -93,46 +93,51 @@ def reduce_columns(columns, *clauses):
return expression.ColumnSet(columns.difference(omit))
-def row_adapter(from_, to, equivalent_columns=None):
- """create a row adapter between two selectables.
+class AliasedRow(object):
+
+ def __init__(self, row, map):
+ # AliasedRow objects don't nest, so un-nest
+ # if another AliasedRow was passed
+ if isinstance(row, AliasedRow):
+ self.row = row.row
+ else:
+ self.row = row
+ self.map = map
+
+ def __contains__(self, key):
+ return self.map[key] in self.row
- The returned adapter is a class that can be instantiated repeatedly for any number
- of rows; this is an inexpensive process. However, the creation of the row
- adapter class itself *is* fairly expensive so caching should be used to prevent
- repeated calls to this function.
- """
+ def has_key(self, key):
+ return key in self
- map = {}
- for c in to.c:
- corr = from_.corresponding_column(c)
- if corr:
- map[c] = corr
- elif equivalent_columns:
- if c in equivalent_columns:
- for c2 in equivalent_columns[c]:
- corr = from_.corresponding_column(c2)
- if corr:
- map[c] = corr
- break
-
- class AliasedRow(object):
- def __init__(self, row):
- self.row = row
- def __contains__(self, key):
- if key in map:
- return map[key] in self.row
- else:
- return key in self.row
- def has_key(self, key):
- return key in self
- def __getitem__(self, key):
- if key in map:
- key = map[key]
- return self.row[key]
- def keys(self):
- return map.keys()
- AliasedRow.map = map
- return AliasedRow
+ def __getitem__(self, key):
+ return self.row[self.map[key]]
+
+ def keys(self):
+ return self.row.keys()
+
+def row_adapter(from_, equivalent_columns=None):
+ """create a row adapter against a selectable."""
+
+ if equivalent_columns is None:
+ equivalent_columns = {}
+
+ def locate_col(col):
+ c = from_.corresponding_column(col)
+ if c:
+ return c
+ elif col in equivalent_columns:
+ for c2 in equivalent_columns[col]:
+ corr = from_.corresponding_column(c2)
+ if corr:
+ return corr
+ return col
+
+ map = util.PopulateDict(locate_col)
+
+ def adapt(row):
+ return AliasedRow(row, map)
+ return adapt
class ColumnsInClause(visitors.ClauseVisitor):
"""Given a selectable, visit clauses and determine if any columns
@@ -189,7 +194,7 @@ class ClauseAdapter(visitors.ClauseVisitor):
if not clone:
raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True")
return visitors.ClauseVisitor.traverse(self, obj, clone=True)
-
+
def copy_and_chain(self, adapter):
"""create a copy of this adapter and chain to the given adapter.
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 09d5a0982..7eccc9b89 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -37,17 +37,17 @@ class ClauseVisitor(object):
traverse_chained = traverse_single
def iterate(self, obj):
- """traverse the given expression structure, and return an iterator of all elements."""
+ """traverse the given expression structure, returning an iterator of all elements."""
stack = [obj]
- traversal = []
- while len(stack) > 0:
+ traversal = util.deque()
+ while stack:
t = stack.pop()
- yield t
- traversal.insert(0, t)
+ traversal.appendleft(t)
for c in t.get_children(**self.__traverse_options__):
stack.append(c)
-
+ return iter(traversal)
+
def traverse(self, obj, clone=False):
"""traverse and visit the given expression structure.
@@ -119,32 +119,19 @@ class ClauseVisitor(object):
def clone(element):
return self._clone_element(element, stop_on, cloned)
elem._copy_internals(clone=clone)
-
- for v in self._iterate_visitors:
- meth = getattr(v, "visit_%s" % elem.__visit_name__, None)
- if meth:
- meth(elem)
+
+ self.traverse_single(elem)
for e in elem.get_children(**self.__traverse_options__):
if e not in stop_on:
self._cloned_traversal_impl(e, stop_on, cloned)
return elem
-
+
def _non_cloned_traversal(self, obj):
"""a non-recursive, non-cloning traversal."""
-
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- for target in traversal:
- for v in self._iterate_visitors:
- meth = getattr(v, "visit_%s" % target.__visit_name__, None)
- if meth:
- meth(target)
+
+ for target in self.iterate(obj):
+ self.traverse_single(target)
return obj
def _iterate_visitors(self):