summaryrefslogtreecommitdiff
path: root/alembic/operations/schemaobj.py
blob: f0f81057f4ed5b5dd76167000d01c71ef518d1f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from sqlalchemy import schema as sa_schema
from sqlalchemy.types import NULLTYPE, Integer
from ..util.compat import string_types
from .. import util


class SchemaObjects(object):

    def __init__(self, migration_context=None):
        self.migration_context = migration_context

    def primary_key_constraint(self, name, table_name, cols, schema=None):
        m = self.metadata()
        columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
        t = sa_schema.Table(
            table_name, m,
            *columns,
            schema=schema)
        p = sa_schema.PrimaryKeyConstraint(
            *[t.c[n] for n in cols], name=name)
        t.append_constraint(p)
        return p

    def foreign_key_constraint(
        self, name, source, referent,
        local_cols, remote_cols,
        onupdate=None, ondelete=None,
        deferrable=None, source_schema=None,
        referent_schema=None, initially=None,
            match=None, **dialect_kw):
        m = self.metadata()
        if source == referent:
            t1_cols = local_cols + remote_cols
        else:
            t1_cols = local_cols
            sa_schema.Table(
                referent, m,
                *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
                schema=referent_schema)

        t1 = sa_schema.Table(
            source, m,
            *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
            schema=source_schema)

        tname = "%s.%s" % (referent_schema, referent) if referent_schema \
                else referent

        if util.sqla_08:
            # "match" kw unsupported in 0.7
            dialect_kw['match'] = match

        f = sa_schema.ForeignKeyConstraint(local_cols,
                                           ["%s.%s" % (tname, n)
                                            for n in remote_cols],
                                           name=name,
                                           onupdate=onupdate,
                                           ondelete=ondelete,
                                           deferrable=deferrable,
                                           initially=initially,
                                           **dialect_kw
                                           )
        t1.append_constraint(f)

        return f

    def unique_constraint(self, name, source, local_cols, schema=None, **kw):
        t = sa_schema.Table(
            source, self.metadata(),
            *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
            schema=schema)
        kw['name'] = name
        uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
        # TODO: need event tests to ensure the event
        # is fired off here
        t.append_constraint(uq)
        return uq

    def check_constraint(self, name, source, condition, schema=None, **kw):
        t = sa_schema.Table(source, self.metadata(),
                            sa_schema.Column('x', Integer), schema=schema)
        ck = sa_schema.CheckConstraint(condition, name=name, **kw)
        t.append_constraint(ck)
        return ck

    def generic_constraint(self, name, table_name, type_, schema=None, **kw):
        t = self.table(table_name, schema=schema)
        types = {
            'foreignkey': lambda name: sa_schema.ForeignKeyConstraint(
                [], [], name=name),
            'primary': sa_schema.PrimaryKeyConstraint,
            'unique': sa_schema.UniqueConstraint,
            'check': lambda name: sa_schema.CheckConstraint("", name=name),
            None: sa_schema.Constraint
        }
        try:
            const = types[type_]
        except KeyError:
            raise TypeError("'type' can be one of %s" %
                            ", ".join(sorted(repr(x) for x in types)))
        else:
            const = const(name=name)
            t.append_constraint(const)
            return const

    def metadata(self):
        kw = {}
        if self.migration_context is not None and \
                'target_metadata' in self.migration_context.opts:
            mt = self.migration_context.opts['target_metadata']
            if hasattr(mt, 'naming_convention'):
                kw['naming_convention'] = mt.naming_convention
        return sa_schema.MetaData(**kw)

    def table(self, name, *columns, **kw):
        m = self.metadata()
        t = sa_schema.Table(name, m, *columns, **kw)
        for f in t.foreign_keys:
            self._ensure_table_for_fk(m, f)
        return t

    def column(self, name, type_, **kw):
        return sa_schema.Column(name, type_, **kw)

    def index(self, name, tablename, columns, schema=None, **kw):
        t = sa_schema.Table(
            tablename or 'no_table', self.metadata(),
            schema=schema
        )
        idx = sa_schema.Index(
            name,
            *[util.sqla_compat._textual_index_column(t, n) for n in columns],
            **kw)
        return idx

    def _parse_table_key(self, table_key):
        if '.' in table_key:
            tokens = table_key.split('.')
            sname = ".".join(tokens[0:-1])
            tname = tokens[-1]
        else:
            tname = table_key
            sname = None
        return (sname, tname)

    def _ensure_table_for_fk(self, metadata, fk):
        """create a placeholder Table object for the referent of a
        ForeignKey.

        """
        if isinstance(fk._colspec, string_types):
            table_key, cname = fk._colspec.rsplit('.', 1)
            sname, tname = self._parse_table_key(table_key)
            if table_key not in metadata.tables:
                rel_t = sa_schema.Table(tname, metadata, schema=sname)
            else:
                rel_t = metadata.tables[table_key]
            if cname not in rel_t.c:
                rel_t.append_column(sa_schema.Column(cname, NULLTYPE))