diff options
author | Benjamin Trofatter <bentrofatter@gmail.com> | 2012-10-30 03:52:22 -0500 |
---|---|---|
committer | Benjamin Trofatter <bentrofatter@gmail.com> | 2012-10-30 03:52:22 -0500 |
commit | 05314919616cde74f4cb3393b25dbf5d062fa64c (patch) | |
tree | 942af93f1b37f8317877e6e0135f8db00e79e11c /lib/sqlalchemy/dialects/sybase/base.py | |
parent | 47da8e06da5067b87f1342d0f781696e790e9005 (diff) | |
download | sqlalchemy-05314919616cde74f4cb3393b25dbf5d062fa64c.tar.gz |
Added reflection to sqlalchemy.dialects.sybase
Added missing types supported by Sybase to ischema_names mapping
Created a SybaseInspector similar to the PGInspector, with a cached table_id
lookup, and added it to the SybaseDialect as the default inspector.
Added the following methods to SybaseDialect:
get_table_id
get_columns
_get_column_info : support method for get_columns
get_foreign_keys
get_indexes
get_pk_constraint
get_schema_names
get_view_definition
get_view_names
Rewrote the following methods to conform to the style of the rest:
get_table_names
has_table
Reordered colspec builder to put default clause after "NULL/NOT NULL",
instead of before. This fixed a syntax error.
Diffstat (limited to 'lib/sqlalchemy/dialects/sybase/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/sybase/base.py | 469 |
1 files changed, 424 insertions, 45 deletions
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 2d213ed5b..e62d37447 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -21,8 +21,9 @@ and database reflection features are not implemented. """ - import operator +import re + from sqlalchemy.sql import compiler, expression, text, bindparam from sqlalchemy.engine import default, base, reflection from sqlalchemy import types as sqltypes @@ -31,10 +32,10 @@ from sqlalchemy import schema as sa_schema from sqlalchemy import util, sql, exc from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ - TEXT,DATE,DATETIME, FLOAT, NUMERIC,\ - BIGINT,INT, INTEGER, SMALLINT, BINARY,\ + TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ + BIGINT, INT, INTEGER, SMALLINT, BINARY,\ VARBINARY, DECIMAL, TIMESTAMP, Unicode,\ - UnicodeText + UnicodeText, REAL RESERVED_WORDS = set([ "add", "all", "alter", "and", @@ -173,32 +174,68 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler): return "UNIQUEIDENTIFIER" ischema_names = { - 'integer' : INTEGER, - 'unsigned int' : INTEGER, # TODO: unsigned flags - 'unsigned smallint' : SMALLINT, # TODO: unsigned flags - 'unsigned bigint' : BIGINT, # TODO: unsigned flags 'bigint': BIGINT, + 'int' : INTEGER, + 'integer' : INTEGER, 'smallint' : SMALLINT, 'tinyint' : TINYINT, - 'varchar' : VARCHAR, - 'long varchar' : TEXT, # TODO - 'char' : CHAR, - 'decimal' : DECIMAL, + 'unsigned bigint' : BIGINT, # TODO: unsigned flags + 'unsigned int' : INTEGER, # TODO: unsigned flags + 'unsigned smallint' : SMALLINT, # TODO: unsigned flags 'numeric' : NUMERIC, + 'decimal' : DECIMAL, + 'dec' : DECIMAL, 'float' : FLOAT, 'double' : NUMERIC, # TODO + 'double precision' : NUMERIC, # TODO + 'real': REAL, + 'smallmoney': SMALLMONEY, + 'money': MONEY, + 'smalldatetime': DATETIME, + 'datetime': DATETIME, + 'date': DATE, + 'time': TIME, + 'char' : CHAR, + 'character' : CHAR, + 'varchar' : VARCHAR, + 'character varying' : VARCHAR, + 'char varying' : VARCHAR, + 'unichar' : UNICHAR, + 'unicode character' : UNIVARCHAR, + 'nchar': NCHAR, + 'national char': NCHAR, + 'national character': NCHAR, + 'nvarchar': NVARCHAR, + 'nchar varying': NVARCHAR, + 'national char varying': NVARCHAR, + 'national character varying': NVARCHAR, + 'text': TEXT, + 'unitext': UNITEXT, 'binary' : BINARY, 'varbinary' : VARBINARY, - 'bit': BIT, 'image' : IMAGE, + 'bit': BIT, + +# not in documentation for ASE 15.7 + 'long varchar' : TEXT, # TODO 'timestamp': TIMESTAMP, - 'money': MONEY, - 'smallmoney': MONEY, 'uniqueidentifier': UNIQUEIDENTIFIER, } +class SybaseInspector(reflection.Inspector): + + def __init__(self, conn): + reflection.Inspector.__init__(self, conn) + + def get_table_id(self, table_name, schema=None): + """Return the table id from `table_name` and `schema`.""" + + return self.dialect.get_table_id(self.bind, table_name, schema, + info_cache=self.info_cache) + + class SybaseExecutionContext(default.DefaultExecutionContext): _enable_identity_insert = False @@ -246,7 +283,6 @@ class SybaseExecutionContext(default.DefaultExecutionContext): self.root_connection.connection.connection, True) - def post_exec(self): if self.isddl: self.set_ddl_autocommit(self.root_connection, False) @@ -348,16 +384,16 @@ class SybaseDDLCompiler(compiler.DDLCompiler): # TODO: need correct syntax for this colspec += " IDENTITY(%s,%s)" % (start, increment) else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + if column.nullable is not None: if not column.nullable or column.primary_key: colspec += " NOT NULL" else: colspec += " NULL" - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - return colspec def visit_drop_index(self, drop): @@ -388,6 +424,7 @@ class SybaseDialect(default.DefaultDialect): statement_compiler = SybaseSQLCompiler ddl_compiler = SybaseDDLCompiler preparer = SybaseIdentifierPreparer + inspector = SybaseInspector def _get_default_schema_name(self, connection): return connection.scalar( @@ -404,38 +441,380 @@ class SybaseDialect(default.DefaultDialect): self.max_identifier_length = 255 @reflection.cache + def get_table_id(self, connection, table_name, schema=None, **kw): + """Fetch the id for schema.table_name. + + Several reflection methods require the table id. The idea for using + this method is that it can be fetched one time and cached for + subsequent calls. + + """ + + table_id = None + if schema is None: + schema = self.default_schema_name + + TABLEID_SQL = text(""" + SELECT o.id AS id + FROM sysobjects o JOIN sysusers u ON o.uid=u.uid + WHERE u.name = :schema_name + AND o.name = :table_name + AND o.type = 'U' + """) + + # Py2K + if isinstance(schema, unicode): + schema = schema.encode("ascii") + if isinstance(table_name, unicode): + table_name = table_name.encode("ascii") + # end Py2K + result = connection.execute(TABLEID_SQL, + schema_name=schema, + table_name=table_name) + table_id = result.scalar() + if table_id is None: + raise exc.NoSuchTableError(table_name) + return table_id + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id(connection, table_name, schema, + info_cache=kw.get("info_cache")) + + COLUMN_SQL = text(""" + SELECT col.name AS name, + t.name AS type, + (col.status & 8) AS nullable, + (col.status & 128) AS autoincrement, + com.text AS 'default', + col.prec AS precision, + col.scale AS scale, + col.length AS length + FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON + col.cdefault = com.id + WHERE col.usertype = t.usertype + AND col.id = :table_id + ORDER BY col.colid + """) + + results = connection.execute(COLUMN_SQL, table_id=table_id) + + columns = [] + for (name, type_, nullable, autoincrement, default, precision, scale, + length) in results: + col_info = self._get_column_info(name, type_, bool(nullable), + bool(autoincrement), default, precision, scale, + length) + columns.append(col_info) + + return columns + + def _get_column_info(self, name, type_, nullable, autoincrement, default, + precision, scale, length): + + coltype = self.ischema_names.get(type_, None) + + kwargs = {} + + if coltype in (NUMERIC, DECIMAL): + args = (precision, scale) + elif coltype == FLOAT: + args = (precision,) + elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR): + args = (length,) + else: + args = () + + if coltype: + coltype = coltype(*args, **kwargs) + #is this necessary + #if is_array: + # coltype = ARRAY(coltype) + else: + util.warn("Did not recognize type '%s' of column '%s'" % + (type_, name)) + coltype = sqltypes.NULLTYPE + + if default: + default = re.sub("DEFAULT", "", default).strip() + default = re.sub("^'(.*)'$", lambda m: m.group(1), default) + else: + default = None + + column_info = dict(name=name, type=coltype, nullable=nullable, + default=default, autoincrement=autoincrement) + return column_info + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id(connection, table_name, schema, + info_cache=kw.get("info_cache")) + + table_cache = {} + column_cache = {} + foreign_keys = [] + + table_cache[table_id] = table_name + + COLUMN_SQL = text(""" + SELECT c.colid AS id, c.name AS name + FROM syscolumns c + WHERE c.id = :table_id + """) + + results = connection.execute(COLUMN_SQL, table_id=table_id) + columns = {} + for col in results: + columns[col["id"]] = col["name"] + column_cache[table_id] = columns + + REFCONSTRAINT_SQL = text(""" + SELECT o.name AS name, r.reftabid AS reftable_id, + r.keycnt AS 'count', + r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3, + r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6, + r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9, + r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12, + r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15, + r.fokey16 AS fokey16, + r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3, + r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6, + r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9, + r.refkey10 AS refkey10, r.refkey11 AS refkey11, + r.refkey12 AS refkey12, r.refkey13 AS refkey13, + r.refkey14 AS refkey14, r.refkey15 AS refkey15, + r.refkey16 AS refkey16 + FROM sysreferences r JOIN sysobjects o on r.tableid = o.id + WHERE r.tableid = :table_id + """) + referential_constraints = connection.execute(REFCONSTRAINT_SQL, + table_id=table_id) + + REFTABLE_SQL = text(""" + SELECT o.id AS id, o.name AS name, u.name AS 'schema' + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE o.id = :table_id + """) + + for r in referential_constraints: + + reftable_id = r["reftable_id"] + + if reftable_id not in table_cache: + c = connection.execute(REFTABLE_SQL, table_id=reftable_id) + reftable = c.fetchone() + c.close() + table_cache[reftable_id] = {"name": reftable["name"], + "schema": reftable["schema"]} + + results = connection.execute(COLUMN_SQL, table_id=reftable_id) + reftable_columns = {} + for col in results: + reftable_columns[col["id"]] = col["name"] + column_cache[reftable_id] = reftable_columns + + reftable = table_cache[reftable_id] + reftable_columns = column_cache[reftable_id] + + constrained_columns = [] + referred_columns = [] + for i in range(1, r["count"]+1): + constrained_columns.append(columns[r["fokey%i" % i]]) + referred_columns.append(reftable_columns[r["refkey%i" % i]]) + + fk_info = { + "constrained_columns": constrained_columns, + "referred_schema": reftable["schema"], + "referred_table": reftable["name"], + "referred_columns": referred_columns, + "name": r["name"] + } + + foreign_keys.append(fk_info) + + return foreign_keys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id(connection, table_name, schema, + info_cache=kw.get("info_cache")) + + INDEX_SQL = text(""" + SELECT object_name(i.id) AS table_name, + i.keycnt AS 'count', + i.name AS name, + (i.status & 0x2) AS 'unique', + index_col(object_name(i.id), i.indid, 1) AS col_1, + index_col(object_name(i.id), i.indid, 2) AS col_2, + index_col(object_name(i.id), i.indid, 3) AS col_3, + index_col(object_name(i.id), i.indid, 4) AS col_4, + index_col(object_name(i.id), i.indid, 5) AS col_5, + index_col(object_name(i.id), i.indid, 6) AS col_6, + index_col(object_name(i.id), i.indid, 7) AS col_7, + index_col(object_name(i.id), i.indid, 8) AS col_8, + index_col(object_name(i.id), i.indid, 9) AS col_9, + index_col(object_name(i.id), i.indid, 10) AS col_10, + index_col(object_name(i.id), i.indid, 11) AS col_11, + index_col(object_name(i.id), i.indid, 12) AS col_12, + index_col(object_name(i.id), i.indid, 13) AS col_13, + index_col(object_name(i.id), i.indid, 14) AS col_14, + index_col(object_name(i.id), i.indid, 15) AS col_15, + index_col(object_name(i.id), i.indid, 16) AS col_16 + FROM sysindexes i, sysobjects o + WHERE o.id = i.id + AND o.id = :table_id + AND (i.status & 2048) = 0 + AND i.indid BETWEEN 1 AND 254 + AND o.type = 'U' + """) + + results = connection.execute(INDEX_SQL, table_id=table_id) + indexes = [] + for r in results: + column_names = [] + for i in range(1, r["count"]): + column_names.append(r["col_%i" % (i,)]) + index_info = {"name": r["name"], + "unique": bool(r["unique"]), + "column_names": column_names} + indexes.append(index_info) + + return indexes + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id(connection, table_name, schema, + info_cache=kw.get("info_cache")) + + PK_SQL = text(""" + SELECT object_name(i.id) AS table_name, + i.keycnt AS 'count', + i.name AS name, + index_col(object_name(i.id), i.indid, 1) AS pk_1, + index_col(object_name(i.id), i.indid, 2) AS pk_2, + index_col(object_name(i.id), i.indid, 3) AS pk_3, + index_col(object_name(i.id), i.indid, 4) AS pk_4, + index_col(object_name(i.id), i.indid, 5) AS pk_5, + index_col(object_name(i.id), i.indid, 6) AS pk_6, + index_col(object_name(i.id), i.indid, 7) AS pk_7, + index_col(object_name(i.id), i.indid, 8) AS pk_8, + index_col(object_name(i.id), i.indid, 9) AS pk_9, + index_col(object_name(i.id), i.indid, 10) AS pk_10, + index_col(object_name(i.id), i.indid, 11) AS pk_11, + index_col(object_name(i.id), i.indid, 12) AS pk_12, + index_col(object_name(i.id), i.indid, 13) AS pk_13, + index_col(object_name(i.id), i.indid, 14) AS pk_14, + index_col(object_name(i.id), i.indid, 15) AS pk_15, + index_col(object_name(i.id), i.indid, 16) AS pk_16 + FROM sysindexes i, sysobjects o + WHERE o.id = i.id + AND o.id = :table_id + AND (i.status & 2048) = 2048 + AND i.indid BETWEEN 1 AND 254 + AND o.type = 'U' + """) + + results = connection.execute(PK_SQL, table_id=table_id) + pks = results.fetchone() + results.close() + + constrained_columns = [] + for i in range(1, pks["count"]+1): + constrained_columns.append(pks["pk_%i" % (i,)]) + return {"constrained_columns": constrained_columns, + "name": pks["name"]} + + @reflection.cache + def get_schema_names(self, connection, **kw): + + SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u") + + schemas = connection.execute(SCHEMA_SQL) + + return [s["name"] for s in schemas] + + @reflection.cache def get_table_names(self, connection, schema=None, **kw): if schema is None: schema = self.default_schema_name - result = connection.execute( - text("select sysobjects.name from sysobjects, sysusers " - "where sysobjects.uid=sysusers.uid and " - "sysusers.name=:schemaname and " - "sysobjects.type='U'", - bindparams=[ - bindparam('schemaname', schema) - ]) - ) - return [r[0] for r in result] - - def has_table(self, connection, tablename, schema=None): + TABLE_SQL = text(""" + SELECT o.name AS name + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE u.name = :schema_name + AND o.type = 'U' + """) + + # Py2K + if isinstance(schema, unicode): + schema = schema.encode("ascii") + # end Py2K + tables = connection.execute(TABLE_SQL, schema_name=schema) + + return [t["name"] for t in tables] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + + VIEW_DEF_SQL = text(""" + SELECT c.text + FROM syscomments c JOIN sysobjects o ON c.id = o.id + WHERE o.name = :view_name + AND o.type = 'V' + """) + + # Py2K + if isinstance(view_name, unicode): + view_name = view_name.encode("ascii") + # end Py2K + view = connection.execute(VIEW_DEF_SQL, view_name=view_name) + + return view.scalar() + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + + VIEW_SQL = text(""" + SELECT o.name AS name + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE u.name = :schema_name + AND o.type = 'V' + """) + + # Py2K + if isinstance(schema, unicode): + schema = schema.encode("ascii") + # end Py2K + views = connection.execute(VIEW_SQL, schema_name=schema) + + return [v["name"] for v in views] + + def has_table(self, connection, table_name, schema=None): if schema is None: schema = self.default_schema_name - result = connection.execute( - text("select sysobjects.name from sysobjects, sysusers " - "where sysobjects.uid=sysusers.uid and " - "sysobjects.name=:tablename and " - "sysusers.name=:schemaname and " - "sysobjects.type='U'", - bindparams=[ - bindparam('tablename', tablename), - bindparam('schemaname', schema) - ]) - ) + HAS_TABLE_SQL = text(""" + SELECT o.name + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE o.name = :table_name + AND u.name = :schema_name + AND o.type = 'U' + """) + + # Py2K + if isinstance(schema, unicode): + schema = schema.encode("ascii") + if isinstance(table_name, unicode): + table_name = table_name.encode("ascii") + # end Py2K + result = connection.execute(HAS_TABLE_SQL, table_name=table_name, + schema_name=schema) return result.scalar() is not None - def reflecttable(self, connection, table, include_columns): - raise NotImplementedError() + #def reflecttable(self, connection, table, include_columns): + # raise NotImplementedError() |