summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql_util.py')
-rw-r--r--lib/sqlalchemy/sql_util.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py
index bfbcff554..4c6cd4d07 100644
--- a/lib/sqlalchemy/sql_util.py
+++ b/lib/sqlalchemy/sql_util.py
@@ -107,3 +107,51 @@ class Aliasizer(sql.ClauseVisitor):
binary.left = self.get_alias(binary.left.table).corresponding_column(binary.left)
if isinstance(binary.right, schema.Column) and self.tables.has_key(binary.right.table):
binary.right = self.get_alias(binary.right.table).corresponding_column(binary.right)
+
+
+class ClauseAdapter(sql.ClauseVisitor):
+ """given a clause (like as in a WHERE criterion), locates columns which 'correspond' to a given selectable,
+ and changes those columns to be that of the selectable.
+
+ such as:
+
+ table1 = Table('sometable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+ table2 = Table('someothertable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+
+ condition = table1.c.col1 == table2.c.col1
+
+ and make an alias of table1:
+
+ s = table1.alias('foo')
+
+ calling condition.accept_visitor(ClauseAdapter(s)) converts condition to read:
+
+ s.c.col1 == table2.c.col1
+
+ """
+ def __init__(self, selectable):
+ self.selectable = selectable
+ def visit_binary(self, binary):
+ if isinstance(binary.left, sql.ColumnElement):
+ col = self.selectable.corresponding_column(binary.left, raiseerr=False, keys_ok=False)
+ if col is not None:
+ binary.left = col
+ if isinstance(binary.right, sql.ColumnElement):
+ col = self.selectable.corresponding_column(binary.right, raiseerr=False, keys_ok=False)
+ if col is not None:
+ binary.right = col
+
+class ColumnsInClause(sql.ClauseVisitor):
+ """given a selectable, visits clauses and determines if any columns from the clause are in the selectable"""
+ def __init__(self, selectable):
+ self.selectable = selectable
+ self.result = False
+ def visit_column(self, column):
+ if self.selectable.c.get(column.key) is column:
+ self.result = True