summaryrefslogtreecommitdiff
path: root/test/sql/constraints.py
blob: 2908e07da929ef10da54ed59a20567c99cf120af (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy import exceptions
from testlib import *
from testlib import config, engines

class ConstraintTest(TestBase, AssertsExecutionResults):

    def setUp(self):
        global metadata
        metadata = MetaData(testing.db)

    def tearDown(self):
        metadata.drop_all()

    def test_constraint(self):
        employees = Table('employees', metadata,
            Column('id', Integer),
            Column('soc', String(40)),
            Column('name', String(30)),
            PrimaryKeyConstraint('id', 'soc')
            )
        elements = Table('elements', metadata,
            Column('id', Integer),
            Column('stuff', String(30)),
            Column('emp_id', Integer),
            Column('emp_soc', String(40)),
            PrimaryKeyConstraint('id', name='elements_primkey'),
            ForeignKeyConstraint(['emp_id', 'emp_soc'], ['employees.id', 'employees.soc'])
            )
        metadata.create_all()

    def test_circular_constraint(self):
        a = Table("a", metadata,
            Column('id', Integer, primary_key=True),
            Column('bid', Integer),
            ForeignKeyConstraint(["bid"], ["b.id"], name="afk")
            )
        b = Table("b", metadata,
            Column('id', Integer, primary_key=True),
            Column("aid", Integer),
            ForeignKeyConstraint(["aid"], ["a.id"], use_alter=True, name="bfk")
            )
        metadata.create_all()

    def test_circular_constraint_2(self):
        a = Table("a", metadata,
            Column('id', Integer, primary_key=True),
            Column('bid', Integer, ForeignKey("b.id")),
            )
        b = Table("b", metadata,
            Column('id', Integer, primary_key=True),
            Column("aid", Integer, ForeignKey("a.id", use_alter=True, name="bfk")),
            )
        metadata.create_all()

    @testing.unsupported('mysql')
    def test_check_constraint(self):
        foo = Table('foo', metadata,
            Column('id', Integer, primary_key=True),
            Column('x', Integer),
            Column('y', Integer),
            CheckConstraint('x>y'))
        bar = Table('bar', metadata,
            Column('id', Integer, primary_key=True),
            Column('x', Integer, CheckConstraint('x>7')),
            Column('z', Integer)
            )

        metadata.create_all()
        foo.insert().execute(id=1,x=9,y=5)
        try:
            foo.insert().execute(id=2,x=5,y=9)
            assert False
        except exceptions.SQLError:
            assert True

        bar.insert().execute(id=1,x=10)
        try:
            bar.insert().execute(id=2,x=5)
            assert False
        except exceptions.SQLError:
            assert True

    def test_unique_constraint(self):
        foo = Table('foo', metadata,
            Column('id', Integer, primary_key=True),
            Column('value', String(30), unique=True))
        bar = Table('bar', metadata,
            Column('id', Integer, primary_key=True),
            Column('value', String(30)),
            Column('value2', String(30)),
            UniqueConstraint('value', 'value2', name='uix1')
            )
        metadata.create_all()
        foo.insert().execute(id=1, value='value1')
        foo.insert().execute(id=2, value='value2')
        bar.insert().execute(id=1, value='a', value2='a')
        bar.insert().execute(id=2, value='a', value2='b')
        try:
            foo.insert().execute(id=3, value='value1')
            assert False
        except exceptions.SQLError:
            assert True
        try:
            bar.insert().execute(id=3, value='a', value2='b')
            assert False
        except exceptions.SQLError:
            assert True

    def test_index_create(self):
        employees = Table('employees', metadata,
                          Column('id', Integer, primary_key=True),
                          Column('first_name', String(30)),
                          Column('last_name', String(30)),
                          Column('email_address', String(30)))
        employees.create()

        i = Index('employee_name_index',
                  employees.c.last_name, employees.c.first_name)
        i.create()
        assert i in employees.indexes

        i2 = Index('employee_email_index',
                   employees.c.email_address, unique=True)
        i2.create()
        assert i2 in employees.indexes

    def test_index_create_camelcase(self):
        """test that mixed-case index identifiers are legal"""
        employees = Table('companyEmployees', metadata,
                          Column('id', Integer, primary_key=True),
                          Column('firstName', String(30)),
                          Column('lastName', String(30)),
                          Column('emailAddress', String(30)))

        employees.create()

        i = Index('employeeNameIndex',
                  employees.c.lastName, employees.c.firstName)
        i.create()

        i = Index('employeeEmailIndex',
                  employees.c.emailAddress, unique=True)
        i.create()

        # Check that the table is useable. This is mostly for pg,
        # which can be somewhat sticky with mixed-case identifiers
        employees.insert().execute(firstName='Joe', lastName='Smith', id=0)
        ss = employees.select().execute().fetchall()
        assert ss[0].firstName == 'Joe'
        assert ss[0].lastName == 'Smith'

    def test_index_create_inline(self):
        """Test indexes defined with tables"""

        events = Table('events', metadata,
                       Column('id', Integer, primary_key=True),
                       Column('name', String(30), index=True, unique=True),
                       Column('location', String(30), index=True),
                       Column('sport', String(30)),
                       Column('announcer', String(30)),
                       Column('winner', String(30)))

        Index('sport_announcer', events.c.sport, events.c.announcer, unique=True)
        Index('idx_winners', events.c.winner)

        index_names = [ ix.name for ix in events.indexes ]
        assert 'ix_events_name' in index_names
        assert 'ix_events_location' in index_names
        assert 'sport_announcer' in index_names
        assert 'idx_winners' in index_names
        assert len(index_names) == 4

        capt = []
        connection = testing.db.connect()
        # TODO: hacky, put a real connection proxy in
        ex = connection._Connection__execute_raw
        def proxy(context):
            capt.append(context.statement)
            capt.append(repr(context.parameters))
            ex(context)
        connection._Connection__execute_raw = proxy
        schemagen = testing.db.dialect.schemagenerator(testing.db.dialect, connection)
        schemagen.traverse(events)

        assert capt[0].strip().startswith('CREATE TABLE events')

        s = set([capt[x].strip() for x in [2,4,6,8]])

        assert s == set([
            'CREATE UNIQUE INDEX ix_events_name ON events (name)',
            'CREATE INDEX ix_events_location ON events (location)',
            'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)',
            'CREATE INDEX idx_winners ON events (winner)'
            ])

        # verify that the table is functional
        events.insert().execute(id=1, name='hockey finals', location='rink',
                                sport='hockey', announcer='some canadian',
                                winner='sweden')
        ss = events.select().execute().fetchall()


class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
    class accum(object):
        def __init__(self):
            self.statements = []
        def __call__(self, sql, *a, **kw):
            self.statements.append(sql)
        def __contains__(self, substring):
            for s in self.statements:
                if substring in s:
                    return True
            return False
        def __str__(self):
            return '\n'.join([repr(x) for x in self.statements])
        def clear(self):
            del self.statements[:]

    def setUp(self):
        self.sql = self.accum()
        opts = config.db_opts.copy()
        opts['strategy'] = 'mock'
        opts['executor'] = self.sql
        self.engine = engines.testing_engine(options=opts)


    def _test_deferrable(self, constraint_factory):
        meta = MetaData(self.engine)
        t = Table('tbl', meta,
                  Column('a', Integer),
                  Column('b', Integer),
                  constraint_factory(deferrable=True))
        t.create()
        assert 'DEFERRABLE' in self.sql, self.sql
        assert 'NOT DEFERRABLE' not in self.sql, self.sql
        self.sql.clear()
        meta.clear()

        t = Table('tbl', meta,
                  Column('a', Integer),
                  Column('b', Integer),
                  constraint_factory(deferrable=False))
        t.create()
        assert 'NOT DEFERRABLE' in self.sql
        self.sql.clear()
        meta.clear()

        t = Table('tbl', meta,
                  Column('a', Integer),
                  Column('b', Integer),
                  constraint_factory(deferrable=True, initially='IMMEDIATE'))
        t.create()
        assert 'NOT DEFERRABLE' not in self.sql
        assert 'INITIALLY IMMEDIATE' in self.sql
        self.sql.clear()
        meta.clear()

        t = Table('tbl', meta,
                  Column('a', Integer),
                  Column('b', Integer),
                  constraint_factory(deferrable=True, initially='DEFERRED'))
        t.create()

        assert 'NOT DEFERRABLE' not in self.sql
        assert 'INITIALLY DEFERRED' in self.sql, self.sql

    def test_deferrable_pk(self):
        factory = lambda **kw: PrimaryKeyConstraint('a', **kw)
        self._test_deferrable(factory)

    def test_deferrable_table_fk(self):
        factory = lambda **kw: ForeignKeyConstraint(['b'], ['tbl.a'], **kw)
        self._test_deferrable(factory)

    def test_deferrable_column_fk(self):
        meta = MetaData(self.engine)
        t = Table('tbl', meta,
                  Column('a', Integer),
                  Column('b', Integer,
                         ForeignKey('tbl.a', deferrable=True,
                                    initially='DEFERRED')))
        t.create()
        assert 'DEFERRABLE' in self.sql, self.sql
        assert 'INITIALLY DEFERRED' in self.sql, self.sql

    def test_deferrable_unique(self):
        factory = lambda **kw: UniqueConstraint('b', **kw)
        self._test_deferrable(factory)

    def test_deferrable_table_check(self):
        factory = lambda **kw: CheckConstraint('a < b', **kw)
        self._test_deferrable(factory)

    def test_deferrable_column_check(self):
        meta = MetaData(self.engine)
        t = Table('tbl', meta,
                  Column('a', Integer),
                  Column('b', Integer,
                         CheckConstraint('a < b',
                                         deferrable=True,
                                         initially='DEFERRED')))
        t.create()
        assert 'DEFERRABLE' in self.sql, self.sql
        assert 'INITIALLY DEFERRED' in self.sql, self.sql


if __name__ == "__main__":
    testenv.main()