summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-11-17 11:37:32 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2012-11-17 11:37:32 -0500
commit123a5349d2c13c756ecf50e26e2cecf3d0e30c98 (patch)
tree73ead49c00a5c2615130d8009b365d42fb4ab714
parent7c3de81ee06b3cda03839bbbf85f89fd572551bf (diff)
parente7d0ba7f760c1646db3af8fefa10eea4b0136005 (diff)
downloadsqlalchemy-123a5349d2c13c756ecf50e26e2cecf3d0e30c98.tar.gz
- merge ben's patch with updates
-rw-r--r--lib/sqlalchemy/dialects/sybase/base.py462
-rw-r--r--lib/sqlalchemy/engine/base.py2
-rw-r--r--test/requirements.py24
3 files changed, 432 insertions, 56 deletions
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
index 2d213ed5b..dfa26a170 100644
--- a/lib/sqlalchemy/dialects/sybase/base.py
+++ b/lib/sqlalchemy/dialects/sybase/base.py
@@ -17,12 +17,12 @@
The Sybase dialect functions on current SQLAlchemy versions
but is not regularly tested, and may have many issues and
- caveats not currently handled. In particular, the table
- and database reflection features are not implemented.
+ caveats not currently handled.
"""
-
import operator
+import re
+
from sqlalchemy.sql import compiler, expression, text, bindparam
from sqlalchemy.engine import default, base, reflection
from sqlalchemy import types as sqltypes
@@ -31,10 +31,10 @@ from sqlalchemy import schema as sa_schema
from sqlalchemy import util, sql, exc
from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
- TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
- BIGINT,INT, INTEGER, SMALLINT, BINARY,\
+ TEXT, DATE, DATETIME, FLOAT, NUMERIC,\
+ BIGINT, INT, INTEGER, SMALLINT, BINARY,\
VARBINARY, DECIMAL, TIMESTAMP, Unicode,\
- UnicodeText
+ UnicodeText, REAL
RESERVED_WORDS = set([
"add", "all", "alter", "and",
@@ -173,32 +173,68 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler):
return "UNIQUEIDENTIFIER"
ischema_names = {
- 'integer' : INTEGER,
- 'unsigned int' : INTEGER, # TODO: unsigned flags
- 'unsigned smallint' : SMALLINT, # TODO: unsigned flags
- 'unsigned bigint' : BIGINT, # TODO: unsigned flags
'bigint': BIGINT,
+ 'int' : INTEGER,
+ 'integer' : INTEGER,
'smallint' : SMALLINT,
'tinyint' : TINYINT,
- 'varchar' : VARCHAR,
- 'long varchar' : TEXT, # TODO
- 'char' : CHAR,
- 'decimal' : DECIMAL,
+ 'unsigned bigint' : BIGINT, # TODO: unsigned flags
+ 'unsigned int' : INTEGER, # TODO: unsigned flags
+ 'unsigned smallint' : SMALLINT, # TODO: unsigned flags
'numeric' : NUMERIC,
+ 'decimal' : DECIMAL,
+ 'dec' : DECIMAL,
'float' : FLOAT,
'double' : NUMERIC, # TODO
+ 'double precision' : NUMERIC, # TODO
+ 'real': REAL,
+ 'smallmoney': SMALLMONEY,
+ 'money': MONEY,
+ 'smalldatetime': DATETIME,
+ 'datetime': DATETIME,
+ 'date': DATE,
+ 'time': TIME,
+ 'char' : CHAR,
+ 'character' : CHAR,
+ 'varchar' : VARCHAR,
+ 'character varying' : VARCHAR,
+ 'char varying' : VARCHAR,
+ 'unichar' : UNICHAR,
+ 'unicode character' : UNIVARCHAR,
+ 'nchar': NCHAR,
+ 'national char': NCHAR,
+ 'national character': NCHAR,
+ 'nvarchar': NVARCHAR,
+ 'nchar varying': NVARCHAR,
+ 'national char varying': NVARCHAR,
+ 'national character varying': NVARCHAR,
+ 'text': TEXT,
+ 'unitext': UNITEXT,
'binary' : BINARY,
'varbinary' : VARBINARY,
- 'bit': BIT,
'image' : IMAGE,
+ 'bit': BIT,
+
+# not in documentation for ASE 15.7
+ 'long varchar' : TEXT, # TODO
'timestamp': TIMESTAMP,
- 'money': MONEY,
- 'smallmoney': MONEY,
'uniqueidentifier': UNIQUEIDENTIFIER,
}
+class SybaseInspector(reflection.Inspector):
+
+ def __init__(self, conn):
+ reflection.Inspector.__init__(self, conn)
+
+ def get_table_id(self, table_name, schema=None):
+ """Return the table id from `table_name` and `schema`."""
+
+ return self.dialect.get_table_id(self.bind, table_name, schema,
+ info_cache=self.info_cache)
+
+
class SybaseExecutionContext(default.DefaultExecutionContext):
_enable_identity_insert = False
@@ -246,7 +282,6 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
self.root_connection.connection.connection,
True)
-
def post_exec(self):
if self.isddl:
self.set_ddl_autocommit(self.root_connection, False)
@@ -306,6 +341,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
return 'DATEPART("%s", %s)' % (
field, self.process(extract.expr, **kw))
+ def visit_now_func(self, fn, **kw):
+ return "GETDATE()"
+
def for_update_clause(self, select):
# "FOR UPDATE" is only allowed on "DECLARE CURSOR"
# which SQLAlchemy doesn't use
@@ -348,16 +386,16 @@ class SybaseDDLCompiler(compiler.DDLCompiler):
# TODO: need correct syntax for this
colspec += " IDENTITY(%s,%s)" % (start, increment)
else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
if column.nullable is not None:
if not column.nullable or column.primary_key:
colspec += " NOT NULL"
else:
colspec += " NULL"
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
return colspec
def visit_drop_index(self, drop):
@@ -388,6 +426,7 @@ class SybaseDialect(default.DefaultDialect):
statement_compiler = SybaseSQLCompiler
ddl_compiler = SybaseDDLCompiler
preparer = SybaseIdentifierPreparer
+ inspector = SybaseInspector
def _get_default_schema_name(self, connection):
return connection.scalar(
@@ -403,39 +442,364 @@ class SybaseDialect(default.DefaultDialect):
else:
self.max_identifier_length = 255
+ def get_table_id(self, connection, table_name, schema=None, **kw):
+ """Fetch the id for schema.table_name.
+
+ Several reflection methods require the table id. The idea for using
+ this method is that it can be fetched one time and cached for
+ subsequent calls.
+
+ """
+
+ table_id = None
+ if schema is None:
+ schema = self.default_schema_name
+
+ TABLEID_SQL = text("""
+ SELECT o.id AS id
+ FROM sysobjects o JOIN sysusers u ON o.uid=u.uid
+ WHERE u.name = :schema_name
+ AND o.name = :table_name
+ AND o.type in ('U', 'V')
+ """)
+
+ # Py2K
+ if isinstance(schema, unicode):
+ schema = schema.encode("ascii")
+ if isinstance(table_name, unicode):
+ table_name = table_name.encode("ascii")
+ # end Py2K
+ result = connection.execute(TABLEID_SQL,
+ schema_name=schema,
+ table_name=table_name)
+ table_id = result.scalar()
+ if table_id is None:
+ raise exc.NoSuchTableError(table_name)
+ return table_id
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(connection, table_name, schema,
+ info_cache=kw.get("info_cache"))
+
+ COLUMN_SQL = text("""
+ SELECT col.name AS name,
+ t.name AS type,
+ (col.status & 8) AS nullable,
+ (col.status & 128) AS autoincrement,
+ com.text AS 'default',
+ col.prec AS precision,
+ col.scale AS scale,
+ col.length AS length
+ FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON
+ col.cdefault = com.id
+ WHERE col.usertype = t.usertype
+ AND col.id = :table_id
+ ORDER BY col.colid
+ """)
+
+ results = connection.execute(COLUMN_SQL, table_id=table_id)
+
+ columns = []
+ for (name, type_, nullable, autoincrement, default, precision, scale,
+ length) in results:
+ col_info = self._get_column_info(name, type_, bool(nullable),
+ bool(autoincrement), default, precision, scale,
+ length)
+ columns.append(col_info)
+
+ return columns
+
+ def _get_column_info(self, name, type_, nullable, autoincrement, default,
+ precision, scale, length):
+
+ coltype = self.ischema_names.get(type_, None)
+
+ kwargs = {}
+
+ if coltype in (NUMERIC, DECIMAL):
+ args = (precision, scale)
+ elif coltype == FLOAT:
+ args = (precision,)
+ elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR):
+ args = (length,)
+ else:
+ args = ()
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ #is this necessary
+ #if is_array:
+ # coltype = ARRAY(coltype)
+ else:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (type_, name))
+ coltype = sqltypes.NULLTYPE
+
+ if default:
+ default = re.sub("DEFAULT", "", default).strip()
+ default = re.sub("^'(.*)'$", lambda m: m.group(1), default)
+ else:
+ default = None
+
+ column_info = dict(name=name, type=coltype, nullable=nullable,
+ default=default, autoincrement=autoincrement)
+ return column_info
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+
+ table_id = self.get_table_id(connection, table_name, schema,
+ info_cache=kw.get("info_cache"))
+
+ table_cache = {}
+ column_cache = {}
+ foreign_keys = []
+
+ table_cache[table_id] = {"name": table_name, "schema": schema}
+
+ COLUMN_SQL = text("""
+ SELECT c.colid AS id, c.name AS name
+ FROM syscolumns c
+ WHERE c.id = :table_id
+ """)
+
+ results = connection.execute(COLUMN_SQL, table_id=table_id)
+ columns = {}
+ for col in results:
+ columns[col["id"]] = col["name"]
+ column_cache[table_id] = columns
+
+ REFCONSTRAINT_SQL = text("""
+ SELECT o.name AS name, r.reftabid AS reftable_id,
+ r.keycnt AS 'count',
+ r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3,
+ r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6,
+ r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9,
+ r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12,
+ r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15,
+ r.fokey16 AS fokey16,
+ r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3,
+ r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6,
+ r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9,
+ r.refkey10 AS refkey10, r.refkey11 AS refkey11,
+ r.refkey12 AS refkey12, r.refkey13 AS refkey13,
+ r.refkey14 AS refkey14, r.refkey15 AS refkey15,
+ r.refkey16 AS refkey16
+ FROM sysreferences r JOIN sysobjects o on r.tableid = o.id
+ WHERE r.tableid = :table_id
+ """)
+ referential_constraints = connection.execute(REFCONSTRAINT_SQL,
+ table_id=table_id)
+
+ REFTABLE_SQL = text("""
+ SELECT o.name AS name, u.name AS 'schema'
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE o.id = :table_id
+ """)
+
+ for r in referential_constraints:
+
+ reftable_id = r["reftable_id"]
+
+ if reftable_id not in table_cache:
+ c = connection.execute(REFTABLE_SQL, table_id=reftable_id)
+ reftable = c.fetchone()
+ c.close()
+ table_info = {"name": reftable["name"], "schema": None}
+ if (schema is not None or
+ reftable["schema"] != self.default_schema_name):
+ table_info["schema"] = reftable["schema"]
+
+ table_cache[reftable_id] = table_info
+ results = connection.execute(COLUMN_SQL, table_id=reftable_id)
+ reftable_columns = {}
+ for col in results:
+ reftable_columns[col["id"]] = col["name"]
+ column_cache[reftable_id] = reftable_columns
+
+ reftable = table_cache[reftable_id]
+ reftable_columns = column_cache[reftable_id]
+
+ constrained_columns = []
+ referred_columns = []
+ for i in range(1, r["count"]+1):
+ constrained_columns.append(columns[r["fokey%i" % i]])
+ referred_columns.append(reftable_columns[r["refkey%i" % i]])
+
+ fk_info = {
+ "constrained_columns": constrained_columns,
+ "referred_schema": reftable["schema"],
+ "referred_table": reftable["name"],
+ "referred_columns": referred_columns,
+ "name": r["name"]
+ }
+
+ foreign_keys.append(fk_info)
+
+ return foreign_keys
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(connection, table_name, schema,
+ info_cache=kw.get("info_cache"))
+
+ INDEX_SQL = text("""
+ SELECT object_name(i.id) AS table_name,
+ i.keycnt AS 'count',
+ i.name AS name,
+ (i.status & 0x2) AS 'unique',
+ index_col(object_name(i.id), i.indid, 1) AS col_1,
+ index_col(object_name(i.id), i.indid, 2) AS col_2,
+ index_col(object_name(i.id), i.indid, 3) AS col_3,
+ index_col(object_name(i.id), i.indid, 4) AS col_4,
+ index_col(object_name(i.id), i.indid, 5) AS col_5,
+ index_col(object_name(i.id), i.indid, 6) AS col_6,
+ index_col(object_name(i.id), i.indid, 7) AS col_7,
+ index_col(object_name(i.id), i.indid, 8) AS col_8,
+ index_col(object_name(i.id), i.indid, 9) AS col_9,
+ index_col(object_name(i.id), i.indid, 10) AS col_10,
+ index_col(object_name(i.id), i.indid, 11) AS col_11,
+ index_col(object_name(i.id), i.indid, 12) AS col_12,
+ index_col(object_name(i.id), i.indid, 13) AS col_13,
+ index_col(object_name(i.id), i.indid, 14) AS col_14,
+ index_col(object_name(i.id), i.indid, 15) AS col_15,
+ index_col(object_name(i.id), i.indid, 16) AS col_16
+ FROM sysindexes i, sysobjects o
+ WHERE o.id = i.id
+ AND o.id = :table_id
+ AND (i.status & 2048) = 0
+ AND i.indid BETWEEN 1 AND 254
+ """)
+
+ results = connection.execute(INDEX_SQL, table_id=table_id)
+ indexes = []
+ for r in results:
+ column_names = []
+ for i in range(1, r["count"]):
+ column_names.append(r["col_%i" % (i,)])
+ index_info = {"name": r["name"],
+ "unique": bool(r["unique"]),
+ "column_names": column_names}
+ indexes.append(index_info)
+
+ return indexes
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(connection, table_name, schema,
+ info_cache=kw.get("info_cache"))
+
+ PK_SQL = text("""
+ SELECT object_name(i.id) AS table_name,
+ i.keycnt AS 'count',
+ i.name AS name,
+ index_col(object_name(i.id), i.indid, 1) AS pk_1,
+ index_col(object_name(i.id), i.indid, 2) AS pk_2,
+ index_col(object_name(i.id), i.indid, 3) AS pk_3,
+ index_col(object_name(i.id), i.indid, 4) AS pk_4,
+ index_col(object_name(i.id), i.indid, 5) AS pk_5,
+ index_col(object_name(i.id), i.indid, 6) AS pk_6,
+ index_col(object_name(i.id), i.indid, 7) AS pk_7,
+ index_col(object_name(i.id), i.indid, 8) AS pk_8,
+ index_col(object_name(i.id), i.indid, 9) AS pk_9,
+ index_col(object_name(i.id), i.indid, 10) AS pk_10,
+ index_col(object_name(i.id), i.indid, 11) AS pk_11,
+ index_col(object_name(i.id), i.indid, 12) AS pk_12,
+ index_col(object_name(i.id), i.indid, 13) AS pk_13,
+ index_col(object_name(i.id), i.indid, 14) AS pk_14,
+ index_col(object_name(i.id), i.indid, 15) AS pk_15,
+ index_col(object_name(i.id), i.indid, 16) AS pk_16
+ FROM sysindexes i, sysobjects o
+ WHERE o.id = i.id
+ AND o.id = :table_id
+ AND (i.status & 2048) = 2048
+ AND i.indid BETWEEN 1 AND 254
+ """)
+
+ results = connection.execute(PK_SQL, table_id=table_id)
+ pks = results.fetchone()
+ results.close()
+
+ constrained_columns = []
+ for i in range(1, pks["count"]+1):
+ constrained_columns.append(pks["pk_%i" % (i,)])
+ return {"constrained_columns": constrained_columns,
+ "name": pks["name"]}
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+
+ SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u")
+
+ schemas = connection.execute(SCHEMA_SQL)
+
+ return [s["name"] for s in schemas]
+
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema is None:
schema = self.default_schema_name
- result = connection.execute(
- text("select sysobjects.name from sysobjects, sysusers "
- "where sysobjects.uid=sysusers.uid and "
- "sysusers.name=:schemaname and "
- "sysobjects.type='U'",
- bindparams=[
- bindparam('schemaname', schema)
- ])
- )
- return [r[0] for r in result]
-
- def has_table(self, connection, tablename, schema=None):
+ TABLE_SQL = text("""
+ SELECT o.name AS name
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE u.name = :schema_name
+ AND o.type = 'U'
+ """)
+
+ # Py2K
+ if isinstance(schema, unicode):
+ schema = schema.encode("ascii")
+ # end Py2K
+ tables = connection.execute(TABLE_SQL, schema_name=schema)
+
+ return [t["name"] for t in tables]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
if schema is None:
schema = self.default_schema_name
- result = connection.execute(
- text("select sysobjects.name from sysobjects, sysusers "
- "where sysobjects.uid=sysusers.uid and "
- "sysobjects.name=:tablename and "
- "sysusers.name=:schemaname and "
- "sysobjects.type='U'",
- bindparams=[
- bindparam('tablename', tablename),
- bindparam('schemaname', schema)
- ])
- )
- return result.scalar() is not None
-
- def reflecttable(self, connection, table, include_columns):
- raise NotImplementedError()
+ VIEW_DEF_SQL = text("""
+ SELECT c.text
+ FROM syscomments c JOIN sysobjects o ON c.id = o.id
+ WHERE o.name = :view_name
+ AND o.type = 'V'
+ """)
+
+ # Py2K
+ if isinstance(view_name, unicode):
+ view_name = view_name.encode("ascii")
+ # end Py2K
+ view = connection.execute(VIEW_DEF_SQL, view_name=view_name)
+ return view.scalar()
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+
+ VIEW_SQL = text("""
+ SELECT o.name AS name
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE u.name = :schema_name
+ AND o.type = 'V'
+ """)
+
+ # Py2K
+ if isinstance(schema, unicode):
+ schema = schema.encode("ascii")
+ # end Py2K
+ views = connection.execute(VIEW_SQL, schema_name=schema)
+
+ return [v["name"] for v in views]
+
+ def has_table(self, connection, table_name, schema=None):
+ try:
+ self.get_table_id(connection, table_name, schema)
+ except exc.NoSuchTableError:
+ return False
+ else:
+ return True
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 062c20d84..15bfe03cb 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1729,4 +1729,4 @@ class OptionEngine(Engine):
def _set_has_events(self, value):
self.__dict__['_has_events'] = value
- _has_events = property(_get_has_events, _set_has_events) \ No newline at end of file
+ _has_events = property(_get_has_events, _set_has_events)
diff --git a/test/requirements.py b/test/requirements.py
index 07b2afb2a..d1ee55e7d 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -295,8 +295,8 @@ class DefaultRequirements(SuiteRequirements):
def empty_strings_varchar(self):
"""target database can persist/return an empty string with a varchar."""
- return fails_if("oracle", 'oracle converts empty '
- 'strings to a blank space')
+ return fails_if(["oracle"],
+ 'oracle converts empty strings to a blank space')
@property
def empty_strings_text(self):
@@ -306,6 +306,12 @@ class DefaultRequirements(SuiteRequirements):
return exclusions.open()
@property
+ def unicode_data(self):
+ return skip_if([
+ no_support("sybase", "no unicode driver support")
+ ])
+
+ @property
def unicode_connections(self):
"""Target driver must support some encoding of Unicode across the wire."""
# TODO: expand to exclude MySQLdb versions w/ broken unicode
@@ -338,15 +344,20 @@ class DefaultRequirements(SuiteRequirements):
lambda: not self._has_cextensions(), "C extensions not installed"
)
-
@property
def emulated_lastrowid(self):
""""target dialect retrieves cursor.lastrowid or an equivalent
after an insert() construct executes.
"""
return fails_on_everything_except('mysql+mysqldb', 'mysql+oursql',
- 'sqlite+pysqlite', 'mysql+pymysql',
- 'mssql+pyodbc', 'mssql+mxodbc')
+ 'sqlite+pysqlite', 'mysql+pymysql',
+ 'sybase', 'mssql+pyodbc', 'mssql+mxodbc')
+
+ @property
+ def implements_get_lastrowid(self):
+ return skip_if([
+ no_support('sybase', 'not supported by database'),
+ ])
@property
def dbapi_lastrowid(self):
@@ -373,7 +384,8 @@ class DefaultRequirements(SuiteRequirements):
def reflects_pk_names(self):
"""Target driver reflects the name of primary key constraints."""
- return fails_on_everything_except('postgresql', 'oracle', 'mssql')
+ return fails_on_everything_except('postgresql', 'oracle', 'mssql',
+ 'sybase')
@property
def datetime(self):