summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/sybase/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/sybase/base.py')
-rw-r--r--lib/sqlalchemy/dialects/sybase/base.py469
1 files changed, 424 insertions, 45 deletions
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
index 2d213ed5b..e62d37447 100644
--- a/lib/sqlalchemy/dialects/sybase/base.py
+++ b/lib/sqlalchemy/dialects/sybase/base.py
@@ -21,8 +21,9 @@
and database reflection features are not implemented.
"""
-
import operator
+import re
+
from sqlalchemy.sql import compiler, expression, text, bindparam
from sqlalchemy.engine import default, base, reflection
from sqlalchemy import types as sqltypes
@@ -31,10 +32,10 @@ from sqlalchemy import schema as sa_schema
from sqlalchemy import util, sql, exc
from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
- TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
- BIGINT,INT, INTEGER, SMALLINT, BINARY,\
+ TEXT, DATE, DATETIME, FLOAT, NUMERIC,\
+ BIGINT, INT, INTEGER, SMALLINT, BINARY,\
VARBINARY, DECIMAL, TIMESTAMP, Unicode,\
- UnicodeText
+ UnicodeText, REAL
RESERVED_WORDS = set([
"add", "all", "alter", "and",
@@ -173,32 +174,68 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler):
return "UNIQUEIDENTIFIER"
ischema_names = {
- 'integer' : INTEGER,
- 'unsigned int' : INTEGER, # TODO: unsigned flags
- 'unsigned smallint' : SMALLINT, # TODO: unsigned flags
- 'unsigned bigint' : BIGINT, # TODO: unsigned flags
'bigint': BIGINT,
+ 'int' : INTEGER,
+ 'integer' : INTEGER,
'smallint' : SMALLINT,
'tinyint' : TINYINT,
- 'varchar' : VARCHAR,
- 'long varchar' : TEXT, # TODO
- 'char' : CHAR,
- 'decimal' : DECIMAL,
+ 'unsigned bigint' : BIGINT, # TODO: unsigned flags
+ 'unsigned int' : INTEGER, # TODO: unsigned flags
+ 'unsigned smallint' : SMALLINT, # TODO: unsigned flags
'numeric' : NUMERIC,
+ 'decimal' : DECIMAL,
+ 'dec' : DECIMAL,
'float' : FLOAT,
'double' : NUMERIC, # TODO
+ 'double precision' : NUMERIC, # TODO
+ 'real': REAL,
+ 'smallmoney': SMALLMONEY,
+ 'money': MONEY,
+ 'smalldatetime': DATETIME,
+ 'datetime': DATETIME,
+ 'date': DATE,
+ 'time': TIME,
+ 'char' : CHAR,
+ 'character' : CHAR,
+ 'varchar' : VARCHAR,
+ 'character varying' : VARCHAR,
+ 'char varying' : VARCHAR,
+ 'unichar' : UNICHAR,
+ 'unicode character' : UNIVARCHAR,
+ 'nchar': NCHAR,
+ 'national char': NCHAR,
+ 'national character': NCHAR,
+ 'nvarchar': NVARCHAR,
+ 'nchar varying': NVARCHAR,
+ 'national char varying': NVARCHAR,
+ 'national character varying': NVARCHAR,
+ 'text': TEXT,
+ 'unitext': UNITEXT,
'binary' : BINARY,
'varbinary' : VARBINARY,
- 'bit': BIT,
'image' : IMAGE,
+ 'bit': BIT,
+
+# not in documentation for ASE 15.7
+ 'long varchar' : TEXT, # TODO
'timestamp': TIMESTAMP,
- 'money': MONEY,
- 'smallmoney': MONEY,
'uniqueidentifier': UNIQUEIDENTIFIER,
}
+class SybaseInspector(reflection.Inspector):
+
+ def __init__(self, conn):
+ reflection.Inspector.__init__(self, conn)
+
+ def get_table_id(self, table_name, schema=None):
+ """Return the table id from `table_name` and `schema`."""
+
+ return self.dialect.get_table_id(self.bind, table_name, schema,
+ info_cache=self.info_cache)
+
+
class SybaseExecutionContext(default.DefaultExecutionContext):
_enable_identity_insert = False
@@ -246,7 +283,6 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
self.root_connection.connection.connection,
True)
-
def post_exec(self):
if self.isddl:
self.set_ddl_autocommit(self.root_connection, False)
@@ -348,16 +384,16 @@ class SybaseDDLCompiler(compiler.DDLCompiler):
# TODO: need correct syntax for this
colspec += " IDENTITY(%s,%s)" % (start, increment)
else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
if column.nullable is not None:
if not column.nullable or column.primary_key:
colspec += " NOT NULL"
else:
colspec += " NULL"
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
return colspec
def visit_drop_index(self, drop):
@@ -388,6 +424,7 @@ class SybaseDialect(default.DefaultDialect):
statement_compiler = SybaseSQLCompiler
ddl_compiler = SybaseDDLCompiler
preparer = SybaseIdentifierPreparer
+ inspector = SybaseInspector
def _get_default_schema_name(self, connection):
return connection.scalar(
@@ -404,38 +441,380 @@ class SybaseDialect(default.DefaultDialect):
self.max_identifier_length = 255
@reflection.cache
+ def get_table_id(self, connection, table_name, schema=None, **kw):
+ """Fetch the id for schema.table_name.
+
+ Several reflection methods require the table id. The idea for using
+ this method is that it can be fetched one time and cached for
+ subsequent calls.
+
+ """
+
+ table_id = None
+ if schema is None:
+ schema = self.default_schema_name
+
+ TABLEID_SQL = text("""
+ SELECT o.id AS id
+ FROM sysobjects o JOIN sysusers u ON o.uid=u.uid
+ WHERE u.name = :schema_name
+ AND o.name = :table_name
+ AND o.type = 'U'
+ """)
+
+ # Py2K
+ if isinstance(schema, unicode):
+ schema = schema.encode("ascii")
+ if isinstance(table_name, unicode):
+ table_name = table_name.encode("ascii")
+ # end Py2K
+ result = connection.execute(TABLEID_SQL,
+ schema_name=schema,
+ table_name=table_name)
+ table_id = result.scalar()
+ if table_id is None:
+ raise exc.NoSuchTableError(table_name)
+ return table_id
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(connection, table_name, schema,
+ info_cache=kw.get("info_cache"))
+
+ COLUMN_SQL = text("""
+ SELECT col.name AS name,
+ t.name AS type,
+ (col.status & 8) AS nullable,
+ (col.status & 128) AS autoincrement,
+ com.text AS 'default',
+ col.prec AS precision,
+ col.scale AS scale,
+ col.length AS length
+ FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON
+ col.cdefault = com.id
+ WHERE col.usertype = t.usertype
+ AND col.id = :table_id
+ ORDER BY col.colid
+ """)
+
+ results = connection.execute(COLUMN_SQL, table_id=table_id)
+
+ columns = []
+ for (name, type_, nullable, autoincrement, default, precision, scale,
+ length) in results:
+ col_info = self._get_column_info(name, type_, bool(nullable),
+ bool(autoincrement), default, precision, scale,
+ length)
+ columns.append(col_info)
+
+ return columns
+
+ def _get_column_info(self, name, type_, nullable, autoincrement, default,
+ precision, scale, length):
+
+ coltype = self.ischema_names.get(type_, None)
+
+ kwargs = {}
+
+ if coltype in (NUMERIC, DECIMAL):
+ args = (precision, scale)
+ elif coltype == FLOAT:
+ args = (precision,)
+ elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR):
+ args = (length,)
+ else:
+ args = ()
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ #is this necessary
+ #if is_array:
+ # coltype = ARRAY(coltype)
+ else:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (type_, name))
+ coltype = sqltypes.NULLTYPE
+
+ if default:
+ default = re.sub("DEFAULT", "", default).strip()
+ default = re.sub("^'(.*)'$", lambda m: m.group(1), default)
+ else:
+ default = None
+
+ column_info = dict(name=name, type=coltype, nullable=nullable,
+ default=default, autoincrement=autoincrement)
+ return column_info
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(connection, table_name, schema,
+ info_cache=kw.get("info_cache"))
+
+ table_cache = {}
+ column_cache = {}
+ foreign_keys = []
+
+ table_cache[table_id] = table_name
+
+ COLUMN_SQL = text("""
+ SELECT c.colid AS id, c.name AS name
+ FROM syscolumns c
+ WHERE c.id = :table_id
+ """)
+
+ results = connection.execute(COLUMN_SQL, table_id=table_id)
+ columns = {}
+ for col in results:
+ columns[col["id"]] = col["name"]
+ column_cache[table_id] = columns
+
+ REFCONSTRAINT_SQL = text("""
+ SELECT o.name AS name, r.reftabid AS reftable_id,
+ r.keycnt AS 'count',
+ r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3,
+ r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6,
+ r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9,
+ r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12,
+ r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15,
+ r.fokey16 AS fokey16,
+ r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3,
+ r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6,
+ r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9,
+ r.refkey10 AS refkey10, r.refkey11 AS refkey11,
+ r.refkey12 AS refkey12, r.refkey13 AS refkey13,
+ r.refkey14 AS refkey14, r.refkey15 AS refkey15,
+ r.refkey16 AS refkey16
+ FROM sysreferences r JOIN sysobjects o on r.tableid = o.id
+ WHERE r.tableid = :table_id
+ """)
+ referential_constraints = connection.execute(REFCONSTRAINT_SQL,
+ table_id=table_id)
+
+ REFTABLE_SQL = text("""
+ SELECT o.id AS id, o.name AS name, u.name AS 'schema'
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE o.id = :table_id
+ """)
+
+ for r in referential_constraints:
+
+ reftable_id = r["reftable_id"]
+
+ if reftable_id not in table_cache:
+ c = connection.execute(REFTABLE_SQL, table_id=reftable_id)
+ reftable = c.fetchone()
+ c.close()
+ table_cache[reftable_id] = {"name": reftable["name"],
+ "schema": reftable["schema"]}
+
+ results = connection.execute(COLUMN_SQL, table_id=reftable_id)
+ reftable_columns = {}
+ for col in results:
+ reftable_columns[col["id"]] = col["name"]
+ column_cache[reftable_id] = reftable_columns
+
+ reftable = table_cache[reftable_id]
+ reftable_columns = column_cache[reftable_id]
+
+ constrained_columns = []
+ referred_columns = []
+ for i in range(1, r["count"]+1):
+ constrained_columns.append(columns[r["fokey%i" % i]])
+ referred_columns.append(reftable_columns[r["refkey%i" % i]])
+
+ fk_info = {
+ "constrained_columns": constrained_columns,
+ "referred_schema": reftable["schema"],
+ "referred_table": reftable["name"],
+ "referred_columns": referred_columns,
+ "name": r["name"]
+ }
+
+ foreign_keys.append(fk_info)
+
+ return foreign_keys
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(connection, table_name, schema,
+ info_cache=kw.get("info_cache"))
+
+ INDEX_SQL = text("""
+ SELECT object_name(i.id) AS table_name,
+ i.keycnt AS 'count',
+ i.name AS name,
+ (i.status & 0x2) AS 'unique',
+ index_col(object_name(i.id), i.indid, 1) AS col_1,
+ index_col(object_name(i.id), i.indid, 2) AS col_2,
+ index_col(object_name(i.id), i.indid, 3) AS col_3,
+ index_col(object_name(i.id), i.indid, 4) AS col_4,
+ index_col(object_name(i.id), i.indid, 5) AS col_5,
+ index_col(object_name(i.id), i.indid, 6) AS col_6,
+ index_col(object_name(i.id), i.indid, 7) AS col_7,
+ index_col(object_name(i.id), i.indid, 8) AS col_8,
+ index_col(object_name(i.id), i.indid, 9) AS col_9,
+ index_col(object_name(i.id), i.indid, 10) AS col_10,
+ index_col(object_name(i.id), i.indid, 11) AS col_11,
+ index_col(object_name(i.id), i.indid, 12) AS col_12,
+ index_col(object_name(i.id), i.indid, 13) AS col_13,
+ index_col(object_name(i.id), i.indid, 14) AS col_14,
+ index_col(object_name(i.id), i.indid, 15) AS col_15,
+ index_col(object_name(i.id), i.indid, 16) AS col_16
+ FROM sysindexes i, sysobjects o
+ WHERE o.id = i.id
+ AND o.id = :table_id
+ AND (i.status & 2048) = 0
+ AND i.indid BETWEEN 1 AND 254
+ AND o.type = 'U'
+ """)
+
+ results = connection.execute(INDEX_SQL, table_id=table_id)
+ indexes = []
+ for r in results:
+ column_names = []
+ for i in range(1, r["count"]):
+ column_names.append(r["col_%i" % (i,)])
+ index_info = {"name": r["name"],
+ "unique": bool(r["unique"]),
+ "column_names": column_names}
+ indexes.append(index_info)
+
+ return indexes
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(connection, table_name, schema,
+ info_cache=kw.get("info_cache"))
+
+ PK_SQL = text("""
+ SELECT object_name(i.id) AS table_name,
+ i.keycnt AS 'count',
+ i.name AS name,
+ index_col(object_name(i.id), i.indid, 1) AS pk_1,
+ index_col(object_name(i.id), i.indid, 2) AS pk_2,
+ index_col(object_name(i.id), i.indid, 3) AS pk_3,
+ index_col(object_name(i.id), i.indid, 4) AS pk_4,
+ index_col(object_name(i.id), i.indid, 5) AS pk_5,
+ index_col(object_name(i.id), i.indid, 6) AS pk_6,
+ index_col(object_name(i.id), i.indid, 7) AS pk_7,
+ index_col(object_name(i.id), i.indid, 8) AS pk_8,
+ index_col(object_name(i.id), i.indid, 9) AS pk_9,
+ index_col(object_name(i.id), i.indid, 10) AS pk_10,
+ index_col(object_name(i.id), i.indid, 11) AS pk_11,
+ index_col(object_name(i.id), i.indid, 12) AS pk_12,
+ index_col(object_name(i.id), i.indid, 13) AS pk_13,
+ index_col(object_name(i.id), i.indid, 14) AS pk_14,
+ index_col(object_name(i.id), i.indid, 15) AS pk_15,
+ index_col(object_name(i.id), i.indid, 16) AS pk_16
+ FROM sysindexes i, sysobjects o
+ WHERE o.id = i.id
+ AND o.id = :table_id
+ AND (i.status & 2048) = 2048
+ AND i.indid BETWEEN 1 AND 254
+ AND o.type = 'U'
+ """)
+
+ results = connection.execute(PK_SQL, table_id=table_id)
+ pks = results.fetchone()
+ results.close()
+
+ constrained_columns = []
+ for i in range(1, pks["count"]+1):
+ constrained_columns.append(pks["pk_%i" % (i,)])
+ return {"constrained_columns": constrained_columns,
+ "name": pks["name"]}
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+
+ SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u")
+
+ schemas = connection.execute(SCHEMA_SQL)
+
+ return [s["name"] for s in schemas]
+
+ @reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema is None:
schema = self.default_schema_name
- result = connection.execute(
- text("select sysobjects.name from sysobjects, sysusers "
- "where sysobjects.uid=sysusers.uid and "
- "sysusers.name=:schemaname and "
- "sysobjects.type='U'",
- bindparams=[
- bindparam('schemaname', schema)
- ])
- )
- return [r[0] for r in result]
-
- def has_table(self, connection, tablename, schema=None):
+ TABLE_SQL = text("""
+ SELECT o.name AS name
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE u.name = :schema_name
+ AND o.type = 'U'
+ """)
+
+ # Py2K
+ if isinstance(schema, unicode):
+ schema = schema.encode("ascii")
+ # end Py2K
+ tables = connection.execute(TABLE_SQL, schema_name=schema)
+
+ return [t["name"] for t in tables]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+
+ VIEW_DEF_SQL = text("""
+ SELECT c.text
+ FROM syscomments c JOIN sysobjects o ON c.id = o.id
+ WHERE o.name = :view_name
+ AND o.type = 'V'
+ """)
+
+ # Py2K
+ if isinstance(view_name, unicode):
+ view_name = view_name.encode("ascii")
+ # end Py2K
+ view = connection.execute(VIEW_DEF_SQL, view_name=view_name)
+
+ return view.scalar()
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+
+ VIEW_SQL = text("""
+ SELECT o.name AS name
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE u.name = :schema_name
+ AND o.type = 'V'
+ """)
+
+ # Py2K
+ if isinstance(schema, unicode):
+ schema = schema.encode("ascii")
+ # end Py2K
+ views = connection.execute(VIEW_SQL, schema_name=schema)
+
+ return [v["name"] for v in views]
+
+ def has_table(self, connection, table_name, schema=None):
if schema is None:
schema = self.default_schema_name
- result = connection.execute(
- text("select sysobjects.name from sysobjects, sysusers "
- "where sysobjects.uid=sysusers.uid and "
- "sysobjects.name=:tablename and "
- "sysusers.name=:schemaname and "
- "sysobjects.type='U'",
- bindparams=[
- bindparam('tablename', tablename),
- bindparam('schemaname', schema)
- ])
- )
+ HAS_TABLE_SQL = text("""
+ SELECT o.name
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE o.name = :table_name
+ AND u.name = :schema_name
+ AND o.type = 'U'
+ """)
+
+ # Py2K
+ if isinstance(schema, unicode):
+ schema = schema.encode("ascii")
+ if isinstance(table_name, unicode):
+ table_name = table_name.encode("ascii")
+ # end Py2K
+ result = connection.execute(HAS_TABLE_SQL, table_name=table_name,
+ schema_name=schema)
return result.scalar() is not None
- def reflecttable(self, connection, table, include_columns):
- raise NotImplementedError()
+ #def reflecttable(self, connection, table, include_columns):
+ # raise NotImplementedError()