diff options
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 143 |
1 files changed, 44 insertions, 99 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 3e2d4ec31..d3e89d57e 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -3,105 +3,50 @@ from sqlalchemy.sql import expression, visitors """Utility functions that build upon SQL and Schema constructs.""" -# TODO: replace with plain list. break out sorting funcs into module-level funcs -class TableCollection(object): - def __init__(self, tables=None): - self.tables = tables or [] - - def __len__(self): - return len(self.tables) - - def __getitem__(self, i): - return self.tables[i] - - def __iter__(self): - return iter(self.tables) - - def __contains__(self, obj): - return obj in self.tables - - def __add__(self, obj): - return self.tables + list(obj) - - def add(self, table): - self.tables.append(table) - if hasattr(self, '_sorted'): - del self._sorted - - def sort(self, reverse=False): - try: - sorted = self._sorted - except AttributeError, e: - self._sorted = self._do_sort() - sorted = self._sorted - if reverse: - x = sorted[:] - x.reverse() - return x - else: - return sorted - - def _do_sort(self): - tuples = [] - class TVisitor(schema.SchemaVisitor): - def visit_foreign_key(_self, fkey): - if fkey.use_alter: - return - parent_table = fkey.column.table - if parent_table in self: - child_table = fkey.parent.table - tuples.append( ( parent_table, child_table ) ) - vis = TVisitor() - for table in self.tables: - vis.traverse(table) - sorter = topological.QueueDependencySorter( tuples, self.tables ) - head = sorter.sort() - sequence = [] - def to_sequence( node, seq=sequence): - seq.append( node.item ) - for child in node.children: - to_sequence( child ) - if head is not None: - to_sequence( head ) - return sequence - - -# TODO: replace with plain module-level func -class TableFinder(TableCollection, visitors.NoColumnVisitor): - """locate all Tables within a clause.""" - - def __init__(self, clause, check_columns=False, include_aliases=False): - TableCollection.__init__(self) - self.check_columns = check_columns - self.include_aliases = include_aliases - for clause in util.to_list(clause): - self.traverse(clause) - - def visit_alias(self, alias): - if self.include_aliases: - self.tables.append(alias) - - def visit_table(self, table): - self.tables.append(table) - - def visit_column(self, column): - if self.check_columns: - self.tables.append(column.table) - -class ColumnFinder(visitors.ClauseVisitor): - def __init__(self): - self.columns = util.Set() - - def visit_column(self, c): - self.columns.add(c) - - def __iter__(self): - return iter(self.columns) - -def find_columns(selectable): - cf = ColumnFinder() - cf.traverse(selectable) - return iter(cf) +def sort_tables(tables, reverse=False): + tuples = [] + class TVisitor(schema.SchemaVisitor): + def visit_foreign_key(_self, fkey): + if fkey.use_alter: + return + parent_table = fkey.column.table + if parent_table in tables: + child_table = fkey.parent.table + tuples.append( ( parent_table, child_table ) ) + vis = TVisitor() + for table in tables: + vis.traverse(table) + sequence = topological.QueueDependencySorter( tuples, tables).sort(create_tree=False) + if reverse: + sequence.reverse() + return sequence + +def find_tables(clause, check_columns=False, include_aliases=False): + tables = [] + kwargs = {} + if include_aliases: + def visit_alias(alias): + tables.append(alias) + kwargs['visit_alias'] = visit_alias + + if check_columns: + def visit_column(column): + tables.append(column.table) + kwargs['visit_column'] = visit_column + + def visit_table(table): + tables.append(table) + kwargs['visit_table'] = visit_table + + visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs) + return tables + +def find_columns(clause): + cols = util.Set() + def visit_column(col): + cols.add(col) + visitors.traverse(clause, visit_column=visit_column) + return cols class ColumnsInClause(visitors.ClauseVisitor): """Given a selectable, visit clauses and determine if any columns |