summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py283
1 files changed, 171 insertions, 112 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index d299982cf..944a68def 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -1,4 +1,4 @@
-from sqlalchemy import exceptions, schema, topological, util, sql
+from sqlalchemy import exc, schema, topological, util, sql
from sqlalchemy.sql import expression, operators, visitors
from itertools import chain
@@ -8,43 +8,57 @@ def sort_tables(tables, reverse=False):
"""sort a collection of Table objects in order of their foreign-key dependency."""
tuples = []
- class TVisitor(schema.SchemaVisitor):
- def visit_foreign_key(_self, fkey):
- if fkey.use_alter:
- return
- parent_table = fkey.column.table
- if parent_table in tables:
- child_table = fkey.parent.table
- tuples.append( ( parent_table, child_table ) )
- vis = TVisitor()
+ def visit_foreign_key(fkey):
+ if fkey.use_alter:
+ return
+ parent_table = fkey.column.table
+ if parent_table in tables:
+ child_table = fkey.parent.table
+ tuples.append( ( parent_table, child_table ) )
+
for table in tables:
- vis.traverse(table)
+ visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key})
sequence = topological.sort(tuples, tables)
if reverse:
return util.reversed(sequence)
else:
return sequence
-def find_tables(clause, check_columns=False, include_aliases=False):
+def search(clause, target):
+ if not clause:
+ return False
+ for elem in visitors.iterate(clause, {'column_collections':False}):
+ if elem is target:
+ return True
+ else:
+ return False
+
+def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False):
"""locate Table objects within the given expression."""
tables = []
- kwargs = {}
+ _visitors = {}
+
+ def visit_something(elem):
+ tables.append(elem)
+
+ if include_selects:
+ _visitors['select'] = _visitors['compound_select'] = visit_something
+
+ if include_joins:
+ _visitors['join'] = visit_something
+
if include_aliases:
- def visit_alias(alias):
- tables.append(alias)
- kwargs['visit_alias'] = visit_alias
+ _visitors['alias'] = visit_something
if check_columns:
def visit_column(column):
tables.append(column.table)
- kwargs['visit_column'] = visit_column
+ _visitors['column'] = visit_column
- def visit_table(table):
- tables.append(table)
- kwargs['visit_table'] = visit_table
+ _visitors['table'] = visit_something
- visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs)
+ visitors.traverse(clause, {'column_collections':False}, _visitors)
return tables
def find_columns(clause):
@@ -53,7 +67,7 @@ def find_columns(clause):
cols = util.Set()
def visit_column(col):
cols.add(col)
- visitors.traverse(clause, visit_column=visit_column)
+ visitors.traverse(clause, {}, {'column':visit_column})
return cols
def join_condition(a, b, ignore_nonexistent_tables=False):
@@ -72,7 +86,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
for fk in b.foreign_keys:
try:
col = fk.get_referent(a)
- except exceptions.NoReferencedTableError:
+ except exc.NoReferencedTableError:
if ignore_nonexistent_tables:
continue
else:
@@ -81,27 +95,26 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
if col:
crit.append(col == fk.parent)
constraints.add(fk.constraint)
-
if a is not b:
for fk in a.foreign_keys:
try:
col = fk.get_referent(b)
- except exceptions.NoReferencedTableError:
+ except exc.NoReferencedTableError:
if ignore_nonexistent_tables:
continue
else:
raise
-
+
if col:
crit.append(col == fk.parent)
constraints.add(fk.constraint)
if len(crit) == 0:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Can't find any foreign key relationships "
"between '%s' and '%s'" % (a.description, b.description))
elif len(constraints) > 1:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Can't determine join between '%s' and '%s'; "
"tables have more than one foreign key "
"constraint relationship between them. "
@@ -111,7 +124,70 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
return (crit[0])
else:
return sql.and_(*crit)
+
+class Annotated(object):
+ """clones a ClauseElement and applies an 'annotations' dictionary.
+
+ Unlike regular clones, this clone also mimics __hash__() and
+ __cmp__() of the original element so that it takes its place
+ in hashed collections.
+ A reference to the original element is maintained, for the important
+ reason of keeping its hash value current. When GC'ed, the
+ hash value may be reused, causing conflicts.
+
+ """
+ def __new__(cls, *args):
+ if not args:
+ return object.__new__(cls)
+ else:
+ element, values = args
+ return object.__new__(
+ type.__new__(type, "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {})
+ )
+
+ def __init__(self, element, values):
+ self.__dict__ = element.__dict__.copy()
+ self.__element = element
+ self._annotations = values
+
+ def _annotate(self, values):
+ _values = self._annotations.copy()
+ _values.update(values)
+ clone = self.__class__.__new__(self.__class__)
+ clone.__dict__ = self.__dict__.copy()
+ clone._annotations = _values
+ return clone
+
+ def __hash__(self):
+ return hash(self.__element)
+
+ def __cmp__(self, other):
+ return cmp(hash(self.__element), hash(other))
+
+def splice_joins(left, right, stop_on=None):
+ if left is None:
+ return right
+
+ stack = [(right, None)]
+
+ adapter = ClauseAdapter(left)
+ ret = None
+ while stack:
+ (right, prevright) = stack.pop()
+ if isinstance(right, expression.Join) and right is not stop_on:
+ right = right._clone()
+ right._reset_exported()
+ right.onclause = adapter.traverse(right.onclause)
+ stack.append((right.left, right))
+ else:
+ right = adapter.traverse(right)
+ if prevright:
+ prevright.left = right
+ if not ret:
+ ret = right
+
+ return ret
def reduce_columns(columns, *clauses):
"""given a list of columns, return a 'reduced' set based on natural equivalents.
@@ -151,7 +227,7 @@ def reduce_columns(columns, *clauses):
omit.add(c)
break
for clause in clauses:
- visitors.traverse(clause, visit_binary=visit_binary)
+ visitors.traverse(clause, {}, {'binary':visit_binary})
return expression.ColumnSet(columns.difference(omit))
@@ -159,7 +235,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re
"""traverse an expression and locate binary criterion pairs."""
if consider_as_foreign_keys and consider_as_referenced_keys:
- raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
+ raise exc.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
def visit_binary(binary):
if not any_operator and binary.operator != operators.eq:
@@ -184,7 +260,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
pairs = []
- visitors.traverse(expression, visit_binary=visit_binary)
+ visitors.traverse(expression, {}, {'binary':visit_binary})
return pairs
def folded_equivalents(join, equivs=None):
@@ -195,15 +271,15 @@ def folded_equivalents(join, equivs=None):
This function is used by Join.select(fold_equivalents=True).
TODO: deprecate ?
- """
+ """
if equivs is None:
equivs = util.Set()
def visit_binary(binary):
if binary.operator == operators.eq and binary.left.name == binary.right.name:
equivs.add(binary.right)
equivs.add(binary.left)
- visitors.traverse(join.onclause, visit_binary=visit_binary)
+ visitors.traverse(join.onclause, {}, {'binary':visit_binary})
collist = []
if isinstance(join.left, expression.Join):
left = folded_equivalents(join.left, equivs)
@@ -246,43 +322,8 @@ class AliasedRow(object):
def keys(self):
return self.row.keys()
-def row_adapter(from_, equivalent_columns=None):
- """create a row adapter callable against a selectable."""
-
- if equivalent_columns is None:
- equivalent_columns = {}
-
- def locate_col(col):
- c = from_.corresponding_column(col)
- if c:
- return c
- elif col in equivalent_columns:
- for c2 in equivalent_columns[col]:
- corr = from_.corresponding_column(c2)
- if corr:
- return corr
- return col
-
- map = util.PopulateDict(locate_col)
-
- def adapt(row):
- return AliasedRow(row, map)
- return adapt
-
-class ColumnsInClause(visitors.ClauseVisitor):
- """Given a selectable, visit clauses and determine if any columns
- from the clause are in the selectable.
- """
-
- def __init__(self, selectable):
- self.selectable = selectable
- self.result = False
-
- def visit_column(self, column):
- if self.selectable.c.get(column.key) is column:
- self.result = True
-class ClauseAdapter(visitors.ClauseVisitor):
+class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""Given a clause (like as in a WHERE criterion), locate columns
which are embedded within a given selectable, and changes those
columns to be that of the selectable.
@@ -308,58 +349,76 @@ class ClauseAdapter(visitors.ClauseVisitor):
condition to read::
s.c.col1 == table2.c.col1
- """
-
- __traverse_options__ = {'column_collections':False}
- def __init__(self, selectable, include=None, exclude=None, equivalents=None):
- self.__traverse_options__ = self.__traverse_options__.copy()
- self.__traverse_options__['stop_on'] = [selectable]
+ """
+ def __init__(self, selectable, equivalents=None, include=None, exclude=None):
+ self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]}
self.selectable = selectable
self.include = include
self.exclude = exclude
- self.equivalents = equivalents
-
- def traverse(self, obj, clone=True):
- if not clone:
- raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True")
- return visitors.ClauseVisitor.traverse(self, obj, clone=True)
-
- def copy_and_chain(self, adapter):
- """create a copy of this adapter and chain to the given adapter.
-
- currently this adapter must be unchained to start, raises
- an exception if it's already chained.
-
- Does not modify the given adapter.
- """
+ self.equivalents = equivalents or {}
- if adapter is None:
- return self
+ def _corresponding_column(self, col, require_embedded):
+ newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded)
- if hasattr(self, '_next'):
- raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)")
-
- ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents)
- ca._next = adapter
- return ca
+ if not newcol and col in self.equivalents:
+ for equiv in self.equivalents[col]:
+ newcol = self.selectable.corresponding_column(equiv, require_embedded=require_embedded)
+ if newcol:
+ return newcol
+ return newcol
- def before_clone(self, col):
+ def replace(self, col):
if isinstance(col, expression.FromClause):
if self.selectable.is_derived_from(col):
return self.selectable
+
if not isinstance(col, expression.ColumnElement):
return None
- if self.include is not None:
- if col not in self.include:
- return None
- if self.exclude is not None:
- if col in self.exclude:
- return None
- newcol = self.selectable.corresponding_column(col, require_embedded=True)
- if newcol is None and self.equivalents is not None and col in self.equivalents:
- for equiv in self.equivalents[col]:
- newcol = self.selectable.corresponding_column(equiv, require_embedded=True)
- if newcol:
- return newcol
- return newcol
+
+ if self.include and col not in self.include:
+ return None
+ elif self.exclude and col in self.exclude:
+ return None
+
+ return self._corresponding_column(col, True)
+
+class ColumnAdapter(ClauseAdapter):
+
+ def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None):
+ ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
+ if chain_to:
+ self.chain(chain_to)
+ self.columns = util.PopulateDict(self._locate_col)
+
+ def wrap(self, adapter):
+ ac = self.__class__.__new__(self.__class__)
+ ac.__dict__ = self.__dict__.copy()
+ ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col)
+ ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause)
+ ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list)
+ ac.columns = util.PopulateDict(ac._locate_col)
+ return ac
+
+ adapt_clause = ClauseAdapter.traverse
+ adapt_list = ClauseAdapter.copy_and_process
+
+ def _wrap(self, local, wrapped):
+ def locate(col):
+ col = local(col)
+ return wrapped(col)
+ return locate
+
+ def _locate_col(self, col):
+ c = self._corresponding_column(col, False)
+ if not c:
+ c = self.adapt_clause(col)
+
+ # anonymize labels in case they have a hardcoded name
+ if isinstance(c, expression._Label):
+ c = c.label(None)
+ return c
+
+ def adapted_row(self, row):
+ return AliasedRow(row, self.columns)
+