summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/informix.py
blob: 39d0ee96a50915c4fd364c504216ef4c1554adfb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
# informix.py
# Copyright (C) 2005,2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com
#
# coding: gbk
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

import datetime

from sqlalchemy import sql, schema, exc, pool, util
from sqlalchemy.sql import compiler
from sqlalchemy.engine import default
from sqlalchemy import types as sqltypes


# for offset

class informix_cursor(object):
    def __init__( self , con ):
        self.__cursor = con.cursor()
        self.rowcount = 0

    def offset( self , n ):
        if n > 0:
            self.fetchmany( n )
            self.rowcount = self.__cursor.rowcount - n
            if self.rowcount < 0:
                self.rowcount = 0
        else:
            self.rowcount = self.__cursor.rowcount

    def execute( self , sql , params ):
        if params is None or len( params ) == 0:
            params = []

        return self.__cursor.execute( sql , params )

    def __getattr__( self , name ):
        if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ):
            return getattr( self.__cursor , name )

class InfoNumeric(sqltypes.Numeric):
    def get_col_spec(self):
        if not self.precision:
            return 'NUMERIC'
        else:
            return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}

class InfoInteger(sqltypes.Integer):
    def get_col_spec(self):
        return "INTEGER"

class InfoSmallInteger(sqltypes.Smallinteger):
    def get_col_spec(self):
        return "SMALLINT"

class InfoDate(sqltypes.Date):
    def get_col_spec( self ):
        return "DATE"

class InfoDateTime(sqltypes.DateTime ):
    def get_col_spec(self):
        return "DATETIME YEAR TO SECOND"

    def bind_processor(self, dialect):
        def process(value):
            if value is not None:
                if value.microsecond:
                    value = value.replace( microsecond = 0 )
            return value
        return process

class InfoTime(sqltypes.Time ):
    def get_col_spec(self):
        return "DATETIME HOUR TO SECOND"

    def bind_processor(self, dialect):
        def process(value):
            if value is not None:
                if value.microsecond:
                    value = value.replace( microsecond = 0 )
            return value
        return process

    def result_processor(self, dialect):
        def process(value):
            if isinstance( value , datetime.datetime ):
                return value.time()
            else:
                return value
        return process

class InfoText(sqltypes.String):
    def get_col_spec(self):
        return "VARCHAR(255)"

class InfoString(sqltypes.String):
    def get_col_spec(self):
        return "VARCHAR(%(length)s)" % {'length' : self.length}

    def bind_processor(self, dialect):
        def process(value):
            if value == '':
                return None
            else:
                return value
        return process

class InfoChar(sqltypes.CHAR):
    def get_col_spec(self):
        return "CHAR(%(length)s)" % {'length' : self.length}

class InfoBinary(sqltypes.Binary):
    def get_col_spec(self):
        return "BYTE"

class InfoBoolean(sqltypes.Boolean):
    default_type = 'NUM'
    def get_col_spec(self):
        return "SMALLINT"

    def result_processor(self, dialect):
        def process(value):
            if value is None:
                return None
            return value and True or False
        return process

    def bind_processor(self, dialect):
        def process(value):
            if value is True:
                return 1
            elif value is False:
                return 0
            elif value is None:
                return None
            else:
                return value and True or False
        return process

colspecs = {
    sqltypes.Integer : InfoInteger,
    sqltypes.Smallinteger : InfoSmallInteger,
    sqltypes.Numeric : InfoNumeric,
    sqltypes.Float : InfoNumeric,
    sqltypes.DateTime : InfoDateTime,
    sqltypes.Date : InfoDate,
    sqltypes.Time: InfoTime,
    sqltypes.String : InfoString,
    sqltypes.Binary : InfoBinary,
    sqltypes.Boolean : InfoBoolean,
    sqltypes.Text : InfoText,
    sqltypes.CHAR: InfoChar,
}


ischema_names = {
    0   : InfoString,       # CHAR
    1   : InfoSmallInteger, # SMALLINT
    2   : InfoInteger,      # INT
    3   : InfoNumeric,      # Float
    3   : InfoNumeric,      # SmallFloat
    5   : InfoNumeric,      # DECIMAL
    6   : InfoInteger,      # Serial
    7   : InfoDate,         # DATE
    8   : InfoNumeric,      # MONEY
    10  : InfoDateTime,     # DATETIME
    11  : InfoBinary,       # BYTE
    12  : InfoText,         # TEXT
    13  : InfoString,       # VARCHAR
    15  : InfoString,       # NCHAR
    16  : InfoString,       # NVARCHAR
    17  : InfoInteger,      # INT8
    18  : InfoInteger,      # Serial8
    43  : InfoString,       # LVARCHAR
    -1  : InfoBinary,       # BLOB
    -1  : InfoText,         # CLOB
}


