diff options
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 283 |
1 files changed, 171 insertions, 112 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d299982cf..944a68def 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,4 +1,4 @@ -from sqlalchemy import exceptions, schema, topological, util, sql +from sqlalchemy import exc, schema, topological, util, sql from sqlalchemy.sql import expression, operators, visitors from itertools import chain @@ -8,43 +8,57 @@ def sort_tables(tables, reverse=False): """sort a collection of Table objects in order of their foreign-key dependency.""" tuples = [] - class TVisitor(schema.SchemaVisitor): - def visit_foreign_key(_self, fkey): - if fkey.use_alter: - return - parent_table = fkey.column.table - if parent_table in tables: - child_table = fkey.parent.table - tuples.append( ( parent_table, child_table ) ) - vis = TVisitor() + def visit_foreign_key(fkey): + if fkey.use_alter: + return + parent_table = fkey.column.table + if parent_table in tables: + child_table = fkey.parent.table + tuples.append( ( parent_table, child_table ) ) + for table in tables: - vis.traverse(table) + visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) sequence = topological.sort(tuples, tables) if reverse: return util.reversed(sequence) else: return sequence -def find_tables(clause, check_columns=False, include_aliases=False): +def search(clause, target): + if not clause: + return False + for elem in visitors.iterate(clause, {'column_collections':False}): + if elem is target: + return True + else: + return False + +def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False): """locate Table objects within the given expression.""" tables = [] - kwargs = {} + _visitors = {} + + def visit_something(elem): + tables.append(elem) + + if include_selects: + _visitors['select'] = _visitors['compound_select'] = visit_something + + if include_joins: + _visitors['join'] = visit_something + if include_aliases: - def visit_alias(alias): - tables.append(alias) - kwargs['visit_alias'] = visit_alias + _visitors['alias'] = visit_something if check_columns: def visit_column(column): tables.append(column.table) - kwargs['visit_column'] = visit_column + _visitors['column'] = visit_column - def visit_table(table): - tables.append(table) - kwargs['visit_table'] = visit_table + _visitors['table'] = visit_something - visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs) + visitors.traverse(clause, {'column_collections':False}, _visitors) return tables def find_columns(clause): @@ -53,7 +67,7 @@ def find_columns(clause): cols = util.Set() def visit_column(col): cols.add(col) - visitors.traverse(clause, visit_column=visit_column) + visitors.traverse(clause, {}, {'column':visit_column}) return cols def join_condition(a, b, ignore_nonexistent_tables=False): @@ -72,7 +86,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False): for fk in b.foreign_keys: try: col = fk.get_referent(a) - except exceptions.NoReferencedTableError: + except exc.NoReferencedTableError: if ignore_nonexistent_tables: continue else: @@ -81,27 +95,26 @@ def join_condition(a, b, ignore_nonexistent_tables=False): if col: crit.append(col == fk.parent) constraints.add(fk.constraint) - if a is not b: for fk in a.foreign_keys: try: col = fk.get_referent(b) - except exceptions.NoReferencedTableError: + except exc.NoReferencedTableError: if ignore_nonexistent_tables: continue else: raise - + if col: crit.append(col == fk.parent) constraints.add(fk.constraint) if len(crit) == 0: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Can't find any foreign key relationships " "between '%s' and '%s'" % (a.description, b.description)) elif len(constraints) > 1: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Can't determine join between '%s' and '%s'; " "tables have more than one foreign key " "constraint relationship between them. " @@ -111,7 +124,70 @@ def join_condition(a, b, ignore_nonexistent_tables=False): return (crit[0]) else: return sql.and_(*crit) + +class Annotated(object): + """clones a ClauseElement and applies an 'annotations' dictionary. + + Unlike regular clones, this clone also mimics __hash__() and + __cmp__() of the original element so that it takes its place + in hashed collections. + A reference to the original element is maintained, for the important + reason of keeping its hash value current. When GC'ed, the + hash value may be reused, causing conflicts. + + """ + def __new__(cls, *args): + if not args: + return object.__new__(cls) + else: + element, values = args + return object.__new__( + type.__new__(type, "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {}) + ) + + def __init__(self, element, values): + self.__dict__ = element.__dict__.copy() + self.__element = element + self._annotations = values + + def _annotate(self, values): + _values = self._annotations.copy() + _values.update(values) + clone = self.__class__.__new__(self.__class__) + clone.__dict__ = self.__dict__.copy() + clone._annotations = _values + return clone + + def __hash__(self): + return hash(self.__element) + + def __cmp__(self, other): + return cmp(hash(self.__element), hash(other)) + +def splice_joins(left, right, stop_on=None): + if left is None: + return right + + stack = [(right, None)] + + adapter = ClauseAdapter(left) + ret = None + while stack: + (right, prevright) = stack.pop() + if isinstance(right, expression.Join) and right is not stop_on: + right = right._clone() + right._reset_exported() + right.onclause = adapter.traverse(right.onclause) + stack.append((right.left, right)) + else: + right = adapter.traverse(right) + if prevright: + prevright.left = right + if not ret: + ret = right + + return ret def reduce_columns(columns, *clauses): """given a list of columns, return a 'reduced' set based on natural equivalents. @@ -151,7 +227,7 @@ def reduce_columns(columns, *clauses): omit.add(c) break for clause in clauses: - visitors.traverse(clause, visit_binary=visit_binary) + visitors.traverse(clause, {}, {'binary':visit_binary}) return expression.ColumnSet(columns.difference(omit)) @@ -159,7 +235,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re """traverse an expression and locate binary criterion pairs.""" if consider_as_foreign_keys and consider_as_referenced_keys: - raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'") + raise exc.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'") def visit_binary(binary): if not any_operator and binary.operator != operators.eq: @@ -184,7 +260,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) pairs = [] - visitors.traverse(expression, visit_binary=visit_binary) + visitors.traverse(expression, {}, {'binary':visit_binary}) return pairs def folded_equivalents(join, equivs=None): @@ -195,15 +271,15 @@ def folded_equivalents(join, equivs=None): This function is used by Join.select(fold_equivalents=True). TODO: deprecate ? - """ + """ if equivs is None: equivs = util.Set() def visit_binary(binary): if binary.operator == operators.eq and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) - visitors.traverse(join.onclause, visit_binary=visit_binary) + visitors.traverse(join.onclause, {}, {'binary':visit_binary}) collist = [] if isinstance(join.left, expression.Join): left = folded_equivalents(join.left, equivs) @@ -246,43 +322,8 @@ class AliasedRow(object): def keys(self): return self.row.keys() -def row_adapter(from_, equivalent_columns=None): - """create a row adapter callable against a selectable.""" - - if equivalent_columns is None: - equivalent_columns = {} - - def locate_col(col): - c = from_.corresponding_column(col) - if c: - return c - elif col in equivalent_columns: - for c2 in equivalent_columns[col]: - corr = from_.corresponding_column(c2) - if corr: - return corr - return col - - map = util.PopulateDict(locate_col) - - def adapt(row): - return AliasedRow(row, map) - return adapt - -class ColumnsInClause(visitors.ClauseVisitor): - """Given a selectable, visit clauses and determine if any columns - from the clause are in the selectable. - """ - - def __init__(self, selectable): - self.selectable = selectable - self.result = False - - def visit_column(self, column): - if self.selectable.c.get(column.key) is column: - self.result = True -class ClauseAdapter(visitors.ClauseVisitor): +class ClauseAdapter(visitors.ReplacingCloningVisitor): """Given a clause (like as in a WHERE criterion), locate columns which are embedded within a given selectable, and changes those columns to be that of the selectable. @@ -308,58 +349,76 @@ class ClauseAdapter(visitors.ClauseVisitor): condition to read:: s.c.col1 == table2.c.col1 - """ - - __traverse_options__ = {'column_collections':False} - def __init__(self, selectable, include=None, exclude=None, equivalents=None): - self.__traverse_options__ = self.__traverse_options__.copy() - self.__traverse_options__['stop_on'] = [selectable] + """ + def __init__(self, selectable, equivalents=None, include=None, exclude=None): + self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]} self.selectable = selectable self.include = include self.exclude = exclude - self.equivalents = equivalents - - def traverse(self, obj, clone=True): - if not clone: - raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True") - return visitors.ClauseVisitor.traverse(self, obj, clone=True) - - def copy_and_chain(self, adapter): - """create a copy of this adapter and chain to the given adapter. - - currently this adapter must be unchained to start, raises - an exception if it's already chained. - - Does not modify the given adapter. - """ + self.equivalents = equivalents or {} - if adapter is None: - return self + def _corresponding_column(self, col, require_embedded): + newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded) - if hasattr(self, '_next'): - raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)") - - ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents) - ca._next = adapter - return ca + if not newcol and col in self.equivalents: + for equiv in self.equivalents[col]: + newcol = self.selectable.corresponding_column(equiv, require_embedded=require_embedded) + if newcol: + return newcol + return newcol - def before_clone(self, col): + def replace(self, col): if isinstance(col, expression.FromClause): if self.selectable.is_derived_from(col): return self.selectable + if not isinstance(col, expression.ColumnElement): return None - if self.include is not None: - if col not in self.include: - return None - if self.exclude is not None: - if col in self.exclude: - return None - newcol = self.selectable.corresponding_column(col, require_embedded=True) - if newcol is None and self.equivalents is not None and col in self.equivalents: - for equiv in self.equivalents[col]: - newcol = self.selectable.corresponding_column(equiv, require_embedded=True) - if newcol: - return newcol - return newcol + + if self.include and col not in self.include: + return None + elif self.exclude and col in self.exclude: + return None + + return self._corresponding_column(col, True) + +class ColumnAdapter(ClauseAdapter): + + def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None): + ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) + if chain_to: + self.chain(chain_to) + self.columns = util.PopulateDict(self._locate_col) + + def wrap(self, adapter): + ac = self.__class__.__new__(self.__class__) + ac.__dict__ = self.__dict__.copy() + ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col) + ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause) + ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list) + ac.columns = util.PopulateDict(ac._locate_col) + return ac + + adapt_clause = ClauseAdapter.traverse + adapt_list = ClauseAdapter.copy_and_process + + def _wrap(self, local, wrapped): + def locate(col): + col = local(col) + return wrapped(col) + return locate + + def _locate_col(self, col): + c = self._corresponding_column(col, False) + if not c: + c = self.adapt_clause(col) + + # anonymize labels in case they have a hardcoded name + if isinstance(c, expression._Label): + c = c.label(None) + return c + + def adapted_row(self, row): + return AliasedRow(row, self.columns) + |