From ed4fc64bb0ac61c27bc4af32962fb129e74a36bf Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 27 Jul 2007 04:08:53 +0000 Subject: merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3. --- lib/sqlalchemy/sql_util.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) (limited to 'lib/sqlalchemy/sql_util.py') diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 9235b9c4e..d91fbe4b5 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -53,7 +53,7 @@ class TableCollection(object): for table in self.tables: vis.traverse(table) sorter = topological.QueueDependencySorter( tuples, self.tables ) - head = sorter.sort() + head = sorter.sort() sequence = [] def to_sequence( node, seq=sequence): seq.append( node.item ) @@ -67,12 +67,12 @@ class TableCollection(object): class TableFinder(TableCollection, sql.NoColumnVisitor): """locate all Tables within a clause.""" - def __init__(self, table, check_columns=False, include_aliases=False): + def __init__(self, clause, check_columns=False, include_aliases=False): TableCollection.__init__(self) self.check_columns = check_columns self.include_aliases = include_aliases - if table is not None: - self.traverse(table) + for clause in util.to_list(clause): + self.traverse(clause) def visit_alias(self, alias): if self.include_aliases: @@ -83,7 +83,7 @@ class TableFinder(TableCollection, sql.NoColumnVisitor): def visit_column(self, column): if self.check_columns: - self.traverse(column.table) + self.tables.append(column.table) class ColumnFinder(sql.ClauseVisitor): def __init__(self): @@ -125,7 +125,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): process the new list. """ - list_ = [o.copy_container() for o in list_] + list_ = list(list_) self.process_list(list_) return list_ @@ -137,7 +137,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): if elem is not None: list_[i] = elem else: - self.traverse(list_[i]) + list_[i] = self.traverse(list_[i], clone=True) def visit_grouping(self, grouping): elem = self.convert_element(grouping.elem) @@ -162,8 +162,24 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): elem = self.convert_element(binary.right) if elem is not None: binary.right = elem - - # TODO: visit_select(). + + def visit_select(self, select): + fr = util.OrderedSet() + for elem in select._froms: + n = self.convert_element(elem) + if n is not None: + fr.add((elem, n)) + select._recorrelate_froms(fr) + + col = [] + for elem in select._raw_columns: + print "RAW COLUMN", elem + n = self.convert_element(elem) + if n is None: + col.append(elem) + else: + col.append(n) + select._raw_columns = col class ClauseAdapter(AbstractClauseProcessor): """Given a clause (like as in a WHERE criterion), locate columns @@ -200,6 +216,9 @@ class ClauseAdapter(AbstractClauseProcessor): self.equivalents = equivalents def convert_element(self, col): + if isinstance(col, sql.FromClause): + if self.selectable.is_derived_from(col): + return self.selectable if not isinstance(col, sql.ColumnElement): return None if self.include is not None: @@ -214,4 +233,9 @@ class ClauseAdapter(AbstractClauseProcessor): newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False) if newcol: return newcol + #if newcol is None: + # self.traverse(col) + # return col return newcol + + -- cgit v1.2.1