class InfoExecutionContext(default.DefaultExecutionContext):
    # cursor.sqlerrd
    # 0 - estimated number of rows returned
    # 1 - serial value after insert or ISAM error code
    # 2 - number of rows processed
    # 3 - estimated cost
    # 4 - offset of the error into the SQL statement
    # 5 - rowid after insert
    def post_exec(self):
        if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None:
            self._last_inserted_ids = [self.cursor.sqlerrd[1]]
        elif hasattr( self.compiled , 'offset' ):
            self.cursor.offset( self.compiled.offset )
        super(InfoExecutionContext, self).post_exec()

    def create_cursor( self ):
        return informix_cursor( self.connection.connection )

class InfoDialect(default.DefaultDialect):
    name = 'informix'
    default_paramstyle = 'qmark'
    # for informix 7.31
    max_identifier_length = 18
    supports_simple_order_by_label = False

    def __init__(self, use_ansi=True, **kwargs):
        self.use_ansi = use_ansi
        default.DefaultDialect.__init__(self, **kwargs)

    def dbapi(cls):
        import informixdb
        return informixdb
    dbapi = classmethod(dbapi)

    def is_disconnect(self, e):
        if isinstance(e, self.dbapi.OperationalError):
            return 'closed the connection' in str(e) or 'connection not open' in str(e)
        else:
            return False

    def do_begin(self , connect ):
        cu = connect.cursor()
        cu.execute( 'SET LOCK MODE TO WAIT' )
        #cu.execute( 'SET ISOLATION TO REPEATABLE READ' )

    def type_descriptor(self, typeobj):
        return sqltypes.adapt_type(typeobj, colspecs)

    def create_connect_args(self, url):
        if url.host:
            dsn = '%s@%s' % ( url.database , url.host )
        else:
            dsn = url.database

        if url.username:
            opt = { 'user':url.username , 'password': url.password }
        else:
            opt = {}

        return ([dsn], opt)

    def create_execution_context(self , *args, **kwargs):
        return InfoExecutionContext(self, *args, **kwargs)

    def oid_column_name(self, column):
        return "rowid"

    def table_names(self, connection, schema):
        s = "select tabname from systables"
        return [row[0] for row in connection.execute(s)]

    def has_table(self, connection, table_name, schema=None):
        cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() )
        return bool( cursor.fetchone() is not None )

    def reflecttable(self, connection, table, include_columns):
        c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() )
        rows = c.fetchall()
        if not rows :
            raise exc.NoSuchTableError(table.name)
        else:
            if table.owner is not None:
                if table.owner.lower() in [r[0] for r in rows]:
                    owner = table.owner.lower()
                else:
                    raise AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name))
            else:
                if len(rows)==1:
                    owner = rows[0][0]
                else:
                    raise AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name)

        c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3
                                    where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=?
                                      and t3.tabid = t2.tabid and t3.colno = t1.colno
                                    order by t1.colno""", table.name.lower(), owner )
        rows = c.fetchall()

        if not rows:
            raise exc.NoSuchTableError(table.name)

        for name , colattr , collength , default , colno in rows:
            name = name.lower()
            if include_columns and name not in include_columns:
                continue

            # in 7.31, coltype = 0x000
            #                       ^^-- column type
            #                      ^-- 1 not null , 0 null
            nullable , coltype = divmod( colattr , 256 )
            if coltype not in ( 0 , 13 ) and default:
                default = default.split()[-1]

            if coltype == 0 or coltype == 13: # char , varchar
                coltype = ischema_names.get(coltype, InfoString)(collength)
                if default:
                    default = "'%s'" % default
            elif coltype == 5: # decimal
                precision , scale = ( collength & 0xFF00 ) >> 8 , collength & 0xFF
                if scale == 255:
                    scale = 0
                coltype = InfoNumeric(precision, scale)
            else:
                try:
                    coltype = ischema_names[coltype]
                except KeyError:
                    util.warn("Did not recognize type '%s' of column '%s'" %
                              (coltype, name))
                    coltype = sqltypes.NULLTYPE

            colargs = []
            if default is not None:
                colargs.append(schema.DefaultClause(sql.text(default)))

            table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs))

        # FK
        c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type ,
                                         t4.colname as local_column , t7.tabname as remote_table ,
                                         t6.colname as remote_column
                                    from sysconstraints as t1 , systables as t2 ,
                                         sysindexes as t3 , syscolumns as t4 ,
                                         sysreferences as t5 , syscolumns as t6 , systables as t7 ,
                                         sysconstraints as t8 , sysindexes as t9
                                   where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'R'
                                     and t3.tabid = t2.tabid and t3.idxname = t1.idxname
                                     and t4.tabid = t2.tabid and t4.colno = t3.part1
                                     and t5.constrid = t1.constrid and t8.constrid = t5.primary
                                     and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname
                                     and t7.tabid = t5.ptabid""", table.name.lower(), owner )
        rows = c.fetchall()
        fks = {}
        for cons_name, cons_type, local_column, remote_table, remote_column in rows:
            try:
                fk = fks[cons_name]
            except KeyError:
                fk = ([], [])
                fks[cons_name] = fk
            refspec = ".".join([remote_table, remote_column])
            schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection)
            if local_column not in fk[0]:
                fk[0].append(local_column)
            if refspec not in fk[1]:
                fk[1].append(refspec)

        for name, value in fks.iteritems():
            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1] , None ))

        # PK
        c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type ,
                                         t4.colname as local_column
                                    from sysconstraints as t1 , systables as t2 ,
                                         sysindexes as t3 , syscolumns as t4
                                   where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'P'
                                     and t3.tabid = t2.tabid and t3.idxname = t1.idxname
                                     and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower(), owner )
        rows = c.fetchall()
        for cons_name, cons_type, local_column in rows:
            table.primary_key.add( table.c[local_column] )

