summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql_util.py
blob: 10d4495d931e64264373af5c884dc5bdbe250021 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from sqlalchemy import sql, util, schema, topological

"""utility functions that build upon SQL and Schema constructs"""


class TableCollection(object):
    def __init__(self, tables=None):
        self.tables = tables or []
    def __len__(self):
        return len(self.tables)
    def __getitem__(self, i):
        return self.tables[i]
    def __iter__(self):
        return iter(self.tables)
    def __contains__(self, obj):
        return obj in self.tables
    def __add__(self, obj):
        return self.tables + list(obj)
    def add(self, table):
        self.tables.append(table)
        if hasattr(self, '_sorted'):
            del self._sorted
    def sort(self, reverse=False):
        try:
            sorted = self._sorted
        except AttributeError, e:
            self._sorted = self._do_sort()
            sorted = self._sorted
        if reverse:
            x = sorted[:]
            x.reverse()
            return x
        else:
            return sorted
            
    def _do_sort(self):
        tuples = []
        class TVisitor(schema.SchemaVisitor):
            def visit_foreign_key(_self, fkey):
                if fkey.use_alter:
                    return
                parent_table = fkey.column.table
                if parent_table in self:
                    child_table = fkey.parent.table
                    tuples.append( ( parent_table, child_table ) )
        vis = TVisitor()        
        for table in self.tables:
            table.accept_schema_visitor(vis)
        sorter = topological.QueueDependencySorter( tuples, self.tables )
        head =  sorter.sort()
        sequence = []
        def to_sequence( node, seq=sequence):
            seq.append( node.item )
            for child in node.children:
                to_sequence( child )
        if head is not None:
            to_sequence( head )
        return sequence
        

class TableFinder(TableCollection, sql.ClauseVisitor):
    """given a Clause, locates all the Tables within it into a list."""
    def __init__(self, table, check_columns=False):
        TableCollection.__init__(self)
        self.check_columns = check_columns
        if table is not None:
            table.accept_visitor(self)
    def visit_table(self, table):
        self.tables.append(table)
    def visit_column(self, column):
        if self.check_columns:
            column.table.accept_visitor(self)

class ColumnFinder(sql.ClauseVisitor):
    def __init__(self):
        self.columns = util.Set()
    def visit_column(self, c):
        self.columns.add(c)
    def __iter__(self):
        return iter(self.columns)
            
class Aliasizer(sql.ClauseVisitor):
    """converts a table instance within an expression to be an alias of that table."""
    def __init__(self, *tables, **kwargs):
        self.tables = {}
        self.aliases = kwargs.get('aliases', {})
        for t in tables:
            self.tables[t] = t
            if not self.aliases.has_key(t):
                self.aliases[t] = sql.alias(t)
            if isinstance(t, sql.Join):
                for t2 in t.columns:
                    self.tables[t2.table] = t2
                    self.aliases[t2.table] = self.aliases[t]
        self.binary = None
    def get_alias(self, table):
        return self.aliases[table]
    def visit_compound(self, compound):
        self.visit_clauselist(compound)
    def visit_clauselist(self, clist):
        for i in range(0, len(clist.clauses)):
            if isinstance(clist.clauses[i], schema.Column) and self.tables.has_key(clist.clauses[i].table):
                orig = clist.clauses[i]
                clist.clauses[i] = self.get_alias(clist.clauses[i].table).corresponding_column(clist.clauses[i])
    def visit_binary(self, binary):
        if isinstance(binary.left, schema.Column) and self.tables.has_key(binary.left.table):
            binary.left = self.get_alias(binary.left.table).corresponding_column(binary.left)
        if isinstance(binary.right, schema.Column) and self.tables.has_key(binary.right.table):
            binary.right = self.get_alias(binary.right.table).corresponding_column(binary.right)


class ClauseAdapter(sql.ClauseVisitor):
    """given a clause (like as in a WHERE criterion), locates columns which 'correspond' to a given selectable, 
    and changes those columns to be that of the selectable.
    
        such as:
        
        table1 = Table('sometable', metadata, 
            Column('col1', Integer),
            Column('col2', Integer)
            )
        table2 = Table('someothertable', metadata, 
            Column('col1', Integer),
            Column('col2', Integer)
            )
    
        condition = table1.c.col1 == table2.c.col1
        
        and make an alias of table1:
        
        s = table1.alias('foo')
        
        calling condition.accept_visitor(ClauseAdapter(s)) converts condition to read:
        
        s.c.col1 == table2.c.col1
    
    """
    def __init__(self, selectable, include=None, exclude=None, equivalents=None):
        self.selectable = selectable
        self.include = include
        self.exclude = exclude
        self.equivalents = equivalents
    def include_col(self, col):
        if not isinstance(col, sql.ColumnElement):
            return None
        if self.include is not None:
            if col not in self.include:
                return None
        if self.exclude is not None:
            if col in self.exclude:
                return None
        newcol = self.selectable.corresponding_column(col, raiseerr=False, keys_ok=False)
        if newcol is None and self.equivalents is not None and col in self.equivalents:
            newcol = self.selectable.corresponding_column(self.equivalents[col], raiseerr=False, keys_ok=False)
        return newcol
    def visit_binary(self, binary):
        col = self.include_col(binary.left)
        if col is not None:
            binary.left = col
        col = self.include_col(binary.right)
        if col is not None:
            binary.right = col

class ColumnsInClause(sql.ClauseVisitor):
    """given a selectable, visits clauses and determines if any columns from the clause are in the selectable"""
    def __init__(self, selectable):
        self.selectable = selectable
        self.result = False
    def visit_column(self, column):
        if self.selectable.c.get(column.key) is column:
            self.result = True