summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql_util.py
blob: 509a7fe3ed00d6341e97f291ca60907f70f7a99f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import sqlalchemy.sql as sql
import sqlalchemy.schema as schema

"""utility functions that build upon SQL and Schema constructs"""


class TableCollection(object):
    def __init__(self):
        self.tables = []
    def add(self, table):
        self.tables.append(table)
        if hasattr(self, '_sorted'):
            del self._sorted
    def sort(self, reverse=False):
        try:
            sorted = self._sorted
        except AttributeError, e:
            self._sorted = self._do_sort()
            sorted = self._sorted
        if reverse:
            x = sorted[:]
            x.reverse()
            return x
        else:
            return sorted
            
    def _do_sort(self):
        import sqlalchemy.orm.topological
        tuples = []
        class TVisitor(schema.SchemaVisitor):
            def visit_foreign_key(self, fkey):
                parent_table = fkey.column.table
                child_table = fkey.parent.table
                tuples.append( ( parent_table, child_table ) )
        vis = TVisitor()        
        for table in self.tables:
            table.accept_schema_visitor(vis)
        sorter = sqlalchemy.orm.topological.QueueDependencySorter( tuples, self.tables )
        head =  sorter.sort()
        sequence = []
        def to_sequence( node, seq=sequence):
            seq.append( node.item )
            for child in node.children:
                to_sequence( child )
        if head is not None:
            to_sequence( head )
        return sequence
        

class TableFinder(TableCollection, sql.ClauseVisitor):
    """given a Clause, locates all the Tables within it into a list."""
    def __init__(self, table, check_columns=False):
        TableCollection.__init__(self)
        self.check_columns = check_columns
        if table is not None:
            table.accept_visitor(self)
    def visit_table(self, table):
        self.tables.append(table)
    def __len__(self):
        return len(self.tables)
    def __getitem__(self, i):
        return self.tables[i]
    def __iter__(self):
        return iter(self.tables)
    def __contains__(self, obj):
        return obj in self.tables
    def __add__(self, obj):
        return self.tables + list(obj)
    def visit_column(self, column):
        if self.check_columns:
            column.table.accept_visitor(self)