class InfoCompiler(compiler.DefaultCompiler):
    """Info compiler modifies the lexical structure of Select statements to work under
    non-ANSI configured Oracle databases, if the use_ansi flag is False."""

    def __init__(self, *args, **kwargs):
        self.limit = 0
        self.offset = 0

        compiler.DefaultCompiler.__init__( self , *args, **kwargs )

    def default_from(self):
        return " from systables where tabname = 'systables' "

    def get_select_precolumns( self , select ):
        s = select._distinct and "DISTINCT " or ""
        # only has limit
        if select._limit:
            off = select._offset or 0
            s += " FIRST %s " % ( select._limit + off )
        else:
            s += ""
        return s

    def visit_select(self, select):
        if select._offset:
            self.offset = select._offset
            self.limit  = select._limit or 0
        # the column in order by clause must in select too

        def __label( c ):
            try:
                return c._label.lower()
            except:
                return ''

        # TODO: dont modify the original select, generate a new one
        a = [ __label(c) for c in select._raw_columns ]
        for c in select._order_by_clause.clauses:
            if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid':
                select.append_column( c )

        return compiler.DefaultCompiler.visit_select(self, select)

    def limit_clause(self, select):
        return ""

    def visit_function( self , func ):
        if func.name.lower() == 'current_date':
            return "today"
        elif func.name.lower() == 'current_time':
            return "CURRENT HOUR TO SECOND"
        elif func.name.lower() in ( 'current_timestamp' , 'now' ):
            return "CURRENT YEAR TO SECOND"
        else:
            return compiler.DefaultCompiler.visit_function( self , func )

    def visit_clauselist(self, list, within_order_by=False, **kwargs):
        try:
            li = [ c for c in list.clauses if c.name != 'oid' ]
        except:
            li = [ c for c in list.clauses ]
        return ', '.join([s for s in [self.process(c, within_order_by=within_order_by) for c in li] if s is not None])

class InfoSchemaGenerator(compiler.SchemaGenerator):
    def get_column_specification(self, column, first_pk=False):
        colspec = self.preparer.format_column(column)
        if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
           isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk:
            colspec += " SERIAL"
            self.has_serial = True
        else:
            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
            default = self.get_column_default_string(column)
            if default is not None:
                colspec += " DEFAULT " + default

        if not column.nullable:
            colspec += " NOT NULL"

        return colspec

    def post_create_table(self, table):
        if hasattr( self , 'has_serial' ):
            del self.has_serial
        return ''

    def visit_primary_key_constraint(self, constraint):
        # for informix 7.31 not support constraint name
        name = constraint.name
        constraint.name = None
        super(InfoSchemaGenerator, self).visit_primary_key_constraint(constraint)
        constraint.name = name

    def visit_unique_constraint(self, constraint):
        # for informix 7.31 not support constraint name
        name = constraint.name
        constraint.name = None
        super(InfoSchemaGenerator, self).visit_unique_constraint(constraint)
        constraint.name = name

    def visit_foreign_key_constraint( self , constraint ):
        if constraint.name is not None:
            constraint.use_alter = True
        else:
            super( InfoSchemaGenerator , self ).visit_foreign_key_constraint( constraint )

    def define_foreign_key(self, constraint):
        # for informix 7.31 not support constraint name
        if constraint.use_alter:
            name = constraint.name
            constraint.name = None
            self.append( "CONSTRAINT " )
            super(InfoSchemaGenerator, self).define_foreign_key(constraint)
            constraint.name = name
            if name is not None:
                self.append( " CONSTRAINT " + name )
        else:
            super(InfoSchemaGenerator, self).define_foreign_key(constraint)

    def visit_index(self, index):
        if len( index.columns ) == 1 and index.columns[0].foreign_key:
            return
        super(InfoSchemaGenerator, self).visit_index(index)

class InfoIdentifierPreparer(compiler.IdentifierPreparer):
    def __init__(self, dialect):
        super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")

    def _requires_quotes(self, value):
        return False

class InfoSchemaDropper(compiler.SchemaDropper):
    def drop_foreignkey(self, constraint):
        if constraint.name is not None:
            super( InfoSchemaDropper , self ).drop_foreignkey( constraint )

dialect = InfoDialect
poolclass = pool.SingletonThreadPool
dialect.statement_compiler = InfoCompiler
dialect.schemagenerator = InfoSchemaGenerator
dialect.schemadropper = InfoSchemaDropper
dialect.preparer = InfoIdentifierPreparer