summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py143
1 files changed, 44 insertions, 99 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 3e2d4ec31..d3e89d57e 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -3,105 +3,50 @@ from sqlalchemy.sql import expression, visitors
"""Utility functions that build upon SQL and Schema constructs."""
-# TODO: replace with plain list. break out sorting funcs into module-level funcs
-class TableCollection(object):
- def __init__(self, tables=None):
- self.tables = tables or []
-
- def __len__(self):
- return len(self.tables)
-
- def __getitem__(self, i):
- return self.tables[i]
-
- def __iter__(self):
- return iter(self.tables)
-
- def __contains__(self, obj):
- return obj in self.tables
-
- def __add__(self, obj):
- return self.tables + list(obj)
-
- def add(self, table):
- self.tables.append(table)
- if hasattr(self, '_sorted'):
- del self._sorted
-
- def sort(self, reverse=False):
- try:
- sorted = self._sorted
- except AttributeError, e:
- self._sorted = self._do_sort()
- sorted = self._sorted
- if reverse:
- x = sorted[:]
- x.reverse()
- return x
- else:
- return sorted
-
- def _do_sort(self):
- tuples = []
- class TVisitor(schema.SchemaVisitor):
- def visit_foreign_key(_self, fkey):
- if fkey.use_alter:
- return
- parent_table = fkey.column.table
- if parent_table in self:
- child_table = fkey.parent.table
- tuples.append( ( parent_table, child_table ) )
- vis = TVisitor()
- for table in self.tables:
- vis.traverse(table)
- sorter = topological.QueueDependencySorter( tuples, self.tables )
- head = sorter.sort()
- sequence = []
- def to_sequence( node, seq=sequence):
- seq.append( node.item )
- for child in node.children:
- to_sequence( child )
- if head is not None:
- to_sequence( head )
- return sequence
-
-
-# TODO: replace with plain module-level func
-class TableFinder(TableCollection, visitors.NoColumnVisitor):
- """locate all Tables within a clause."""
-
- def __init__(self, clause, check_columns=False, include_aliases=False):
- TableCollection.__init__(self)
- self.check_columns = check_columns
- self.include_aliases = include_aliases
- for clause in util.to_list(clause):
- self.traverse(clause)
-
- def visit_alias(self, alias):
- if self.include_aliases:
- self.tables.append(alias)
-
- def visit_table(self, table):
- self.tables.append(table)
-
- def visit_column(self, column):
- if self.check_columns:
- self.tables.append(column.table)
-
-class ColumnFinder(visitors.ClauseVisitor):
- def __init__(self):
- self.columns = util.Set()
-
- def visit_column(self, c):
- self.columns.add(c)
-
- def __iter__(self):
- return iter(self.columns)
-
-def find_columns(selectable):
- cf = ColumnFinder()
- cf.traverse(selectable)
- return iter(cf)
+def sort_tables(tables, reverse=False):
+ tuples = []
+ class TVisitor(schema.SchemaVisitor):
+ def visit_foreign_key(_self, fkey):
+ if fkey.use_alter:
+ return
+ parent_table = fkey.column.table
+ if parent_table in tables:
+ child_table = fkey.parent.table
+ tuples.append( ( parent_table, child_table ) )
+ vis = TVisitor()
+ for table in tables:
+ vis.traverse(table)
+ sequence = topological.QueueDependencySorter( tuples, tables).sort(create_tree=False)
+ if reverse:
+ sequence.reverse()
+ return sequence
+
+def find_tables(clause, check_columns=False, include_aliases=False):
+ tables = []
+ kwargs = {}
+ if include_aliases:
+ def visit_alias(alias):
+ tables.append(alias)
+ kwargs['visit_alias'] = visit_alias
+
+ if check_columns:
+ def visit_column(column):
+ tables.append(column.table)
+ kwargs['visit_column'] = visit_column
+
+ def visit_table(table):
+ tables.append(table)
+ kwargs['visit_table'] = visit_table
+
+ visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs)
+ return tables
+
+def find_columns(clause):
+ cols = util.Set()
+ def visit_column(col):
+ cols.add(col)
+ visitors.traverse(clause, visit_column=visit_column)
+ return cols
class ColumnsInClause(visitors.ClauseVisitor):
"""Given a selectable, visit clauses and determine if any columns