diff options
Diffstat (limited to 'lib/sqlalchemy/databases')
-rw-r--r-- | lib/sqlalchemy/databases/__init__.py | 18 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/access.py | 443 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 768 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/information_schema.py | 193 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/informix.py | 493 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/maxdb.py | 1099 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 1771 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mxODBC.py | 60 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 2732 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 904 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 889 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 646 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sybase.py | 875 |
13 files changed, 16 insertions, 10875 deletions
diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index 6588be0ae..16cabd47f 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -4,6 +4,20 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from sqlalchemy.dialects.sqlite import base as sqlite +from sqlalchemy.dialects.postgresql import base as postgresql +postgres = postgresql +from sqlalchemy.dialects.mysql import base as mysql +from sqlalchemy.dialects.oracle import base as oracle +from sqlalchemy.dialects.firebird import base as firebird +from sqlalchemy.dialects.maxdb import base as maxdb +from sqlalchemy.dialects.informix import base as informix +from sqlalchemy.dialects.mssql import base as mssql +from sqlalchemy.dialects.access import base as access +from sqlalchemy.dialects.sybase import base as sybase + + + __all__ = ( 'access', @@ -12,8 +26,8 @@ __all__ = ( 'maxdb', 'mssql', 'mysql', - 'oracle', - 'postgres', + 'postgresql', 'sqlite', + 'oracle', 'sybase', ) diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py deleted file mode 100644 index 56c28b8cc..000000000 --- a/lib/sqlalchemy/databases/access.py +++ /dev/null @@ -1,443 +0,0 @@ -# access.py -# Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk -# Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -from sqlalchemy import sql, schema, types, exc, pool -from sqlalchemy.sql import compiler, expression -from sqlalchemy.engine import default, base - - -class AcNumeric(types.Numeric): - def result_processor(self, dialect): - return None - - def bind_processor(self, dialect): - def process(value): - if value is None: - # Not sure that this exception is needed - return value - else: - return str(value) - return process - - def get_col_spec(self): - return "NUMERIC" - -class AcFloat(types.Float): - def get_col_spec(self): - return "FLOAT" - - def bind_processor(self, dialect): - """By converting to string, we can use Decimal types round-trip.""" - def process(value): - if not value is None: - return str(value) - return None - return process - -class AcInteger(types.Integer): - def get_col_spec(self): - return "INTEGER" - -class AcTinyInteger(types.Integer): - def get_col_spec(self): - return "TINYINT" - -class AcSmallInteger(types.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class AcDateTime(types.DateTime): - def __init__(self, *a, **kw): - super(AcDateTime, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - -class AcDate(types.Date): - def __init__(self, *a, **kw): - super(AcDate, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - -class AcText(types.Text): - def get_col_spec(self): - return "MEMO" - -class AcString(types.String): - def get_col_spec(self): - return "TEXT" + (self.length and ("(%d)" % self.length) or "") - -class AcUnicode(types.Unicode): - def get_col_spec(self): - return "TEXT" + (self.length and ("(%d)" % self.length) or "") - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - return None - -class AcChar(types.CHAR): - def get_col_spec(self): - return "TEXT" + (self.length and ("(%d)" % self.length) or "") - -class AcBinary(types.Binary): - def get_col_spec(self): - return "BINARY" - -class AcBoolean(types.Boolean): - def get_col_spec(self): - return "YESNO" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - -class AcTimeStamp(types.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" - -class AccessExecutionContext(default.DefaultExecutionContext): - def _has_implicit_sequence(self, column): - if column.primary_key and column.autoincrement: - if isinstance(column.type, types.Integer) and not column.foreign_keys: - if column.default is None or (isinstance(column.default, schema.Sequence) and \ - column.default.optional): - return True - return False - - def post_exec(self): - """If we inserted into a row with a COUNTER column, fetch the ID""" - - if self.compiled.isinsert: - tbl = self.compiled.statement.table - if not hasattr(tbl, 'has_sequence'): - tbl.has_sequence = None - for column in tbl.c: - if getattr(column, 'sequence', False) or self._has_implicit_sequence(column): - tbl.has_sequence = column - break - - if bool(tbl.has_sequence): - # TBD: for some reason _last_inserted_ids doesn't exist here - # (but it does at corresponding point in mssql???) - #if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:] - # print "LAST ROW ID", self._last_inserted_ids - - super(AccessExecutionContext, self).post_exec() - - -const, daoEngine = None, None -class AccessDialect(default.DefaultDialect): - colspecs = { - types.Unicode : AcUnicode, - types.Integer : AcInteger, - types.Smallinteger: AcSmallInteger, - types.Numeric : AcNumeric, - types.Float : AcFloat, - types.DateTime : AcDateTime, - types.Date : AcDate, - types.String : AcString, - types.Binary : AcBinary, - types.Boolean : AcBoolean, - types.Text : AcText, - types.CHAR: AcChar, - types.TIMESTAMP: AcTimeStamp, - } - name = 'access' - supports_sane_rowcount = False - supports_sane_multi_rowcount = False - - def type_descriptor(self, typeobj): - newobj = types.adapt_type(typeobj, self.colspecs) - return newobj - - def __init__(self, **params): - super(AccessDialect, self).__init__(**params) - self.text_as_varchar = False - self._dtbs = None - - def dbapi(cls): - import win32com.client, pythoncom - - global const, daoEngine - if const is None: - const = win32com.client.constants - for suffix in (".36", ".35", ".30"): - try: - daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix) - break - except pythoncom.com_error: - pass - else: - raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.") - - import pyodbc as module - return module - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args() - connectors = ["Driver={Microsoft Access Driver (*.mdb)}"] - connectors.append("Dbq=%s" % opts["database"]) - user = opts.get("username", None) - if user: - connectors.append("UID=%s" % user) - connectors.append("PWD=%s" % opts.get("password", "")) - return [[";".join(connectors)], {}] - - def last_inserted_ids(self): - return self.context.last_inserted_ids - - def do_execute(self, cursor, statement, params, **kwargs): - if params == {}: - params = () - super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs) - - def _execute(self, c, statement, parameters): - try: - if parameters == {}: - parameters = () - c.execute(statement, parameters) - self.context.rowcount = c.rowcount - except Exception, e: - raise exc.DBAPIError.instance(statement, parameters, e) - - def has_table(self, connection, tablename, schema=None): - # This approach seems to be more reliable that using DAO - try: - connection.execute('select top 1 * from [%s]' % tablename) - return True - except Exception, e: - return False - - def reflecttable(self, connection, table, include_columns): - # This is defined in the function, as it relies on win32com constants, - # that aren't imported until dbapi method is called - if not hasattr(self, 'ischema_names'): - self.ischema_names = { - const.dbByte: AcBinary, - const.dbInteger: AcInteger, - const.dbLong: AcInteger, - const.dbSingle: AcFloat, - const.dbDouble: AcFloat, - const.dbDate: AcDateTime, - const.dbLongBinary: AcBinary, - const.dbMemo: AcText, - const.dbBoolean: AcBoolean, - const.dbText: AcUnicode, # All Access strings are unicode - const.dbCurrency: AcNumeric, - } - - # A fresh DAO connection is opened for each reflection - # This is necessary, so we get the latest updates - dtbs = daoEngine.OpenDatabase(connection.engine.url.database) - - try: - for tbl in dtbs.TableDefs: - if tbl.Name.lower() == table.name.lower(): - break - else: - raise exc.NoSuchTableError(table.name) - - for col in tbl.Fields: - coltype = self.ischema_names[col.Type] - if col.Type == const.dbText: - coltype = coltype(col.Size) - - colargs = \ - { - 'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField), - } - default = col.DefaultValue - - if col.Attributes & const.dbAutoIncrField: - colargs['default'] = schema.Sequence(col.Name + '_seq') - elif default: - if col.Type == const.dbBoolean: - default = default == 'Yes' and '1' or '0' - colargs['server_default'] = schema.DefaultClause(sql.text(default)) - - table.append_column(schema.Column(col.Name, coltype, **colargs)) - - # TBD: check constraints - - # Find primary key columns first - for idx in tbl.Indexes: - if idx.Primary: - for col in idx.Fields: - thecol = table.c[col.Name] - table.primary_key.add(thecol) - if isinstance(thecol.type, AcInteger) and \ - not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)): - thecol.autoincrement = False - - # Then add other indexes - for idx in tbl.Indexes: - if not idx.Primary: - if len(idx.Fields) == 1: - col = table.c[idx.Fields[0].Name] - if not col.primary_key: - col.index = True - col.unique = idx.Unique - else: - pass # TBD: multi-column indexes - - - for fk in dtbs.Relations: - if fk.ForeignTable != table.name: - continue - scols = [c.ForeignName for c in fk.Fields] - rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields] - table.append_constraint(schema.ForeignKeyConstraint(scols, rcols, link_to_name=True)) - - finally: - dtbs.Close() - - def table_names(self, connection, schema): - # A fresh DAO connection is opened for each reflection - # This is necessary, so we get the latest updates - dtbs = daoEngine.OpenDatabase(connection.engine.url.database) - - names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"] - dtbs.Close() - return names - - -class AccessCompiler(compiler.DefaultCompiler): - extract_map = compiler.DefaultCompiler.extract_map.copy() - extract_map.update ({ - 'month': 'm', - 'day': 'd', - 'year': 'yyyy', - 'second': 's', - 'hour': 'h', - 'doy': 'y', - 'minute': 'n', - 'quarter': 'q', - 'dow': 'w', - 'week': 'ww' - }) - - def visit_select_precolumns(self, select): - """Access puts TOP, it's version of LIMIT here """ - s = select.distinct and "DISTINCT " or "" - if select.limit: - s += "TOP %s " % (select.limit) - if select.offset: - raise exc.InvalidRequestError('Access does not support LIMIT with an offset') - return s - - def limit_clause(self, select): - """Limit in access is after the select keyword""" - return "" - - def binary_operator_string(self, binary): - """Access uses "mod" instead of "%" """ - return binary.operator == '%' and 'mod' or binary.operator - - def label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label() - else: - return super(AccessCompiler, self).label_select_column(select, column, asfrom) - - function_rewrites = {'current_date': 'now', - 'current_timestamp': 'now', - 'length': 'len', - } - def visit_function(self, func): - """Access function names differ from the ANSI SQL names; rewrite common ones""" - func.name = self.function_rewrites.get(func.name, func.name) - return super(AccessCompiler, self).visit_function(func) - - def for_update_clause(self, select): - """FOR UPDATE is not supported by Access; silently ignore""" - return '' - - # Strip schema - def visit_table(self, table, asfrom=False, **kwargs): - if asfrom: - return self.preparer.quote(table.name, table.quote) - else: - return "" - - def visit_join(self, join, asfrom=False, **kwargs): - return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \ - self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) - - def visit_extract(self, extract): - field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) - - -class AccessSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() - - # install a sequence if we have an implicit IDENTITY column - if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ - column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_keys: - if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): - column.sequence = schema.Sequence(column.name + '_seq') - - if not column.nullable: - colspec += " NOT NULL" - - if hasattr(column, 'sequence'): - column.table.has_sequence = column - colspec = self.preparer.format_column(column) + " counter" - else: - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - return colspec - -class AccessSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - - self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False))) - self.execute() - -class AccessDefaultRunner(base.DefaultRunner): - pass - -class AccessIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = compiler.RESERVED_WORDS.copy() - reserved_words.update(['value', 'text']) - def __init__(self, dialect): - super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') - - -dialect = AccessDialect -dialect.poolclass = pool.SingletonThreadPool -dialect.statement_compiler = AccessCompiler -dialect.schemagenerator = AccessSchemaGenerator -dialect.schemadropper = AccessSchemaDropper -dialect.preparer = AccessIdentifierPreparer -dialect.defaultrunner = AccessDefaultRunner -dialect.execution_ctx_cls = AccessExecutionContext diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py deleted file mode 100644 index 8a8d02d4a..000000000 --- a/lib/sqlalchemy/databases/firebird.py +++ /dev/null @@ -1,768 +0,0 @@ -# firebird.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -""" -Firebird backend -================ - -This module implements the Firebird backend, thru the kinterbasdb_ -DBAPI module. - -Firebird dialects ------------------ - -Firebird offers two distinct dialects_ (not to be confused with the -SA ``Dialect`` thing): - -dialect 1 - This is the old syntax and behaviour, inherited from Interbase pre-6.0. - -dialect 3 - This is the newer and supported syntax, introduced in Interbase 6.0. - -From the user point of view, the biggest change is in date/time -handling: under dialect 1, there's a single kind of field, ``DATE`` -with a synonim ``DATETIME``, that holds a `timestamp` value, that is a -date with hour, minute, second. Under dialect 3 there are three kinds, -a ``DATE`` that holds a date, a ``TIME`` that holds a *time of the -day* value and a ``TIMESTAMP``, equivalent to the old ``DATE``. - -The problem is that the dialect of a Firebird database is a property -of the database itself [#]_ (that is, any single database has been -created with one dialect or the other: there is no way to change the -after creation). SQLAlchemy has a single instance of the class that -controls all the connections to a particular kind of database, so it -cannot easily differentiate between the two modes, and in particular -it **cannot** simultaneously talk with two distinct Firebird databases -with different dialects. - -By default this module is biased toward dialect 3, but you can easily -tweak it to handle dialect 1 if needed:: - - from sqlalchemy import types as sqltypes - from sqlalchemy.databases.firebird import FBDate, colspecs, ischema_names - - # Adjust the mapping of the timestamp kind - ischema_names['TIMESTAMP'] = FBDate - colspecs[sqltypes.DateTime] = FBDate, - -Other aspects may be version-specific. You can use the ``server_version_info()`` method -on the ``FBDialect`` class to do whatever is needed:: - - from sqlalchemy.databases.firebird import FBCompiler - - if engine.dialect.server_version_info(connection) < (2,0): - # Change the name of the function ``length`` to use the UDF version - # instead of ``char_length`` - FBCompiler.LENGTH_FUNCTION_NAME = 'strlen' - -Pooling connections -------------------- - -The default strategy used by SQLAlchemy to pool the database connections -in particular cases may raise an ``OperationalError`` with a message -`"object XYZ is in use"`. This happens on Firebird when there are two -connections to the database, one is using, or has used, a particular table -and the other tries to drop or alter the same table. To garantee DDL -operations success Firebird recommend doing them as the single connected user. - -In case your SA application effectively needs to do DDL operations while other -connections are active, the following setting may alleviate the problem:: - - from sqlalchemy import pool - from sqlalchemy.databases.firebird import dialect - - # Force SA to use a single connection per thread - dialect.poolclass = pool.SingletonThreadPool - -RETURNING support ------------------ - -Firebird 2.0 supports returning a result set from inserts, and 2.1 extends -that to deletes and updates. - -To use this pass the column/expression list to the ``firebird_returning`` -parameter when creating the queries:: - - raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), - firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall() - - -.. [#] Well, that is not the whole story, as the client may still ask - a different (lower) dialect... - -.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html -.. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb -""" - - -import datetime, decimal, re - -from sqlalchemy import exc, schema, types as sqltypes, sql, util -from sqlalchemy.engine import base, default - - -_initialized_kb = False - - -class FBNumeric(sqltypes.Numeric): - """Handle ``NUMERIC(precision,scale)`` datatype.""" - - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % { 'precision': self.precision, - 'scale' : self.scale } - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - if self.asdecimal: - return None - else: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - - -class FBFloat(sqltypes.Float): - """Handle ``FLOAT(precision)`` datatype.""" - - def get_col_spec(self): - if not self.precision: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - -class FBInteger(sqltypes.Integer): - """Handle ``INTEGER`` datatype.""" - - def get_col_spec(self): - return "INTEGER" - - -class FBSmallInteger(sqltypes.Smallinteger): - """Handle ``SMALLINT`` datatype.""" - - def get_col_spec(self): - return "SMALLINT" - - -class FBDateTime(sqltypes.DateTime): - """Handle ``TIMESTAMP`` datatype.""" - - def get_col_spec(self): - return "TIMESTAMP" - - def bind_processor(self, dialect): - def process(value): - if value is None or isinstance(value, datetime.datetime): - return value - else: - return datetime.datetime(year=value.year, - month=value.month, - day=value.day) - return process - - -class FBDate(sqltypes.DateTime): - """Handle ``DATE`` datatype.""" - - def get_col_spec(self): - return "DATE" - - -class FBTime(sqltypes.Time): - """Handle ``TIME`` datatype.""" - - def get_col_spec(self): - return "TIME" - - -class FBText(sqltypes.Text): - """Handle ``BLOB SUB_TYPE 1`` datatype (aka *textual* blob).""" - - def get_col_spec(self): - return "BLOB SUB_TYPE 1" - - -class FBString(sqltypes.String): - """Handle ``VARCHAR(length)`` datatype.""" - - def get_col_spec(self): - if self.length: - return "VARCHAR(%(length)s)" % {'length' : self.length} - else: - return "BLOB SUB_TYPE 1" - - -class FBChar(sqltypes.CHAR): - """Handle ``CHAR(length)`` datatype.""" - - def get_col_spec(self): - if self.length: - return "CHAR(%(length)s)" % {'length' : self.length} - else: - return "BLOB SUB_TYPE 1" - - -class FBBinary(sqltypes.Binary): - """Handle ``BLOB SUB_TYPE 0`` datatype (aka *binary* blob).""" - - def get_col_spec(self): - return "BLOB SUB_TYPE 0" - - -class FBBoolean(sqltypes.Boolean): - """Handle boolean values as a ``SMALLINT`` datatype.""" - - def get_col_spec(self): - return "SMALLINT" - - -colspecs = { - sqltypes.Integer : FBInteger, - sqltypes.Smallinteger : FBSmallInteger, - sqltypes.Numeric : FBNumeric, - sqltypes.Float : FBFloat, - sqltypes.DateTime : FBDateTime, - sqltypes.Date : FBDate, - sqltypes.Time : FBTime, - sqltypes.String : FBString, - sqltypes.Binary : FBBinary, - sqltypes.Boolean : FBBoolean, - sqltypes.Text : FBText, - sqltypes.CHAR: FBChar, -} - - -ischema_names = { - 'SHORT': lambda r: FBSmallInteger(), - 'LONG': lambda r: FBInteger(), - 'QUAD': lambda r: FBFloat(), - 'FLOAT': lambda r: FBFloat(), - 'DATE': lambda r: FBDate(), - 'TIME': lambda r: FBTime(), - 'TEXT': lambda r: FBString(r['flen']), - 'INT64': lambda r: FBNumeric(precision=r['fprec'], scale=r['fscale'] * -1), # This generically handles NUMERIC() - 'DOUBLE': lambda r: FBFloat(), - 'TIMESTAMP': lambda r: FBDateTime(), - 'VARYING': lambda r: FBString(r['flen']), - 'CSTRING': lambda r: FBChar(r['flen']), - 'BLOB': lambda r: r['stype']==1 and FBText() or FBBinary() - } - -RETURNING_KW_NAME = 'firebird_returning' - -class FBExecutionContext(default.DefaultExecutionContext): - pass - - -class FBDialect(default.DefaultDialect): - """Firebird dialect""" - name = 'firebird' - supports_sane_rowcount = False - supports_sane_multi_rowcount = False - max_identifier_length = 31 - preexecute_pk_sequences = True - supports_pk_autoincrement = False - - def __init__(self, type_conv=200, concurrency_level=1, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - - self.type_conv = type_conv - self.concurrency_level = concurrency_level - - def dbapi(cls): - import kinterbasdb - return kinterbasdb - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if opts.get('port'): - opts['host'] = "%s/%s" % (opts['host'], opts['port']) - del opts['port'] - opts.update(url.query) - - type_conv = opts.pop('type_conv', self.type_conv) - concurrency_level = opts.pop('concurrency_level', self.concurrency_level) - global _initialized_kb - if not _initialized_kb and self.dbapi is not None: - _initialized_kb = True - self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) - return ([], opts) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def server_version_info(self, connection): - """Get the version of the Firebird server used by a connection. - - Returns a tuple of (`major`, `minor`, `build`), three integers - representing the version of the attached server. - """ - - # This is the simpler approach (the other uses the services api), - # that for backward compatibility reasons returns a string like - # LI-V6.3.3.12981 Firebird 2.0 - # where the first version is a fake one resembling the old - # Interbase signature. This is more than enough for our purposes, - # as this is mainly (only?) used by the testsuite. - - from re import match - - fbconn = connection.connection.connection - version = fbconn.server_version - m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) - if not m: - raise AssertionError("Could not determine version from string '%s'" % version) - return tuple([int(x) for x in m.group(5, 6, 4)]) - - def _normalize_name(self, name): - """Convert the name to lowercase if it is possible""" - - # Remove trailing spaces: FB uses a CHAR() type, - # that is padded with spaces - name = name and name.rstrip() - if name is None: - return None - elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower()): - return name.lower() - else: - return name - - def _denormalize_name(self, name): - """Revert a *normalized* name to its uppercase equivalent""" - - if name is None: - return None - elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): - return name.upper() - else: - return name - - def table_names(self, connection, schema): - """Return a list of *normalized* table names omitting system relations.""" - - s = """ - SELECT r.rdb$relation_name - FROM rdb$relations r - WHERE r.rdb$system_flag=0 - """ - return [self._normalize_name(row[0]) for row in connection.execute(s)] - - def has_table(self, connection, table_name, schema=None): - """Return ``True`` if the given table exists, ignoring the `schema`.""" - - tblqry = """ - SELECT 1 FROM rdb$database - WHERE EXISTS (SELECT rdb$relation_name - FROM rdb$relations - WHERE rdb$relation_name=?) - """ - c = connection.execute(tblqry, [self._denormalize_name(table_name)]) - row = c.fetchone() - if row is not None: - return True - else: - return False - - def has_sequence(self, connection, sequence_name): - """Return ``True`` if the given sequence (generator) exists.""" - - genqry = """ - SELECT 1 FROM rdb$database - WHERE EXISTS (SELECT rdb$generator_name - FROM rdb$generators - WHERE rdb$generator_name=?) - """ - c = connection.execute(genqry, [self._denormalize_name(sequence_name)]) - row = c.fetchone() - if row is not None: - return True - else: - return False - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'Unable to complete network request to host' in str(e) - elif isinstance(e, self.dbapi.ProgrammingError): - msg = str(e) - return ('Invalid connection state' in msg or - 'Invalid cursor state' in msg) - else: - return False - - def reflecttable(self, connection, table, include_columns): - # Query to extract the details of all the fields of the given table - tblqry = """ - SELECT DISTINCT r.rdb$field_name AS fname, - r.rdb$null_flag AS null_flag, - t.rdb$type_name AS ftype, - f.rdb$field_sub_type AS stype, - f.rdb$field_length AS flen, - f.rdb$field_precision AS fprec, - f.rdb$field_scale AS fscale, - COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault - FROM rdb$relation_fields r - JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name - JOIN rdb$types t ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE' - WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? - ORDER BY r.rdb$field_position - """ - # Query to extract the PK/FK constrained fields of the given table - keyqry = """ - SELECT se.rdb$field_name AS fname - FROM rdb$relation_constraints rc - JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name - WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? - """ - # Query to extract the details of each UK/FK of the given table - fkqry = """ - SELECT rc.rdb$constraint_name AS cname, - cse.rdb$field_name AS fname, - ix2.rdb$relation_name AS targetrname, - se.rdb$field_name AS targetfname - FROM rdb$relation_constraints rc - JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name - JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key - JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name - JOIN rdb$index_segments se ON se.rdb$index_name=ix2.rdb$index_name AND se.rdb$field_position=cse.rdb$field_position - WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? - ORDER BY se.rdb$index_name, se.rdb$field_position - """ - # Heuristic-query to determine the generator associated to a PK field - genqry = """ - SELECT trigdep.rdb$depended_on_name AS fgenerator - FROM rdb$dependencies tabdep - JOIN rdb$dependencies trigdep ON (tabdep.rdb$dependent_name=trigdep.rdb$dependent_name - AND trigdep.rdb$depended_on_type=14 - AND trigdep.rdb$dependent_type=2) - JOIN rdb$triggers trig ON (trig.rdb$trigger_name=tabdep.rdb$dependent_name) - WHERE tabdep.rdb$depended_on_name=? - AND tabdep.rdb$depended_on_type=0 - AND trig.rdb$trigger_type=1 - AND tabdep.rdb$field_name=? - AND (SELECT count(*) - FROM rdb$dependencies trigdep2 - WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2 - """ - - tablename = self._denormalize_name(table.name) - - # get primary key fields - c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) - pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()] - - # get all of the fields for this table - c = connection.execute(tblqry, [tablename]) - - found_table = False - while True: - row = c.fetchone() - if row is None: - break - found_table = True - - name = self._normalize_name(row['fname']) - if include_columns and name not in include_columns: - continue - args = [name] - - kw = {} - # get the data type - coltype = ischema_names.get(row['ftype'].rstrip()) - if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % - (str(row['ftype']), name)) - coltype = sqltypes.NULLTYPE - else: - coltype = coltype(row) - args.append(coltype) - - # is it a primary key? - kw['primary_key'] = name in pkfields - - # is it nullable? - kw['nullable'] = not bool(row['null_flag']) - - # does it have a default value? - if row['fdefault'] is not None: - # the value comes down as "DEFAULT 'value'" - assert row['fdefault'].upper().startswith('DEFAULT '), row - defvalue = row['fdefault'][8:] - args.append(schema.DefaultClause(sql.text(defvalue))) - - col = schema.Column(*args, **kw) - if kw['primary_key']: - # if the PK is a single field, try to see if its linked to - # a sequence thru a trigger - if len(pkfields)==1: - genc = connection.execute(genqry, [tablename, row['fname']]) - genr = genc.fetchone() - if genr is not None: - col.sequence = schema.Sequence(self._normalize_name(genr['fgenerator'])) - - table.append_column(col) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - # get the foreign keys - c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) - fks = {} - while True: - row = c.fetchone() - if not row: - break - - cname = self._normalize_name(row['cname']) - try: - fk = fks[cname] - except KeyError: - fks[cname] = fk = ([], []) - rname = self._normalize_name(row['targetrname']) - schema.Table(rname, table.metadata, autoload=True, autoload_with=connection) - fname = self._normalize_name(row['fname']) - refspec = rname + '.' + self._normalize_name(row['targetfname']) - fk[0].append(fname) - fk[1].append(refspec) - - for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True)) - - def do_execute(self, cursor, statement, parameters, **kwargs): - # kinterbase does not accept a None, but wants an empty list - # when there are no arguments. - cursor.execute(statement, parameters or []) - - def do_rollback(self, connection): - # Use the retaining feature, that keeps the transaction going - connection.rollback(True) - - def do_commit(self, connection): - # Use the retaining feature, that keeps the transaction going - connection.commit(True) - - -def _substring(s, start, length=None): - "Helper function to handle Firebird 2 SUBSTRING builtin" - - if length is None: - return "SUBSTRING(%s FROM %s)" % (s, start) - else: - return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) - - -class FBCompiler(sql.compiler.DefaultCompiler): - """Firebird specific idiosincrasies""" - - # Firebird lacks a builtin modulo operator, but there is - # an equivalent function in the ib_udf library. - operators = sql.compiler.DefaultCompiler.operators.copy() - operators.update({ - sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y) - }) - - def visit_alias(self, alias, asfrom=False, **kwargs): - # Override to not use the AS keyword which FB 1.5 does not like - if asfrom: - return self.process(alias.original, asfrom=True, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name)) - else: - return self.process(alias.original, **kwargs) - - functions = sql.compiler.DefaultCompiler.functions.copy() - functions['substring'] = _substring - - def function_argspec(self, func): - if func.clauses: - return self.process(func.clause_expr) - else: - return "" - - def default_from(self): - return " FROM rdb$database" - - def visit_sequence(self, seq): - return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) - - def get_select_precolumns(self, select): - """Called when building a ``SELECT`` statement, position is just - before column list Firebird puts the limit and offset right - after the ``SELECT``... - """ - - result = "" - if select._limit: - result += "FIRST %d " % select._limit - if select._offset: - result +="SKIP %d " % select._offset - if select._distinct: - result += "DISTINCT " - return result - - def limit_clause(self, select): - """Already taken care of in the `get_select_precolumns` method.""" - - return "" - - LENGTH_FUNCTION_NAME = 'char_length' - def function_string(self, func): - """Substitute the ``length`` function. - - On newer FB there is a ``char_length`` function, while older - ones need the ``strlen`` UDF. - """ - - if func.name == 'length': - return self.LENGTH_FUNCTION_NAME + '%(expr)s' - return super(FBCompiler, self).function_string(func) - - def _append_returning(self, text, stmt): - returning_cols = stmt.kwargs[RETURNING_KW_NAME] - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, sql.expression.Selectable): - for co in c.columns: - yield co - else: - yield c - columns = [self.process(c, within_columns_clause=True) - for c in flatten_columnlist(returning_cols)] - text += ' RETURNING ' + ', '.join(columns) - return text - - def visit_update(self, update_stmt): - text = super(FBCompiler, self).visit_update(update_stmt) - if RETURNING_KW_NAME in update_stmt.kwargs: - return self._append_returning(text, update_stmt) - else: - return text - - def visit_insert(self, insert_stmt): - text = super(FBCompiler, self).visit_insert(insert_stmt) - if RETURNING_KW_NAME in insert_stmt.kwargs: - return self._append_returning(text, insert_stmt) - else: - return text - - def visit_delete(self, delete_stmt): - text = super(FBCompiler, self).visit_delete(delete_stmt) - if RETURNING_KW_NAME in delete_stmt.kwargs: - return self._append_returning(text, delete_stmt) - else: - return text - - -class FBSchemaGenerator(sql.compiler.SchemaGenerator): - """Firebird syntactic idiosincrasies""" - - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable or column.primary_key: - colspec += " NOT NULL" - - return colspec - - def visit_sequence(self, sequence): - """Generate a ``CREATE GENERATOR`` statement for the sequence.""" - - if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name): - self.append("CREATE GENERATOR %s" % self.preparer.format_sequence(sequence)) - self.execute() - - -class FBSchemaDropper(sql.compiler.SchemaDropper): - """Firebird syntactic idiosincrasies""" - - def visit_sequence(self, sequence): - """Generate a ``DROP GENERATOR`` statement for the sequence.""" - - if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name): - self.append("DROP GENERATOR %s" % self.preparer.format_sequence(sequence)) - self.execute() - - -class FBDefaultRunner(base.DefaultRunner): - """Firebird specific idiosincrasies""" - - def visit_sequence(self, seq): - """Get the next value from the sequence using ``gen_id()``.""" - - return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \ - self.dialect.identifier_preparer.format_sequence(seq)) - - -RESERVED_WORDS = set( - ["action", "active", "add", "admin", "after", "all", "alter", "and", "any", - "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename", - "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer", - "by", "cache", "cascade", "case", "cast", "char", "character", "character_length", - "char_length", "check", "check_point_len", "check_point_length", "close", "collate", - "collation", "column", "commit", "committed", "compiletime", "computed", "conditional", - "connect", "constraint", "containing", "continue", "count", "create", "cstring", - "current", "current_connection", "current_date", "current_role", "current_time", - "current_timestamp", "current_transaction", "current_user", "cursor", "database", - "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete", - "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct", - "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point", - "escape", "event", "exception", "execute", "exists", "exit", "extern", "external", - "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it", - "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto", - "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour", - "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input", - "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join", - "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile", - "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment", - "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month", - "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric", - "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option", - "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength", - "pages", "page_size", "parameter", "password", "plan", "position", "post_event", - "precision", "prepare", "primary", "privileges", "procedure", "protected", "public", - "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate", - "references", "release", "release", "reserv", "reserving", "restrict", "retain", - "return", "returning_values", "returns", "revoke", "right", "role", "rollback", - "row_count", "runtime", "savepoint", "schema", "second", "segment", "select", - "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint", - "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability", - "starting", "starts", "statement", "static", "statistics", "sub_type", "sum", - "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction", - "translate", "translation", "trigger", "trim", "type", "uncommitted", "union", - "unique", "update", "upper", "user", "using", "value", "values", "varchar", - "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when", - "whenever", "where", "while", "with", "work", "write", "year", "yearday" ]) - - -class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): - """Install Firebird specific reserved words.""" - - reserved_words = RESERVED_WORDS - - def __init__(self, dialect): - super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) - - -dialect = FBDialect -dialect.statement_compiler = FBCompiler -dialect.schemagenerator = FBSchemaGenerator -dialect.schemadropper = FBSchemaDropper -dialect.defaultrunner = FBDefaultRunner -dialect.preparer = FBIdentifierPreparer -dialect.execution_ctx_cls = FBExecutionContext diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py deleted file mode 100644 index a7d4101cd..000000000 --- a/lib/sqlalchemy/databases/information_schema.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -information schema implementation. - -This module is deprecated and will not be present in this form in SQLAlchemy 0.6. - -""" -from sqlalchemy import util - -util.warn_deprecated("the information_schema module is deprecated.") - -import sqlalchemy.sql as sql -import sqlalchemy.exc as exc -from sqlalchemy import select, MetaData, Table, Column, String, Integer -from sqlalchemy.schema import DefaultClause, ForeignKeyConstraint - -ischema = MetaData() - -schemata = Table("schemata", ischema, - Column("catalog_name", String), - Column("schema_name", String), - Column("schema_owner", String), - schema="information_schema") - -tables = Table("tables", ischema, - Column("table_catalog", String), - Column("table_schema", String), - Column("table_name", String), - Column("table_type", String), - schema="information_schema") - -columns = Table("columns", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("is_nullable", Integer), - Column("data_type", String), - Column("ordinal_position", Integer), - Column("character_maximum_length", Integer), - Column("numeric_precision", Integer), - Column("numeric_scale", Integer), - Column("column_default", Integer), - Column("collation_name", String), - schema="information_schema") - -constraints = Table("table_constraints", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("constraint_name", String), - Column("constraint_type", String), - schema="information_schema") - -column_constraints = Table("constraint_column_usage", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("constraint_name", String), - schema="information_schema") - -pg_key_constraints = Table("key_column_usage", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("constraint_name", String), - Column("ordinal_position", Integer), - schema="information_schema") - -#mysql_key_constraints = Table("key_column_usage", ischema, -# Column("table_schema", String), -# Column("table_name", String), -# Column("column_name", String), -# Column("constraint_name", String), -# Column("referenced_table_schema", String), -# Column("referenced_table_name", String), -# Column("referenced_column_name", String), -# schema="information_schema") - -key_constraints = pg_key_constraints - -ref_constraints = Table("referential_constraints", ischema, - Column("constraint_catalog", String), - Column("constraint_schema", String), - Column("constraint_name", String), - Column("unique_constraint_catlog", String), - Column("unique_constraint_schema", String), - Column("unique_constraint_name", String), - Column("match_option", String), - Column("update_rule", String), - Column("delete_rule", String), - schema="information_schema") - - -def table_names(connection, schema): - s = select([tables.c.table_name], tables.c.table_schema==schema) - return [row[0] for row in connection.execute(s)] - - -def reflecttable(connection, table, include_columns, ischema_names): - key_constraints = pg_key_constraints - - if table.schema is not None: - current_schema = table.schema - else: - current_schema = connection.default_schema_name() - - s = select([columns], - sql.and_(columns.c.table_name==table.name, - columns.c.table_schema==current_schema), - order_by=[columns.c.ordinal_position]) - - c = connection.execute(s) - found_table = False - while True: - row = c.fetchone() - if row is None: - break - #print "row! " + repr(row) - # continue - found_table = True - (name, type, nullable, charlen, numericprec, numericscale, default) = ( - row[columns.c.column_name], - row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', - row[columns.c.character_maximum_length], - row[columns.c.numeric_precision], - row[columns.c.numeric_scale], - row[columns.c.column_default] - ) - if include_columns and name not in include_columns: - continue - - args = [] - for a in (charlen, numericprec, numericscale): - if a is not None: - args.append(a) - coltype = ischema_names[type] - #print "coltype " + repr(coltype) + " args " + repr(args) - coltype = coltype(*args) - colargs = [] - if default is not None: - colargs.append(DefaultClause(sql.text(default))) - table.append_column(Column(name, coltype, nullable=nullable, *colargs)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns - # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys - # wont reflect properly. dont see a way around this based on whats available from information_schema - s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)], order_by=[key_constraints.c.ordinal_position]) - s.append_column(column_constraints) - s.append_whereclause(constraints.c.table_name==table.name) - s.append_whereclause(constraints.c.table_schema==current_schema) - colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name, constraints.c.constraint_name, key_constraints.c.ordinal_position] - c = connection.execute(s) - - fks = {} - while True: - row = c.fetchone() - if row is None: - break - (type, constrained_column, referred_schema, referred_table, referred_column, constraint_name, ordinal_position) = ( - row[colmap[0]], - row[colmap[1]], - row[colmap[2]], - row[colmap[3]], - row[colmap[4]], - row[colmap[5]], - row[colmap[6]] - ) - #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) - if type == 'PRIMARY KEY': - table.primary_key.add(table.c[constrained_column]) - elif type == 'FOREIGN KEY': - try: - fk = fks[constraint_name] - except KeyError: - fk = ([], []) - fks[constraint_name] = fk - if current_schema == referred_schema: - referred_schema = table.schema - if referred_schema is not None: - Table(referred_table, table.metadata, autoload=True, schema=referred_schema, autoload_with=connection) - refspec = ".".join([referred_schema, referred_table, referred_column]) - else: - Table(referred_table, table.metadata, autoload=True, autoload_with=connection) - refspec = ".".join([referred_table, referred_column]) - if constrained_column not in fk[0]: - fk[0].append(constrained_column) - if refspec not in fk[1]: - fk[1].append(refspec) - - for name, value in fks.iteritems(): - table.append_constraint(ForeignKeyConstraint(value[0], value[1], name=name)) diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py deleted file mode 100644 index 4476af3b9..000000000 --- a/lib/sqlalchemy/databases/informix.py +++ /dev/null @@ -1,493 +0,0 @@ -# informix.py -# Copyright (C) 2005,2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# coding: gbk -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -import datetime - -from sqlalchemy import sql, schema, exc, pool, util -from sqlalchemy.sql import compiler -from sqlalchemy.engine import default -from sqlalchemy import types as sqltypes - - -# for offset - -class informix_cursor(object): - def __init__( self , con ): - self.__cursor = con.cursor() - self.rowcount = 0 - - def offset( self , n ): - if n > 0: - self.fetchmany( n ) - self.rowcount = self.__cursor.rowcount - n - if self.rowcount < 0: - self.rowcount = 0 - else: - self.rowcount = self.__cursor.rowcount - - def execute( self , sql , params ): - if params is None or len( params ) == 0: - params = [] - - return self.__cursor.execute( sql , params ) - - def __getattr__( self , name ): - if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ): - return getattr( self.__cursor , name ) - -class InfoNumeric(sqltypes.Numeric): - def get_col_spec(self): - if not self.precision: - return 'NUMERIC' - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class InfoInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class InfoSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class InfoDate(sqltypes.Date): - def get_col_spec( self ): - return "DATE" - -class InfoDateTime(sqltypes.DateTime ): - def get_col_spec(self): - return "DATETIME YEAR TO SECOND" - - def bind_processor(self, dialect): - def process(value): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value - return process - -class InfoTime(sqltypes.Time ): - def get_col_spec(self): - return "DATETIME HOUR TO SECOND" - - def bind_processor(self, dialect): - def process(value): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value - return process - - def result_processor(self, dialect): - def process(value): - if isinstance( value , datetime.datetime ): - return value.time() - else: - return value - return process - -class InfoText(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(255)" - -class InfoString(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} - - def bind_processor(self, dialect): - def process(value): - if value == '': - return None - else: - return value - return process - -class InfoChar(sqltypes.CHAR): - def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} - -class InfoBinary(sqltypes.Binary): - def get_col_spec(self): - return "BYTE" - -class InfoBoolean(sqltypes.Boolean): - default_type = 'NUM' - def get_col_spec(self): - return "SMALLINT" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - -colspecs = { - sqltypes.Integer : InfoInteger, - sqltypes.Smallinteger : InfoSmallInteger, - sqltypes.Numeric : InfoNumeric, - sqltypes.Float : InfoNumeric, - sqltypes.DateTime : InfoDateTime, - sqltypes.Date : InfoDate, - sqltypes.Time: InfoTime, - sqltypes.String : InfoString, - sqltypes.Binary : InfoBinary, - sqltypes.Boolean : InfoBoolean, - sqltypes.Text : InfoText, - sqltypes.CHAR: InfoChar, -} - - -ischema_names = { - 0 : InfoString, # CHAR - 1 : InfoSmallInteger, # SMALLINT - 2 : InfoInteger, # INT - 3 : InfoNumeric, # Float - 3 : InfoNumeric, # SmallFloat - 5 : InfoNumeric, # DECIMAL - 6 : InfoInteger, # Serial - 7 : InfoDate, # DATE - 8 : InfoNumeric, # MONEY - 10 : InfoDateTime, # DATETIME - 11 : InfoBinary, # BYTE - 12 : InfoText, # TEXT - 13 : InfoString, # VARCHAR - 15 : InfoString, # NCHAR - 16 : InfoString, # NVARCHAR - 17 : InfoInteger, # INT8 - 18 : InfoInteger, # Serial8 - 43 : InfoString, # LVARCHAR - -1 : InfoBinary, # BLOB - -1 : InfoText, # CLOB -} - - -class InfoExecutionContext(default.DefaultExecutionContext): - # cursor.sqlerrd - # 0 - estimated number of rows returned - # 1 - serial value after insert or ISAM error code - # 2 - number of rows processed - # 3 - estimated cost - # 4 - offset of the error into the SQL statement - # 5 - rowid after insert - def post_exec(self): - if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None: - self._last_inserted_ids = [self.cursor.sqlerrd[1]] - elif hasattr( self.compiled , 'offset' ): - self.cursor.offset( self.compiled.offset ) - super(InfoExecutionContext, self).post_exec() - - def create_cursor( self ): - return informix_cursor( self.connection.connection ) - -class InfoDialect(default.DefaultDialect): - name = 'informix' - default_paramstyle = 'qmark' - # for informix 7.31 - max_identifier_length = 18 - - def __init__(self, use_ansi=True, **kwargs): - self.use_ansi = use_ansi - default.DefaultDialect.__init__(self, **kwargs) - - def dbapi(cls): - import informixdb - return informixdb - dbapi = classmethod(dbapi) - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'closed the connection' in str(e) or 'connection not open' in str(e) - else: - return False - - def do_begin(self , connect ): - cu = connect.cursor() - cu.execute( 'SET LOCK MODE TO WAIT' ) - #cu.execute( 'SET ISOLATION TO REPEATABLE READ' ) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def create_connect_args(self, url): - if url.host: - dsn = '%s@%s' % ( url.database , url.host ) - else: - dsn = url.database - - if url.username: - opt = { 'user':url.username , 'password': url.password } - else: - opt = {} - - return ([dsn], opt) - - def table_names(self, connection, schema): - s = "select tabname from systables" - return [row[0] for row in connection.execute(s)] - - def has_table(self, connection, table_name, schema=None): - cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() ) - return bool( cursor.fetchone() is not None ) - - def reflecttable(self, connection, table, include_columns): - c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() ) - rows = c.fetchall() - if not rows : - raise exc.NoSuchTableError(table.name) - else: - if table.owner is not None: - if table.owner.lower() in [r[0] for r in rows]: - owner = table.owner.lower() - else: - raise AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name)) - else: - if len(rows)==1: - owner = rows[0][0] - else: - raise AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name) - - c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3 - where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? - and t3.tabid = t2.tabid and t3.colno = t1.colno - order by t1.colno""", table.name.lower(), owner ) - rows = c.fetchall() - - if not rows: - raise exc.NoSuchTableError(table.name) - - for name , colattr , collength , default , colno in rows: - name = name.lower() - if include_columns and name not in include_columns: - continue - - # in 7.31, coltype = 0x000 - # ^^-- column type - # ^-- 1 not null , 0 null - nullable , coltype = divmod( colattr , 256 ) - if coltype not in ( 0 , 13 ) and default: - default = default.split()[-1] - - if coltype == 0 or coltype == 13: # char , varchar - coltype = ischema_names.get(coltype, InfoString)(collength) - if default: - default = "'%s'" % default - elif coltype == 5: # decimal - precision , scale = ( collength & 0xFF00 ) >> 8 , collength & 0xFF - if scale == 255: - scale = 0 - coltype = InfoNumeric(precision, scale) - else: - try: - coltype = ischema_names[coltype] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, name)) - coltype = sqltypes.NULLTYPE - - colargs = [] - if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) - - table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs)) - - # FK - c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , - t4.colname as local_column , t7.tabname as remote_table , - t6.colname as remote_column - from sysconstraints as t1 , systables as t2 , - sysindexes as t3 , syscolumns as t4 , - sysreferences as t5 , syscolumns as t6 , systables as t7 , - sysconstraints as t8 , sysindexes as t9 - where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'R' - and t3.tabid = t2.tabid and t3.idxname = t1.idxname - and t4.tabid = t2.tabid and t4.colno = t3.part1 - and t5.constrid = t1.constrid and t8.constrid = t5.primary - and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname - and t7.tabid = t5.ptabid""", table.name.lower(), owner ) - rows = c.fetchall() - fks = {} - for cons_name, cons_type, local_column, remote_table, remote_column in rows: - try: - fk = fks[cons_name] - except KeyError: - fk = ([], []) - fks[cons_name] = fk - refspec = ".".join([remote_table, remote_column]) - schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection) - if local_column not in fk[0]: - fk[0].append(local_column) - if refspec not in fk[1]: - fk[1].append(refspec) - - for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1] , None, link_to_name=True )) - - # PK - c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , - t4.colname as local_column - from sysconstraints as t1 , systables as t2 , - sysindexes as t3 , syscolumns as t4 - where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'P' - and t3.tabid = t2.tabid and t3.idxname = t1.idxname - and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower(), owner ) - rows = c.fetchall() - for cons_name, cons_type, local_column in rows: - table.primary_key.add( table.c[local_column] ) - -class InfoCompiler(compiler.DefaultCompiler): - """Info compiler modifies the lexical structure of Select statements to work under - non-ANSI configured Oracle databases, if the use_ansi flag is False.""" - - def __init__(self, *args, **kwargs): - self.limit = 0 - self.offset = 0 - - compiler.DefaultCompiler.__init__( self , *args, **kwargs ) - - def default_from(self): - return " from systables where tabname = 'systables' " - - def get_select_precolumns( self , select ): - s = select._distinct and "DISTINCT " or "" - # only has limit - if select._limit: - off = select._offset or 0 - s += " FIRST %s " % ( select._limit + off ) - else: - s += "" - return s - - def visit_select(self, select): - if select._offset: - self.offset = select._offset - self.limit = select._limit or 0 - # the column in order by clause must in select too - - def __label( c ): - try: - return c._label.lower() - except: - return '' - - # TODO: dont modify the original select, generate a new one - a = [ __label(c) for c in select._raw_columns ] - for c in select._order_by_clause.clauses: - if ( __label(c) not in a ): - select.append_column( c ) - - return compiler.DefaultCompiler.visit_select(self, select) - - def limit_clause(self, select): - return "" - - def visit_function( self , func ): - if func.name.lower() == 'current_date': - return "today" - elif func.name.lower() == 'current_time': - return "CURRENT HOUR TO SECOND" - elif func.name.lower() in ( 'current_timestamp' , 'now' ): - return "CURRENT YEAR TO SECOND" - else: - return compiler.DefaultCompiler.visit_function( self , func ) - - def visit_clauselist(self, list, **kwargs): - return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None]) - -class InfoSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, first_pk=False): - colspec = self.preparer.format_column(column) - if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ - isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk: - colspec += " SERIAL" - self.has_serial = True - else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - - return colspec - - def post_create_table(self, table): - if hasattr( self , 'has_serial' ): - del self.has_serial - return '' - - def visit_primary_key_constraint(self, constraint): - # for informix 7.31 not support constraint name - name = constraint.name - constraint.name = None - super(InfoSchemaGenerator, self).visit_primary_key_constraint(constraint) - constraint.name = name - - def visit_unique_constraint(self, constraint): - # for informix 7.31 not support constraint name - name = constraint.name - constraint.name = None - super(InfoSchemaGenerator, self).visit_unique_constraint(constraint) - constraint.name = name - - def visit_foreign_key_constraint( self , constraint ): - if constraint.name is not None: - constraint.use_alter = True - else: - super( InfoSchemaGenerator , self ).visit_foreign_key_constraint( constraint ) - - def define_foreign_key(self, constraint): - # for informix 7.31 not support constraint name - if constraint.use_alter: - name = constraint.name - constraint.name = None - self.append( "CONSTRAINT " ) - super(InfoSchemaGenerator, self).define_foreign_key(constraint) - constraint.name = name - if name is not None: - self.append( " CONSTRAINT " + name ) - else: - super(InfoSchemaGenerator, self).define_foreign_key(constraint) - - def visit_index(self, index): - if len( index.columns ) == 1 and index.columns[0].foreign_key: - return - super(InfoSchemaGenerator, self).visit_index(index) - -class InfoIdentifierPreparer(compiler.IdentifierPreparer): - def __init__(self, dialect): - super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") - - def _requires_quotes(self, value): - return False - -class InfoSchemaDropper(compiler.SchemaDropper): - def drop_foreignkey(self, constraint): - if constraint.name is not None: - super( InfoSchemaDropper , self ).drop_foreignkey( constraint ) - -dialect = InfoDialect -poolclass = pool.SingletonThreadPool -dialect.statement_compiler = InfoCompiler -dialect.schemagenerator = InfoSchemaGenerator -dialect.schemadropper = InfoSchemaDropper -dialect.preparer = InfoIdentifierPreparer -dialect.execution_ctx_cls = InfoExecutionContext
\ No newline at end of file diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/databases/maxdb.py deleted file mode 100644 index 693295054..000000000 --- a/lib/sqlalchemy/databases/maxdb.py +++ /dev/null @@ -1,1099 +0,0 @@ -# maxdb.py -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""Support for the MaxDB database. - -TODO: More module docs! MaxDB support is currently experimental. - -Overview --------- - -The ``maxdb`` dialect is **experimental** and has only been tested on 7.6.03.007 -and 7.6.00.037. Of these, **only 7.6.03.007 will work** with SQLAlchemy's ORM. -The earlier version has severe ``LEFT JOIN`` limitations and will return -incorrect results from even very simple ORM queries. - -Only the native Python DB-API is currently supported. ODBC driver support -is a future enhancement. - -Connecting ----------- - -The username is case-sensitive. If you usually connect to the -database with sqlcli and other tools in lower case, you likely need to -use upper case for DB-API. - -Implementation Notes --------------------- - -Also check the DatabaseNotes page on the wiki for detailed information. - -With the 7.6.00.37 driver and Python 2.5, it seems that all DB-API -generated exceptions are broken and can cause Python to crash. - -For 'somecol.in_([])' to work, the IN operator's generation must be changed -to cast 'NULL' to a numeric, i.e. NUM(NULL). The DB-API doesn't accept a -bind parameter there, so that particular generation must inline the NULL value, -which depends on [ticket:807]. - -The DB-API is very picky about where bind params may be used in queries. - -Bind params for some functions (e.g. MOD) need type information supplied. -The dialect does not yet do this automatically. - -Max will occasionally throw up 'bad sql, compile again' exceptions for -perfectly valid SQL. The dialect does not currently handle these, more -research is needed. - -MaxDB 7.5 and Sap DB <= 7.4 reportedly do not support schemas. A very -slightly different version of this dialect would be required to support -those versions, and can easily be added if there is demand. Some other -required components such as an Max-aware 'old oracle style' join compiler -(thetas with (+) outer indicators) are already done and available for -integration- email the devel list if you're interested in working on -this. - -""" -import datetime, itertools, re - -from sqlalchemy import exc, schema, sql, util -from sqlalchemy.sql import operators as sql_operators, expression as sql_expr -from sqlalchemy.sql import compiler, visitors -from sqlalchemy.engine import base as engine_base, default -from sqlalchemy import types as sqltypes - - -__all__ = [ - 'MaxString', 'MaxUnicode', 'MaxChar', 'MaxText', 'MaxInteger', - 'MaxSmallInteger', 'MaxNumeric', 'MaxFloat', 'MaxTimestamp', - 'MaxDate', 'MaxTime', 'MaxBoolean', 'MaxBlob', - ] - - -class _StringType(sqltypes.String): - _type = None - - def __init__(self, length=None, encoding=None, **kw): - super(_StringType, self).__init__(length=length, **kw) - self.encoding = encoding - - def get_col_spec(self): - if self.length is None: - spec = 'LONG' - else: - spec = '%s(%s)' % (self._type, self.length) - - if self.encoding is not None: - spec = ' '.join([spec, self.encoding.upper()]) - return spec - - def bind_processor(self, dialect): - if self.encoding == 'unicode': - return None - else: - def process(value): - if isinstance(value, unicode): - return value.encode(dialect.encoding) - else: - return value - return process - - def result_processor(self, dialect): - def process(value): - while True: - if value is None: - return None - elif isinstance(value, unicode): - return value - elif isinstance(value, str): - if self.convert_unicode or dialect.convert_unicode: - return value.decode(dialect.encoding) - else: - return value - elif hasattr(value, 'read'): - # some sort of LONG, snarf and retry - value = value.read(value.remainingLength()) - continue - else: - # unexpected type, return as-is - return value - return process - - -class MaxString(_StringType): - _type = 'VARCHAR' - - def __init__(self, *a, **kw): - super(MaxString, self).__init__(*a, **kw) - - -class MaxUnicode(_StringType): - _type = 'VARCHAR' - - def __init__(self, length=None, **kw): - super(MaxUnicode, self).__init__(length=length, encoding='unicode') - - -class MaxChar(_StringType): - _type = 'CHAR' - - -class MaxText(_StringType): - _type = 'LONG' - - def __init__(self, *a, **kw): - super(MaxText, self).__init__(*a, **kw) - - def get_col_spec(self): - spec = 'LONG' - if self.encoding is not None: - spec = ' '.join((spec, self.encoding)) - elif self.convert_unicode: - spec = ' '.join((spec, 'UNICODE')) - - return spec - - -class MaxInteger(sqltypes.Integer): - def get_col_spec(self): - return 'INTEGER' - - -class MaxSmallInteger(MaxInteger): - def get_col_spec(self): - return 'SMALLINT' - - -class MaxNumeric(sqltypes.Numeric): - """The FIXED (also NUMERIC, DECIMAL) data type.""" - - def __init__(self, precision=None, scale=None, **kw): - kw.setdefault('asdecimal', True) - super(MaxNumeric, self).__init__(scale=scale, precision=precision, - **kw) - - def bind_processor(self, dialect): - return None - - def get_col_spec(self): - if self.scale and self.precision: - return 'FIXED(%s, %s)' % (self.precision, self.scale) - elif self.precision: - return 'FIXED(%s)' % self.precision - else: - return 'INTEGER' - - -class MaxFloat(sqltypes.Float): - """The FLOAT data type.""" - - def get_col_spec(self): - if self.precision is None: - return 'FLOAT' - else: - return 'FLOAT(%s)' % (self.precision,) - - -class MaxTimestamp(sqltypes.DateTime): - def get_col_spec(self): - return 'TIMESTAMP' - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - elif isinstance(value, basestring): - return value - elif dialect.datetimeformat == 'internal': - ms = getattr(value, 'microsecond', 0) - return value.strftime("%Y%m%d%H%M%S" + ("%06u" % ms)) - elif dialect.datetimeformat == 'iso': - ms = getattr(value, 'microsecond', 0) - return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms)) - else: - raise exc.InvalidRequestError( - "datetimeformat '%s' is not supported." % ( - dialect.datetimeformat,)) - return process - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - elif dialect.datetimeformat == 'internal': - return datetime.datetime( - *[int(v) - for v in (value[0:4], value[4:6], value[6:8], - value[8:10], value[10:12], value[12:14], - value[14:])]) - elif dialect.datetimeformat == 'iso': - return datetime.datetime( - *[int(v) - for v in (value[0:4], value[5:7], value[8:10], - value[11:13], value[14:16], value[17:19], - value[20:])]) - else: - raise exc.InvalidRequestError( - "datetimeformat '%s' is not supported." % ( - dialect.datetimeformat,)) - return process - - -class MaxDate(sqltypes.Date): - def get_col_spec(self): - return 'DATE' - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - elif isinstance(value, basestring): - return value - elif dialect.datetimeformat == 'internal': - return value.strftime("%Y%m%d") - elif dialect.datetimeformat == 'iso': - return value.strftime("%Y-%m-%d") - else: - raise exc.InvalidRequestError( - "datetimeformat '%s' is not supported." % ( - dialect.datetimeformat,)) - return process - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - elif dialect.datetimeformat == 'internal': - return datetime.date( - *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) - elif dialect.datetimeformat == 'iso': - return datetime.date( - *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) - else: - raise exc.InvalidRequestError( - "datetimeformat '%s' is not supported." % ( - dialect.datetimeformat,)) - return process - - -class MaxTime(sqltypes.Time): - def get_col_spec(self): - return 'TIME' - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - elif isinstance(value, basestring): - return value - elif dialect.datetimeformat == 'internal': - return value.strftime("%H%M%S") - elif dialect.datetimeformat == 'iso': - return value.strftime("%H-%M-%S") - else: - raise exc.InvalidRequestError( - "datetimeformat '%s' is not supported." % ( - dialect.datetimeformat,)) - return process - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - elif dialect.datetimeformat == 'internal': - t = datetime.time( - *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) - return t - elif dialect.datetimeformat == 'iso': - return datetime.time( - *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) - else: - raise exc.InvalidRequestError( - "datetimeformat '%s' is not supported." % ( - dialect.datetimeformat,)) - return process - - -class MaxBoolean(sqltypes.Boolean): - def get_col_spec(self): - return 'BOOLEAN' - - -class MaxBlob(sqltypes.Binary): - def get_col_spec(self): - return 'LONG BYTE' - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - else: - return str(value) - return process - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - else: - return value.read(value.remainingLength()) - return process - - -colspecs = { - sqltypes.Integer: MaxInteger, - sqltypes.Smallinteger: MaxSmallInteger, - sqltypes.Numeric: MaxNumeric, - sqltypes.Float: MaxFloat, - sqltypes.DateTime: MaxTimestamp, - sqltypes.Date: MaxDate, - sqltypes.Time: MaxTime, - sqltypes.String: MaxString, - sqltypes.Binary: MaxBlob, - sqltypes.Boolean: MaxBoolean, - sqltypes.Text: MaxText, - sqltypes.CHAR: MaxChar, - sqltypes.TIMESTAMP: MaxTimestamp, - sqltypes.BLOB: MaxBlob, - sqltypes.Unicode: MaxUnicode, - } - -ischema_names = { - 'boolean': MaxBoolean, - 'char': MaxChar, - 'character': MaxChar, - 'date': MaxDate, - 'fixed': MaxNumeric, - 'float': MaxFloat, - 'int': MaxInteger, - 'integer': MaxInteger, - 'long binary': MaxBlob, - 'long unicode': MaxText, - 'long': MaxText, - 'long': MaxText, - 'smallint': MaxSmallInteger, - 'time': MaxTime, - 'timestamp': MaxTimestamp, - 'varchar': MaxString, - } - - -class MaxDBExecutionContext(default.DefaultExecutionContext): - def post_exec(self): - # DB-API bug: if there were any functions as values, - # then do another select and pull CURRVAL from the - # autoincrement column's implicit sequence... ugh - if self.compiled.isinsert and not self.executemany: - table = self.compiled.statement.table - index, serial_col = _autoserial_column(table) - - if serial_col and (not self.compiled._safeserial or - not(self._last_inserted_ids) or - self._last_inserted_ids[index] in (None, 0)): - if table.schema: - sql = "SELECT %s.CURRVAL FROM DUAL" % ( - self.compiled.preparer.format_table(table)) - else: - sql = "SELECT CURRENT_SCHEMA.%s.CURRVAL FROM DUAL" % ( - self.compiled.preparer.format_table(table)) - - if self.connection.engine._should_log_info: - self.connection.engine.logger.info(sql) - rs = self.cursor.execute(sql) - id = rs.fetchone()[0] - - if self.connection.engine._should_log_debug: - self.connection.engine.logger.debug([id]) - if not self._last_inserted_ids: - # This shouldn't ever be > 1? Right? - self._last_inserted_ids = \ - [None] * len(table.primary_key.columns) - self._last_inserted_ids[index] = id - - super(MaxDBExecutionContext, self).post_exec() - - def get_result_proxy(self): - if self.cursor.description is not None: - for column in self.cursor.description: - if column[1] in ('Long Binary', 'Long', 'Long Unicode'): - return MaxDBResultProxy(self) - return engine_base.ResultProxy(self) - - -class MaxDBCachedColumnRow(engine_base.RowProxy): - """A RowProxy that only runs result_processors once per column.""" - - def __init__(self, parent, row): - super(MaxDBCachedColumnRow, self).__init__(parent, row) - self.columns = {} - self._row = row - self._parent = parent - - def _get_col(self, key): - if key not in self.columns: - self.columns[key] = self._parent._get_col(self._row, key) - return self.columns[key] - - def __iter__(self): - for i in xrange(len(self._row)): - yield self._get_col(i) - - def __repr__(self): - return repr(list(self)) - - def __eq__(self, other): - return ((other is self) or - (other == tuple([self._get_col(key) - for key in xrange(len(self._row))]))) - def __getitem__(self, key): - if isinstance(key, slice): - indices = key.indices(len(self._row)) - return tuple([self._get_col(i) for i in xrange(*indices)]) - else: - return self._get_col(key) - - def __getattr__(self, name): - try: - return self._get_col(name) - except KeyError: - raise AttributeError(name) - - -class MaxDBResultProxy(engine_base.ResultProxy): - _process_row = MaxDBCachedColumnRow - - -class MaxDBDialect(default.DefaultDialect): - name = 'maxdb' - supports_alter = True - supports_unicode_statements = True - max_identifier_length = 32 - supports_sane_rowcount = True - supports_sane_multi_rowcount = False - preexecute_pk_sequences = True - - # MaxDB-specific - datetimeformat = 'internal' - - def __init__(self, _raise_known_sql_errors=False, **kw): - super(MaxDBDialect, self).__init__(**kw) - self._raise_known = _raise_known_sql_errors - - if self.dbapi is None: - self.dbapi_type_map = {} - else: - self.dbapi_type_map = { - 'Long Binary': MaxBlob(), - 'Long byte_t': MaxBlob(), - 'Long Unicode': MaxText(), - 'Timestamp': MaxTimestamp(), - 'Date': MaxDate(), - 'Time': MaxTime(), - datetime.datetime: MaxTimestamp(), - datetime.date: MaxDate(), - datetime.time: MaxTime(), - } - - def dbapi(cls): - from sapdb import dbapi as _dbapi - return _dbapi - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - opts.update(url.query) - return [], opts - - def type_descriptor(self, typeobj): - if isinstance(typeobj, type): - typeobj = typeobj() - if isinstance(typeobj, sqltypes.Unicode): - return typeobj.adapt(MaxUnicode) - else: - return sqltypes.adapt_type(typeobj, colspecs) - - def do_execute(self, cursor, statement, parameters, context=None): - res = cursor.execute(statement, parameters) - if isinstance(res, int) and context is not None: - context._rowcount = res - - def do_release_savepoint(self, connection, name): - # Does MaxDB truly support RELEASE SAVEPOINT <id>? All my attempts - # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS - # BEGIN SQLSTATE: I7065" - # Note that ROLLBACK TO works fine. In theory, a RELEASE should - # just free up some transactional resources early, before the overall - # COMMIT/ROLLBACK so omitting it should be relatively ok. - pass - - def get_default_schema_name(self, connection): - try: - return self._default_schema_name - except AttributeError: - name = self.identifier_preparer._normalize_name( - connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) - self._default_schema_name = name - return name - - def has_table(self, connection, table_name, schema=None): - denormalize = self.identifier_preparer._denormalize_name - bind = [denormalize(table_name)] - if schema is None: - sql = ("SELECT tablename FROM TABLES " - "WHERE TABLES.TABLENAME=? AND" - " TABLES.SCHEMANAME=CURRENT_SCHEMA ") - else: - sql = ("SELECT tablename FROM TABLES " - "WHERE TABLES.TABLENAME = ? AND" - " TABLES.SCHEMANAME=? ") - bind.append(denormalize(schema)) - - rp = connection.execute(sql, bind) - found = bool(rp.fetchone()) - rp.close() - return found - - def table_names(self, connection, schema): - if schema is None: - sql = (" SELECT TABLENAME FROM TABLES WHERE " - " SCHEMANAME=CURRENT_SCHEMA ") - rs = connection.execute(sql) - else: - sql = (" SELECT TABLENAME FROM TABLES WHERE " - " SCHEMANAME=? ") - matchname = self.identifier_preparer._denormalize_name(schema) - rs = connection.execute(sql, matchname) - normalize = self.identifier_preparer._normalize_name - return [normalize(row[0]) for row in rs] - - def reflecttable(self, connection, table, include_columns): - denormalize = self.identifier_preparer._denormalize_name - normalize = self.identifier_preparer._normalize_name - - st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' - ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' - 'FROM COLUMNS ' - 'WHERE TABLENAME=? AND SCHEMANAME=%s ' - 'ORDER BY POS') - - fk = ('SELECT COLUMNNAME, FKEYNAME, ' - ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' - ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' - ' THEN 1 ELSE 0 END) AS in_schema ' - 'FROM FOREIGNKEYCOLUMNS ' - 'WHERE TABLENAME=? AND SCHEMANAME=%s ' - 'ORDER BY FKEYNAME ') - - params = [denormalize(table.name)] - if not table.schema: - st = st % 'CURRENT_SCHEMA' - fk = fk % 'CURRENT_SCHEMA' - else: - st = st % '?' - fk = fk % '?' - params.append(denormalize(table.schema)) - - rows = connection.execute(st, params).fetchall() - if not rows: - raise exc.NoSuchTableError(table.fullname) - - include_columns = set(include_columns or []) - - for row in rows: - (name, mode, col_type, encoding, length, scale, - nullable, constant_def, func_def) = row - - name = normalize(name) - - if include_columns and name not in include_columns: - continue - - type_args, type_kw = [], {} - if col_type == 'FIXED': - type_args = length, scale - # Convert FIXED(10) DEFAULT SERIAL to our Integer - if (scale == 0 and - func_def is not None and func_def.startswith('SERIAL')): - col_type = 'INTEGER' - type_args = length, - elif col_type in 'FLOAT': - type_args = length, - elif col_type in ('CHAR', 'VARCHAR'): - type_args = length, - type_kw['encoding'] = encoding - elif col_type == 'LONG': - type_kw['encoding'] = encoding - - try: - type_cls = ischema_names[col_type.lower()] - type_instance = type_cls(*type_args, **type_kw) - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (col_type, name)) - type_instance = sqltypes.NullType - - col_kw = {'autoincrement': False} - col_kw['nullable'] = (nullable == 'YES') - col_kw['primary_key'] = (mode == 'KEY') - - if func_def is not None: - if func_def.startswith('SERIAL'): - if col_kw['primary_key']: - # No special default- let the standard autoincrement - # support handle SERIAL pk columns. - col_kw['autoincrement'] = True - else: - # strip current numbering - col_kw['server_default'] = schema.DefaultClause( - sql.text('SERIAL')) - col_kw['autoincrement'] = True - else: - col_kw['server_default'] = schema.DefaultClause( - sql.text(func_def)) - elif constant_def is not None: - col_kw['server_default'] = schema.DefaultClause(sql.text( - "'%s'" % constant_def.replace("'", "''"))) - - table.append_column(schema.Column(name, type_instance, **col_kw)) - - fk_sets = itertools.groupby(connection.execute(fk, params), - lambda row: row.FKEYNAME) - for fkeyname, fkey in fk_sets: - fkey = list(fkey) - if include_columns: - key_cols = set([r.COLUMNNAME for r in fkey]) - if key_cols != include_columns: - continue - - columns, referants = [], [] - quote = self.identifier_preparer._maybe_quote_identifier - - for row in fkey: - columns.append(normalize(row.COLUMNNAME)) - if table.schema or not row.in_schema: - referants.append('.'.join( - [quote(normalize(row[c])) - for c in ('REFSCHEMANAME', 'REFTABLENAME', - 'REFCOLUMNNAME')])) - else: - referants.append('.'.join( - [quote(normalize(row[c])) - for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) - - constraint_kw = {'name': fkeyname.lower()} - if fkey[0].RULE is not None: - rule = fkey[0].RULE - if rule.startswith('DELETE '): - rule = rule[7:] - constraint_kw['ondelete'] = rule - - table_kw = {} - if table.schema or not row.in_schema: - table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) - - ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), - table_kw.get('schema')) - if ref_key not in table.metadata.tables: - schema.Table(normalize(fkey[0].REFTABLENAME), - table.metadata, - autoload=True, autoload_with=connection, - **table_kw) - - constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True, - **constraint_kw) - table.append_constraint(constraint) - - def has_sequence(self, connection, name): - # [ticket:726] makes this schema-aware. - denormalize = self.identifier_preparer._denormalize_name - sql = ("SELECT sequence_name FROM SEQUENCES " - "WHERE SEQUENCE_NAME=? ") - - rp = connection.execute(sql, denormalize(name)) - found = bool(rp.fetchone()) - rp.close() - return found - - -class MaxDBCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() - operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y) - - function_conversion = { - 'CURRENT_DATE': 'DATE', - 'CURRENT_TIME': 'TIME', - 'CURRENT_TIMESTAMP': 'TIMESTAMP', - } - - # These functions must be written without parens when called with no - # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL' - bare_functions = set([ - 'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP', - 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP', - 'UTCDATE', 'UTCDIFF']) - - def default_from(self): - return ' FROM DUAL' - - def for_update_clause(self, select): - clause = select.for_update - if clause is True: - return " WITH LOCK EXCLUSIVE" - elif clause is None: - return "" - elif clause == "read": - return " WITH LOCK" - elif clause == "ignore": - return " WITH LOCK (IGNORE) EXCLUSIVE" - elif clause == "nowait": - return " WITH LOCK (NOWAIT) EXCLUSIVE" - elif isinstance(clause, basestring): - return " WITH LOCK %s" % clause.upper() - elif not clause: - return "" - else: - return " WITH LOCK EXCLUSIVE" - - def apply_function_parens(self, func): - if func.name.upper() in self.bare_functions: - return len(func.clauses) > 0 - else: - return True - - def visit_function(self, fn, **kw): - transform = self.function_conversion.get(fn.name.upper(), None) - if transform: - fn = fn._clone() - fn.name = transform - return super(MaxDBCompiler, self).visit_function(fn, **kw) - - def visit_cast(self, cast, **kwargs): - # MaxDB only supports casts * to NUMERIC, * to VARCHAR or - # date/time to VARCHAR. Casts of LONGs will fail. - if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)): - return "NUM(%s)" % self.process(cast.clause) - elif isinstance(cast.type, sqltypes.String): - return "CHR(%s)" % self.process(cast.clause) - else: - return self.process(cast.clause) - - def visit_sequence(self, sequence): - if sequence.optional: - return None - else: - return (self.dialect.identifier_preparer.format_sequence(sequence) + - ".NEXTVAL") - - class ColumnSnagger(visitors.ClauseVisitor): - def __init__(self): - self.count = 0 - self.column = None - def visit_column(self, column): - self.column = column - self.count += 1 - - def _find_labeled_columns(self, columns, use_labels=False): - labels = {} - for column in columns: - if isinstance(column, basestring): - continue - snagger = self.ColumnSnagger() - snagger.traverse(column) - if snagger.count == 1: - if isinstance(column, sql_expr._Label): - labels[unicode(snagger.column)] = column.name - elif use_labels: - labels[unicode(snagger.column)] = column._label - - return labels - - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) - - # ORDER BY clauses in DISTINCT queries must reference aliased - # inner columns by alias name, not true column name. - if order_by and getattr(select, '_distinct', False): - labels = self._find_labeled_columns(select.inner_columns, - select.use_labels) - if labels: - for needs_alias in labels.keys(): - r = re.compile(r'(^| )(%s)(,| |$)' % - re.escape(needs_alias)) - order_by = r.sub((r'\1%s\3' % labels[needs_alias]), - order_by) - - # No ORDER BY in subqueries. - if order_by: - if self.is_subquery(): - # It's safe to simply drop the ORDER BY if there is no - # LIMIT. Right? Other dialects seem to get away with - # dropping order. - if select._limit: - raise exc.InvalidRequestError( - "MaxDB does not support ORDER BY in subqueries") - else: - return "" - return " ORDER BY " + order_by - else: - return "" - - def get_select_precolumns(self, select): - # Convert a subquery's LIMIT to TOP - sql = select._distinct and 'DISTINCT ' or '' - if self.is_subquery() and select._limit: - if select._offset: - raise exc.InvalidRequestError( - 'MaxDB does not support LIMIT with an offset.') - sql += 'TOP %s ' % select._limit - return sql - - def limit_clause(self, select): - # The docs say offsets are supported with LIMIT. But they're not. - # TODO: maybe emulate by adding a ROWNO/ROWNUM predicate? - if self.is_subquery(): - # sub queries need TOP - return '' - elif select._offset: - raise exc.InvalidRequestError( - 'MaxDB does not support LIMIT with an offset.') - else: - return ' \n LIMIT %s' % (select._limit,) - - def visit_insert(self, insert): - self.isinsert = True - self._safeserial = True - - colparams = self._get_colparams(insert) - for value in (insert.parameters or {}).itervalues(): - if isinstance(value, sql_expr.Function): - self._safeserial = False - break - - return ''.join(('INSERT INTO ', - self.preparer.format_table(insert.table), - ' (', - ', '.join([self.preparer.format_column(c[0]) - for c in colparams]), - ') VALUES (', - ', '.join([c[1] for c in colparams]), - ')')) - - -class MaxDBDefaultRunner(engine_base.DefaultRunner): - def visit_sequence(self, seq): - if seq.optional: - return None - return self.execute_string("SELECT %s.NEXTVAL FROM DUAL" % ( - self.dialect.identifier_preparer.format_sequence(seq))) - - -class MaxDBIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = set([ - 'abs', 'absolute', 'acos', 'adddate', 'addtime', 'all', 'alpha', - 'alter', 'any', 'ascii', 'asin', 'atan', 'atan2', 'avg', 'binary', - 'bit', 'boolean', 'byte', 'case', 'ceil', 'ceiling', 'char', - 'character', 'check', 'chr', 'column', 'concat', 'constraint', 'cos', - 'cosh', 'cot', 'count', 'cross', 'curdate', 'current', 'curtime', - 'database', 'date', 'datediff', 'day', 'dayname', 'dayofmonth', - 'dayofweek', 'dayofyear', 'dec', 'decimal', 'decode', 'default', - 'degrees', 'delete', 'digits', 'distinct', 'double', 'except', - 'exists', 'exp', 'expand', 'first', 'fixed', 'float', 'floor', 'for', - 'from', 'full', 'get_objectname', 'get_schema', 'graphic', 'greatest', - 'group', 'having', 'hex', 'hextoraw', 'hour', 'ifnull', 'ignore', - 'index', 'initcap', 'inner', 'insert', 'int', 'integer', 'internal', - 'intersect', 'into', 'join', 'key', 'last', 'lcase', 'least', 'left', - 'length', 'lfill', 'list', 'ln', 'locate', 'log', 'log10', 'long', - 'longfile', 'lower', 'lpad', 'ltrim', 'makedate', 'maketime', - 'mapchar', 'max', 'mbcs', 'microsecond', 'min', 'minute', 'mod', - 'month', 'monthname', 'natural', 'nchar', 'next', 'no', 'noround', - 'not', 'now', 'null', 'num', 'numeric', 'object', 'of', 'on', - 'order', 'packed', 'pi', 'power', 'prev', 'primary', 'radians', - 'real', 'reject', 'relative', 'replace', 'rfill', 'right', 'round', - 'rowid', 'rowno', 'rpad', 'rtrim', 'second', 'select', 'selupd', - 'serial', 'set', 'show', 'sign', 'sin', 'sinh', 'smallint', 'some', - 'soundex', 'space', 'sqrt', 'stamp', 'statistics', 'stddev', - 'subdate', 'substr', 'substring', 'subtime', 'sum', 'sysdba', - 'table', 'tan', 'tanh', 'time', 'timediff', 'timestamp', 'timezone', - 'to', 'toidentifier', 'transaction', 'translate', 'trim', 'trunc', - 'truncate', 'ucase', 'uid', 'unicode', 'union', 'update', 'upper', - 'user', 'usergroup', 'using', 'utcdate', 'utcdiff', 'value', 'values', - 'varchar', 'vargraphic', 'variance', 'week', 'weekofyear', 'when', - 'where', 'with', 'year', 'zoned' ]) - - def _normalize_name(self, name): - if name is None: - return None - if name.isupper(): - lc_name = name.lower() - if not self._requires_quotes(lc_name): - return lc_name - return name - - def _denormalize_name(self, name): - if name is None: - return None - elif (name.islower() and - not self._requires_quotes(name)): - return name.upper() - else: - return name - - def _maybe_quote_identifier(self, name): - if self._requires_quotes(name): - return self.quote_identifier(name) - else: - return name - - -class MaxDBSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kw): - colspec = [self.preparer.format_column(column), - column.type.dialect_impl(self.dialect).get_col_spec()] - - if not column.nullable: - colspec.append('NOT NULL') - - default = column.default - default_str = self.get_column_default_string(column) - - # No DDL default for columns specified with non-optional sequence- - # this defaulting behavior is entirely client-side. (And as a - # consequence, non-reflectable.) - if (default and isinstance(default, schema.Sequence) and - not default.optional): - pass - # Regular default - elif default_str is not None: - colspec.append('DEFAULT %s' % default_str) - # Assign DEFAULT SERIAL heuristically - elif column.primary_key and column.autoincrement: - # For SERIAL on a non-primary key member, use - # DefaultClause(text('SERIAL')) - try: - first = [c for c in column.table.primary_key.columns - if (c.autoincrement and - (isinstance(c.type, sqltypes.Integer) or - (isinstance(c.type, MaxNumeric) and - c.type.precision)) and - not c.foreign_keys)].pop(0) - if column is first: - colspec.append('DEFAULT SERIAL') - except IndexError: - pass - return ' '.join(colspec) - - def get_column_default_string(self, column): - if isinstance(column.server_default, schema.DefaultClause): - if isinstance(column.default.arg, basestring): - if isinstance(column.type, sqltypes.Integer): - return str(column.default.arg) - else: - return "'%s'" % column.default.arg - else: - return unicode(self._compile(column.default.arg, None)) - else: - return None - - def visit_sequence(self, sequence): - """Creates a SEQUENCE. - - TODO: move to module doc? - - start - With an integer value, set the START WITH option. - - increment - An integer value to increment by. Default is the database default. - - maxdb_minvalue - maxdb_maxvalue - With an integer value, sets the corresponding sequence option. - - maxdb_no_minvalue - maxdb_no_maxvalue - Defaults to False. If true, sets the corresponding sequence option. - - maxdb_cycle - Defaults to False. If true, sets the CYCLE option. - - maxdb_cache - With an integer value, sets the CACHE option. - - maxdb_no_cache - Defaults to False. If true, sets NOCACHE. - """ - - if (not sequence.optional and - (not self.checkfirst or - not self.dialect.has_sequence(self.connection, sequence.name))): - - ddl = ['CREATE SEQUENCE', - self.preparer.format_sequence(sequence)] - - sequence.increment = 1 - - if sequence.increment is not None: - ddl.extend(('INCREMENT BY', str(sequence.increment))) - - if sequence.start is not None: - ddl.extend(('START WITH', str(sequence.start))) - - opts = dict([(pair[0][6:].lower(), pair[1]) - for pair in sequence.kwargs.items() - if pair[0].startswith('maxdb_')]) - - if 'maxvalue' in opts: - ddl.extend(('MAXVALUE', str(opts['maxvalue']))) - elif opts.get('no_maxvalue', False): - ddl.append('NOMAXVALUE') - if 'minvalue' in opts: - ddl.extend(('MINVALUE', str(opts['minvalue']))) - elif opts.get('no_minvalue', False): - ddl.append('NOMINVALUE') - - if opts.get('cycle', False): - ddl.append('CYCLE') - - if 'cache' in opts: - ddl.extend(('CACHE', str(opts['cache']))) - elif opts.get('no_cache', False): - ddl.append('NOCACHE') - - self.append(' '.join(ddl)) - self.execute() - - -class MaxDBSchemaDropper(compiler.SchemaDropper): - def visit_sequence(self, sequence): - if (not sequence.optional and - (not self.checkfirst or - self.dialect.has_sequence(self.connection, sequence.name))): - self.append("DROP SEQUENCE %s" % - self.preparer.format_sequence(sequence)) - self.execute() - - -def _autoserial_column(table): - """Finds the effective DEFAULT SERIAL column of a Table, if any.""" - - for index, col in enumerate(table.primary_key.columns): - if (isinstance(col.type, (sqltypes.Integer, sqltypes.Numeric)) and - col.autoincrement): - if isinstance(col.default, schema.Sequence): - if col.default.optional: - return index, col - elif (col.default is None or - (not isinstance(col.server_default, schema.DefaultClause))): - return index, col - - return None, None - -dialect = MaxDBDialect -dialect.preparer = MaxDBIdentifierPreparer -dialect.statement_compiler = MaxDBCompiler -dialect.schemagenerator = MaxDBSchemaGenerator -dialect.schemadropper = MaxDBSchemaDropper -dialect.defaultrunner = MaxDBDefaultRunner -dialect.execution_ctx_cls = MaxDBExecutionContext
\ No newline at end of file diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py deleted file mode 100644 index d963b7477..000000000 --- a/lib/sqlalchemy/databases/mssql.py +++ /dev/null @@ -1,1771 +0,0 @@ -# mssql.py - -"""Support for the Microsoft SQL Server database. - -Driver ------- - -The MSSQL dialect will work with three different available drivers: - -* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded - driver. - -* *pymssql* - http://pymssql.sourceforge.net/ - -* *adodbapi* - http://adodbapi.sourceforge.net/ - -Drivers are loaded in the order listed above based on availability. - -If you need to load a specific driver pass ``module_name`` when -creating the engine:: - - engine = create_engine('mssql://dsn', module_name='pymssql') - -``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and -``adodbapi``. - -Currently the pyodbc driver offers the greatest level of -compatibility. - -Connecting ----------- - -Connecting with create_engine() uses the standard URL approach of -``mssql://user:pass@host/dbname[?key=value&key=value...]``. - -If the database name is present, the tokens are converted to a -connection string with the specified values. If the database is not -present, then the host token is taken directly as the DSN name. - -Examples of pyodbc connection string URLs: - -* *mssql://mydsn* - connects using the specified DSN named ``mydsn``. - The connection string that is created will appear like:: - - dsn=mydsn;TrustedConnection=Yes - -* *mssql://user:pass@mydsn* - connects using the DSN named - ``mydsn`` passing in the ``UID`` and ``PWD`` information. The - connection string that is created will appear like:: - - dsn=mydsn;UID=user;PWD=pass - -* *mssql://user:pass@mydsn/?LANGUAGE=us_english* - connects - using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` - information, plus the additional connection configuration option - ``LANGUAGE``. The connection string that is created will appear - like:: - - dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english - -* *mssql://user:pass@host/db* - connects using a connection string - dynamically created that would appear like:: - - DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass - -* *mssql://user:pass@host:123/db* - connects using a connection - string that is dynamically created, which also includes the port - information using the comma syntax. If your connection string - requires the port information to be passed as a ``port`` keyword - see the next example. This will create the following connection - string:: - - DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass - -* *mssql://user:pass@host/db?port=123* - connects using a connection - string that is dynamically created that includes the port - information as a separate ``port`` keyword. This will create the - following connection string:: - - DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123 - -If you require a connection string that is outside the options -presented above, use the ``odbc_connect`` keyword to pass in a -urlencoded connection string. What gets passed in will be urldecoded -and passed directly. - -For example:: - - mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb - -would create the following connection string:: - - dsn=mydsn;Database=db - -Encoding your connection string can be easily accomplished through -the python shell. For example:: - - >>> import urllib - >>> urllib.quote_plus('dsn=mydsn;Database=db') - 'dsn%3Dmydsn%3BDatabase%3Ddb' - -Additional arguments which may be specified either as query string -arguments on the URL, or as keyword argument to -:func:`~sqlalchemy.create_engine()` are: - -* *auto_identity_insert* - enables support for IDENTITY inserts by - automatically turning IDENTITY INSERT ON and OFF as required. - Defaults to ``True``. - -* *query_timeout* - allows you to override the default query timeout. - Defaults to ``None``. This is only supported on pymssql. - -* *text_as_varchar* - if enabled this will treat all TEXT column - types as their equivalent VARCHAR(max) type. This is often used if - you need to compare a VARCHAR to a TEXT field, which is not - supported directly on MSSQL. Defaults to ``False``. - -* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY - should be used in place of the non-scoped version @@IDENTITY. - Defaults to ``False``. On pymssql this defaults to ``True``, and on - pyodbc this defaults to ``True`` if the version of pyodbc being - used supports it. - -* *has_window_funcs* - indicates whether or not window functions - (LIMIT and OFFSET) are supported on the version of MSSQL being - used. If you're running MSSQL 2005 or later turn this on to get - OFFSET support. Defaults to ``False``. - -* *max_identifier_length* - allows you to se the maximum length of - identfiers supported by the database. Defaults to 128. For pymssql - the default is 30. - -* *schema_name* - use to set the schema name. Defaults to ``dbo``. - -Auto Increment Behavior ------------------------ - -``IDENTITY`` columns are supported by using SQLAlchemy -``schema.Sequence()`` objects. In other words:: - - Table('test', mss_engine, - Column('id', Integer, - Sequence('blah',100,10), primary_key=True), - Column('name', String(20)) - ).create() - -would yield:: - - CREATE TABLE test ( - id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, - name VARCHAR(20) NULL, - ) - -Note that the ``start`` and ``increment`` values for sequences are -optional and will default to 1,1. - -* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for - ``INSERT`` s) - -* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on - ``INSERT`` - -Collation Support ------------------ - -MSSQL specific string types support a collation parameter that -creates a column-level specific collation for the column. The -collation parameter accepts a Windows Collation Name or a SQL -Collation Name. Supported types are MSChar, MSNChar, MSString, -MSNVarchar, MSText, and MSNText. For example:: - - Column('login', String(32, collation='Latin1_General_CI_AS')) - -will yield:: - - login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL - -LIMIT/OFFSET Support --------------------- - -MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is -supported directly through the ``TOP`` Transact SQL keyword:: - - select.limit - -will yield:: - - SELECT TOP n - -If the ``has_window_funcs`` flag is set then LIMIT with OFFSET -support is available through the ``ROW_NUMBER OVER`` construct. This -construct requires an ``ORDER BY`` to be specified as well and is -only available on MSSQL 2005 and later. - -Nullability ------------ -MSSQL has support for three levels of column nullability. The default -nullability allows nulls and is explicit in the CREATE TABLE -construct:: - - name VARCHAR(20) NULL - -If ``nullable=None`` is specified then no specification is made. In -other words the database's configured default is used. This will -render:: - - name VARCHAR(20) - -If ``nullable`` is ``True`` or ``False`` then the column will be -``NULL` or ``NOT NULL`` respectively. - -Date / Time Handling --------------------- -For MSSQL versions that support the ``DATE`` and ``TIME`` types -(MSSQL 2008+) the data type is used. For versions that do not -support the ``DATE`` and ``TIME`` types a ``DATETIME`` type is used -instead and the MSSQL dialect handles converting the results -properly. This means ``Date()`` and ``Time()`` are fully supported -on all versions of MSSQL. If you do not desire this behavior then -do not use the ``Date()`` or ``Time()`` types. - -Compatibility Levels --------------------- -MSSQL supports the notion of setting compatibility levels at the -database level. This allows, for instance, to run a database that -is compatibile with SQL2000 while running on a SQL2005 database -server. ``server_version_info`` will always retrun the database -server version information (in this case SQL2005) and not the -compatibiility level information. Because of this, if running under -a backwards compatibility mode SQAlchemy may attempt to use T-SQL -statements that are unable to be parsed by the database server. - -Known Issues ------------- - -* No support for more than one ``IDENTITY`` column per table - -* pymssql has problems with binary and unicode data that this module - does **not** work around - -""" -import datetime, decimal, inspect, operator, re, sys, urllib - -from sqlalchemy import sql, schema, exc, util -from sqlalchemy import Table, MetaData, Column, ForeignKey, String, Integer -from sqlalchemy.sql import select, compiler, expression, operators as sql_operators, functions as sql_functions -from sqlalchemy.engine import default, base -from sqlalchemy import types as sqltypes -from decimal import Decimal as _python_Decimal - - -RESERVED_WORDS = set( - ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', - 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', - 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', - 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', - 'containstable', 'continue', 'convert', 'create', 'cross', 'current', - 'current_date', 'current_time', 'current_timestamp', 'current_user', - 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', - 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', - 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', - 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', - 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', - 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', - 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', - 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', - 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', - 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', - 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', - 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', - 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', - 'reconfigure', 'references', 'replication', 'restore', 'restrict', - 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', - 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', - 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', - 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', - 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', - 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', - 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', - 'writetext', - ]) - - -class _StringType(object): - """Base for MSSQL string types.""" - - def __init__(self, collation=None, **kwargs): - self.collation = kwargs.get('collate', collation) - - def _extend(self, spec): - """Extend a string-type declaration with standard SQL - COLLATE annotations. - """ - - if self.collation: - collation = 'COLLATE %s' % self.collation - else: - collation = None - - return ' '.join([c for c in (spec, collation) - if c is not None]) - - def __repr__(self): - attributes = inspect.getargspec(self.__init__)[0][1:] - attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) - - params = {} - for attr in attributes: - val = getattr(self, attr) - if val is not None and val is not False: - params[attr] = val - - return "%s(%s)" % (self.__class__.__name__, - ', '.join(['%s=%r' % (k, params[k]) for k in params])) - - def bind_processor(self, dialect): - if self.convert_unicode or dialect.convert_unicode: - if self.assert_unicode is None: - assert_unicode = dialect.assert_unicode - else: - assert_unicode = self.assert_unicode - - if not assert_unicode: - return None - - def process(value): - if not isinstance(value, (unicode, sqltypes.NoneType)): - if assert_unicode == 'warn': - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) - return value - else: - raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) - else: - return value - return process - else: - return None - - -class MSNumeric(sqltypes.Numeric): - def result_processor(self, dialect): - if self.asdecimal: - def process(value): - if value is not None: - return _python_Decimal(str(value)) - else: - return value - return process - else: - def process(value): - return float(value) - return process - - def bind_processor(self, dialect): - def process(value): - if value is None: - # Not sure that this exception is needed - return value - - elif isinstance(value, decimal.Decimal): - if value.adjusted() < 0: - result = "%s0.%s%s" % ( - (value < 0 and '-' or ''), - '0' * (abs(value.adjusted()) - 1), - "".join([str(nint) for nint in value._int])) - - else: - if 'E' in str(value): - result = "%s%s%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in value._int]), - "0" * (value.adjusted() - (len(value._int)-1))) - else: - if (len(value._int) - 1) > value.adjusted(): - result = "%s%s.%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in value._int][0:value.adjusted() + 1]), - "".join([str(s) for s in value._int][value.adjusted() + 1:])) - else: - result = "%s%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in value._int][0:value.adjusted() + 1])) - - return result - - else: - return value - - return process - - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - - -class MSFloat(sqltypes.Float): - def get_col_spec(self): - if self.precision is None: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - -class MSReal(MSFloat): - """A type for ``real`` numbers.""" - - def __init__(self): - """ - Construct a Real. - - """ - super(MSReal, self).__init__(precision=24) - - def adapt(self, impltype): - return impltype() - - def get_col_spec(self): - return "REAL" - - -class MSInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - - -class MSBigInteger(MSInteger): - def get_col_spec(self): - return "BIGINT" - - -class MSTinyInteger(MSInteger): - def get_col_spec(self): - return "TINYINT" - - -class MSSmallInteger(MSInteger): - def get_col_spec(self): - return "SMALLINT" - - -class _DateTimeType(object): - """Base for MSSQL datetime types.""" - - def bind_processor(self, dialect): - # if we receive just a date we can manipulate it - # into a datetime since the db-api may not do this. - def process(value): - if type(value) is datetime.date: - return datetime.datetime(value.year, value.month, value.day) - return value - return process - - -class MSDateTime(_DateTimeType, sqltypes.DateTime): - def get_col_spec(self): - return "DATETIME" - - -class MSDate(sqltypes.Date): - def get_col_spec(self): - return "DATE" - - -class MSTime(sqltypes.Time): - def __init__(self, precision=None, **kwargs): - self.precision = precision - super(MSTime, self).__init__() - - def get_col_spec(self): - if self.precision: - return "TIME(%s)" % self.precision - else: - return "TIME" - - -class MSSmallDateTime(_DateTimeType, sqltypes.TypeEngine): - def get_col_spec(self): - return "SMALLDATETIME" - - -class MSDateTime2(_DateTimeType, sqltypes.TypeEngine): - def __init__(self, precision=None, **kwargs): - self.precision = precision - - def get_col_spec(self): - if self.precision: - return "DATETIME2(%s)" % self.precision - else: - return "DATETIME2" - - -class MSDateTimeOffset(_DateTimeType, sqltypes.TypeEngine): - def __init__(self, precision=None, **kwargs): - self.precision = precision - - def get_col_spec(self): - if self.precision: - return "DATETIMEOFFSET(%s)" % self.precision - else: - return "DATETIMEOFFSET" - - -class MSDateTimeAsDate(_DateTimeType, MSDate): - """ This is an implementation of the Date type for versions of MSSQL that - do not support that specific type. In order to make it work a ``DATETIME`` - column specification is used and the results get converted back to just - the date portion. - - """ - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - # If the DBAPI returns the value as datetime.datetime(), truncate - # it back to datetime.date() - if type(value) is datetime.datetime: - return value.date() - return value - return process - - -class MSDateTimeAsTime(MSTime): - """ This is an implementation of the Time type for versions of MSSQL that - do not support that specific type. In order to make it work a ``DATETIME`` - column specification is used and the results get converted back to just - the time portion. - - """ - - __zero_date = datetime.date(1900, 1, 1) - - def get_col_spec(self): - return "DATETIME" - - def bind_processor(self, dialect): - def process(value): - if type(value) is datetime.datetime: - value = datetime.datetime.combine(self.__zero_date, value.time()) - elif type(value) is datetime.time: - value = datetime.datetime.combine(self.__zero_date, value) - return value - return process - - def result_processor(self, dialect): - def process(value): - if type(value) is datetime.datetime: - return value.time() - elif type(value) is datetime.date: - return datetime.time(0, 0, 0) - return value - return process - - -class MSDateTime_adodbapi(MSDateTime): - def result_processor(self, dialect): - def process(value): - # adodbapi will return datetimes with empty time values as datetime.date() objects. - # Promote them back to full datetime.datetime() - if type(value) is datetime.date: - return datetime.datetime(value.year, value.month, value.day) - return value - return process - - -class MSText(_StringType, sqltypes.Text): - """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" - - def __init__(self, *args, **kwargs): - """Construct a TEXT. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.Text.__init__(self, None, - convert_unicode=kwargs.get('convert_unicode', False), - assert_unicode=kwargs.get('assert_unicode', None)) - - def get_col_spec(self): - if self.dialect.text_as_varchar: - return self._extend("VARCHAR(max)") - else: - return self._extend("TEXT") - - -class MSNText(_StringType, sqltypes.UnicodeText): - """MSSQL NTEXT type, for variable-length unicode text up to 2^30 - characters.""" - - def __init__(self, *args, **kwargs): - """Construct a NTEXT. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.UnicodeText.__init__(self, None, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def get_col_spec(self): - if self.dialect.text_as_varchar: - return self._extend("NVARCHAR(max)") - else: - return self._extend("NTEXT") - - -class MSString(_StringType, sqltypes.String): - """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum - of 8,000 characters.""" - - def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs): - """Construct a VARCHAR. - - :param length: Optinal, maximum data length, in characters. - - :param convert_unicode: defaults to False. If True, convert - ``unicode`` data sent to the database to a ``str`` - bytestring, and convert bytestrings coming back from the - database into ``unicode``. - - Bytestrings are encoded using the dialect's - :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which - defaults to `utf-8`. - - If False, may be overridden by - :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. - - :param assert_unicode: - - If None (the default), no assertion will take place unless - overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. - - If 'warn', will issue a runtime warning if a ``str`` - instance is used as a bind value. - - If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.String.__init__(self, length=length, - convert_unicode=convert_unicode, - assert_unicode=assert_unicode) - - def get_col_spec(self): - if self.length: - return self._extend("VARCHAR(%s)" % self.length) - else: - return self._extend("VARCHAR") - - -class MSNVarchar(_StringType, sqltypes.Unicode): - """MSSQL NVARCHAR type. - - For variable-length unicode character data up to 4,000 characters.""" - - def __init__(self, length=None, **kwargs): - """Construct a NVARCHAR. - - :param length: Optional, Maximum data length, in characters. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.Unicode.__init__(self, length=length, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def adapt(self, impltype): - return impltype(length=self.length, - convert_unicode=self.convert_unicode, - assert_unicode=self.assert_unicode, - collation=self.collation) - - def get_col_spec(self): - if self.length: - return self._extend("NVARCHAR(%(length)s)" % {'length' : self.length}) - else: - return self._extend("NVARCHAR") - - -class MSChar(_StringType, sqltypes.CHAR): - """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum - of 8,000 characters.""" - - def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs): - """Construct a CHAR. - - :param length: Optinal, maximum data length, in characters. - - :param convert_unicode: defaults to False. If True, convert - ``unicode`` data sent to the database to a ``str`` - bytestring, and convert bytestrings coming back from the - database into ``unicode``. - - Bytestrings are encoded using the dialect's - :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which - defaults to `utf-8`. - - If False, may be overridden by - :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. - - :param assert_unicode: - - If None (the default), no assertion will take place unless - overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. - - If 'warn', will issue a runtime warning if a ``str`` - instance is used as a bind value. - - If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.CHAR.__init__(self, length=length, - convert_unicode=convert_unicode, - assert_unicode=assert_unicode) - - def get_col_spec(self): - if self.length: - return self._extend("CHAR(%s)" % self.length) - else: - return self._extend("CHAR") - - -class MSNChar(_StringType, sqltypes.NCHAR): - """MSSQL NCHAR type. - - For fixed-length unicode character data up to 4,000 characters.""" - - def __init__(self, length=None, **kwargs): - """Construct an NCHAR. - - :param length: Optional, Maximum data length, in characters. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.NCHAR.__init__(self, length=length, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def get_col_spec(self): - if self.length: - return self._extend("NCHAR(%(length)s)" % {'length' : self.length}) - else: - return self._extend("NCHAR") - - -class MSGenericBinary(sqltypes.Binary): - """The Binary type assumes that a Binary specification without a length - is an unbound Binary type whereas one with a length specification results - in a fixed length Binary type. - - If you want standard MSSQL ``BINARY`` behavior use the ``MSBinary`` type. - - """ - - def get_col_spec(self): - if self.length: - return "BINARY(%s)" % self.length - else: - return "IMAGE" - - -class MSBinary(MSGenericBinary): - def get_col_spec(self): - if self.length: - return "BINARY(%s)" % self.length - else: - return "BINARY" - - -class MSVarBinary(MSGenericBinary): - def get_col_spec(self): - if self.length: - return "VARBINARY(%s)" % self.length - else: - return "VARBINARY" - - -class MSImage(MSGenericBinary): - def get_col_spec(self): - return "IMAGE" - - -class MSBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BIT" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - - -class MSTimeStamp(sqltypes.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" - - -class MSMoney(sqltypes.TypeEngine): - def get_col_spec(self): - return "MONEY" - - -class MSSmallMoney(MSMoney): - def get_col_spec(self): - return "SMALLMONEY" - - -class MSUniqueIdentifier(sqltypes.TypeEngine): - def get_col_spec(self): - return "UNIQUEIDENTIFIER" - - -class MSVariant(sqltypes.TypeEngine): - def get_col_spec(self): - return "SQL_VARIANT" - -ischema = MetaData() - -schemata = Table("SCHEMATA", ischema, - Column("CATALOG_NAME", String, key="catalog_name"), - Column("SCHEMA_NAME", String, key="schema_name"), - Column("SCHEMA_OWNER", String, key="schema_owner"), - schema="INFORMATION_SCHEMA") - -tables = Table("TABLES", ischema, - Column("TABLE_CATALOG", String, key="table_catalog"), - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("TABLE_TYPE", String, key="table_type"), - schema="INFORMATION_SCHEMA") - -columns = Table("COLUMNS", ischema, - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("COLUMN_NAME", String, key="column_name"), - Column("IS_NULLABLE", Integer, key="is_nullable"), - Column("DATA_TYPE", String, key="data_type"), - Column("ORDINAL_POSITION", Integer, key="ordinal_position"), - Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"), - Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), - Column("NUMERIC_SCALE", Integer, key="numeric_scale"), - Column("COLUMN_DEFAULT", Integer, key="column_default"), - Column("COLLATION_NAME", String, key="collation_name"), - schema="INFORMATION_SCHEMA") - -constraints = Table("TABLE_CONSTRAINTS", ischema, - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("CONSTRAINT_NAME", String, key="constraint_name"), - Column("CONSTRAINT_TYPE", String, key="constraint_type"), - schema="INFORMATION_SCHEMA") - -column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("COLUMN_NAME", String, key="column_name"), - Column("CONSTRAINT_NAME", String, key="constraint_name"), - schema="INFORMATION_SCHEMA") - -key_constraints = Table("KEY_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("COLUMN_NAME", String, key="column_name"), - Column("CONSTRAINT_NAME", String, key="constraint_name"), - Column("ORDINAL_POSITION", Integer, key="ordinal_position"), - schema="INFORMATION_SCHEMA") - -ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, - Column("CONSTRAINT_CATALOG", String, key="constraint_catalog"), - Column("CONSTRAINT_SCHEMA", String, key="constraint_schema"), - Column("CONSTRAINT_NAME", String, key="constraint_name"), - Column("UNIQUE_CONSTRAINT_CATLOG", String, key="unique_constraint_catalog"), - Column("UNIQUE_CONSTRAINT_SCHEMA", String, key="unique_constraint_schema"), - Column("UNIQUE_CONSTRAINT_NAME", String, key="unique_constraint_name"), - Column("MATCH_OPTION", String, key="match_option"), - Column("UPDATE_RULE", String, key="update_rule"), - Column("DELETE_RULE", String, key="delete_rule"), - schema="INFORMATION_SCHEMA") - -def _has_implicit_sequence(column): - return column.primary_key and \ - column.autoincrement and \ - isinstance(column.type, sqltypes.Integer) and \ - not column.foreign_keys and \ - ( - column.default is None or - ( - isinstance(column.default, schema.Sequence) and - column.default.optional) - ) - -def _table_sequence_column(tbl): - if not hasattr(tbl, '_ms_has_sequence'): - tbl._ms_has_sequence = None - for column in tbl.c: - if getattr(column, 'sequence', False) or _has_implicit_sequence(column): - tbl._ms_has_sequence = column - break - return tbl._ms_has_sequence - -class MSSQLExecutionContext(default.DefaultExecutionContext): - IINSERT = False - HASIDENT = False - - def pre_exec(self): - """Activate IDENTITY_INSERT if needed.""" - - if self.compiled.isinsert: - tbl = self.compiled.statement.table - seq_column = _table_sequence_column(tbl) - self.HASIDENT = bool(seq_column) - if self.dialect.auto_identity_insert and self.HASIDENT: - self.IINSERT = tbl._ms_has_sequence.key in self.compiled_parameters[0] - else: - self.IINSERT = False - - if self.IINSERT: - self.cursor.execute("SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) - - def handle_dbapi_exception(self, e): - if self.IINSERT: - try: - self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) - except: - pass - - def post_exec(self): - """Disable IDENTITY_INSERT if enabled.""" - - if self.compiled.isinsert and not self.executemany and self.HASIDENT and not self.IINSERT: - if not self._last_inserted_ids or self._last_inserted_ids[0] is None: - if self.dialect.use_scope_identity: - self.cursor.execute("SELECT scope_identity() AS lastrowid") - else: - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] - - if self.IINSERT: - self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) - - -class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext): - def pre_exec(self): - """where appropriate, issue "select scope_identity()" in the same statement""" - super(MSSQLExecutionContext_pyodbc, self).pre_exec() - if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \ - and len(self.parameters) == 1 and self.dialect.use_scope_identity: - self.statement += "; select scope_identity()" - - def post_exec(self): - if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany: - import pyodbc - # Fetch the last inserted id from the manipulated statement - # We may have to skip over a number of result sets with no data (due to triggers, etc.) - while True: - try: - row = self.cursor.fetchone() - break - except pyodbc.Error, e: - self.cursor.nextset() - self._last_inserted_ids = [int(row[0])] - else: - super(MSSQLExecutionContext_pyodbc, self).post_exec() - -class MSSQLDialect(default.DefaultDialect): - name = 'mssql' - supports_default_values = True - supports_empty_insert = False - auto_identity_insert = True - execution_ctx_cls = MSSQLExecutionContext - text_as_varchar = False - use_scope_identity = False - has_window_funcs = False - max_identifier_length = 128 - schema_name = "dbo" - - colspecs = { - sqltypes.Unicode : MSNVarchar, - sqltypes.Integer : MSInteger, - sqltypes.Smallinteger: MSSmallInteger, - sqltypes.Numeric : MSNumeric, - sqltypes.Float : MSFloat, - sqltypes.DateTime : MSDateTime, - sqltypes.Date : MSDate, - sqltypes.Time : MSTime, - sqltypes.String : MSString, - sqltypes.Binary : MSGenericBinary, - sqltypes.Boolean : MSBoolean, - sqltypes.Text : MSText, - sqltypes.UnicodeText : MSNText, - sqltypes.CHAR: MSChar, - sqltypes.NCHAR: MSNChar, - sqltypes.TIMESTAMP: MSTimeStamp, - } - - ischema_names = { - 'int' : MSInteger, - 'bigint': MSBigInteger, - 'smallint' : MSSmallInteger, - 'tinyint' : MSTinyInteger, - 'varchar' : MSString, - 'nvarchar' : MSNVarchar, - 'char' : MSChar, - 'nchar' : MSNChar, - 'text' : MSText, - 'ntext' : MSNText, - 'decimal' : MSNumeric, - 'numeric' : MSNumeric, - 'float' : MSFloat, - 'datetime' : MSDateTime, - 'datetime2' : MSDateTime2, - 'datetimeoffset' : MSDateTimeOffset, - 'date': MSDate, - 'time': MSTime, - 'smalldatetime' : MSSmallDateTime, - 'binary' : MSBinary, - 'varbinary' : MSVarBinary, - 'bit': MSBoolean, - 'real' : MSFloat, - 'image' : MSImage, - 'timestamp': MSTimeStamp, - 'money': MSMoney, - 'smallmoney': MSSmallMoney, - 'uniqueidentifier': MSUniqueIdentifier, - 'sql_variant': MSVariant, - } - - def __new__(cls, *args, **kwargs): - if cls is not MSSQLDialect: - # this gets called with the dialect specific class - return super(MSSQLDialect, cls).__new__(cls) - dbapi = kwargs.get('dbapi', None) - if dbapi: - dialect = dialect_mapping.get(dbapi.__name__) - return dialect(**kwargs) - else: - return object.__new__(cls) - - def __init__(self, - auto_identity_insert=True, query_timeout=None, - text_as_varchar=False, use_scope_identity=False, - has_window_funcs=False, max_identifier_length=None, - schema_name="dbo", **opts): - self.auto_identity_insert = bool(auto_identity_insert) - self.query_timeout = int(query_timeout or 0) - self.schema_name = schema_name - - # to-do: the options below should use server version introspection to set themselves on connection - self.text_as_varchar = bool(text_as_varchar) - self.use_scope_identity = bool(use_scope_identity) - self.has_window_funcs = bool(has_window_funcs) - self.max_identifier_length = int(max_identifier_length or 0) or \ - self.max_identifier_length - super(MSSQLDialect, self).__init__(**opts) - - @classmethod - def dbapi(cls, module_name=None): - if module_name: - try: - dialect_cls = dialect_mapping[module_name] - return dialect_cls.import_dbapi() - except KeyError: - raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) - else: - for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]: - try: - return dialect_cls.import_dbapi() - except ImportError, e: - pass - else: - raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi') - - @base.connection_memoize(('mssql', 'server_version_info')) - def server_version_info(self, connection): - """A tuple of the database server version. - - Formats the remote server version as a tuple of version values, - e.g. ``(9, 0, 1399)``. If there are strings in the version number - they will be in the tuple too, so don't count on these all being - ``int`` values. - - This is a fast check that does not require a round trip. It is also - cached per-Connection. - """ - return connection.dialect._server_version_info(connection.connection) - - def _server_version_info(self, dbapi_con): - """Return a tuple of the database's version number.""" - raise NotImplementedError() - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - opts.update(url.query) - if 'auto_identity_insert' in opts: - self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert'))) - if 'query_timeout' in opts: - self.query_timeout = int(opts.pop('query_timeout')) - if 'text_as_varchar' in opts: - self.text_as_varchar = bool(int(opts.pop('text_as_varchar'))) - if 'use_scope_identity' in opts: - self.use_scope_identity = bool(int(opts.pop('use_scope_identity'))) - if 'has_window_funcs' in opts: - self.has_window_funcs = bool(int(opts.pop('has_window_funcs'))) - return self.make_connect_string(opts, url.query) - - def type_descriptor(self, typeobj): - newobj = sqltypes.adapt_type(typeobj, self.colspecs) - # Some types need to know about the dialect - if isinstance(newobj, (MSText, MSNText)): - newobj.dialect = self - return newobj - - def do_savepoint(self, connection, name): - util.warn("Savepoint support in mssql is experimental and may lead to data loss.") - connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") - connection.execute("SAVE TRANSACTION %s" % name) - - def do_release_savepoint(self, connection, name): - pass - - @base.connection_memoize(('dialect', 'default_schema_name')) - def get_default_schema_name(self, connection): - query = "SELECT user_name() as user_name;" - user_name = connection.scalar(sql.text(query)) - if user_name is not None: - # now, get the default schema - query = """ - SELECT default_schema_name FROM - sys.database_principals - WHERE name = :user_name - AND type = 'S' - """ - try: - default_schema_name = connection.scalar(sql.text(query), - user_name=user_name) - if default_schema_name is not None: - return default_schema_name - except: - pass - return self.schema_name - - def table_names(self, connection, schema): - s = select([tables.c.table_name], tables.c.table_schema==schema) - return [row[0] for row in connection.execute(s)] - - - def has_table(self, connection, tablename, schema=None): - - current_schema = schema or self.get_default_schema_name(connection) - s = sql.select([columns], - current_schema - and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) - or columns.c.table_name==tablename, - ) - - c = connection.execute(s) - row = c.fetchone() - return row is not None - - def reflecttable(self, connection, table, include_columns): - # Get base columns - if table.schema is not None: - current_schema = table.schema - else: - current_schema = self.get_default_schema_name(connection) - - s = sql.select([columns], - current_schema - and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema) - or columns.c.table_name==table.name, - order_by=[columns.c.ordinal_position]) - - c = connection.execute(s) - found_table = False - while True: - row = c.fetchone() - if row is None: - break - found_table = True - (name, type, nullable, charlen, numericprec, numericscale, default, collation) = ( - row[columns.c.column_name], - row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', - row[columns.c.character_maximum_length], - row[columns.c.numeric_precision], - row[columns.c.numeric_scale], - row[columns.c.column_default], - row[columns.c.collation_name] - ) - if include_columns and name not in include_columns: - continue - - coltype = self.ischema_names.get(type, None) - - kwargs = {} - if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary): - kwargs['length'] = charlen - if collation: - kwargs['collation'] = collation - if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1): - kwargs.pop('length') - - if issubclass(coltype, sqltypes.Numeric): - kwargs['scale'] = numericscale - kwargs['precision'] = numericprec - - if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % (type, name)) - coltype = sqltypes.NULLTYPE - - coltype = coltype(**kwargs) - colargs = [] - if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - # We also run an sp_columns to check for identity columns: - cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema)) - ic = None - while True: - row = cursor.fetchone() - if row is None: - break - col_name, type_name = row[3], row[5] - if type_name.endswith("identity") and col_name in table.c: - ic = table.c[col_name] - ic.autoincrement = True - # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute - ic.sequence = schema.Sequence(ic.name + '_identity', 1, 1) - # MSSQL: only one identity per table allowed - cursor.close() - break - if not ic is None: - try: - cursor = connection.execute("select ident_seed(?), ident_incr(?)", table.fullname, table.fullname) - row = cursor.fetchone() - cursor.close() - if not row is None: - ic.sequence.start = int(row[0]) - ic.sequence.increment = int(row[1]) - except: - # ignoring it, works just like before - pass - - # Add constraints - RR = ref_constraints - TC = constraints - C = key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column - R = key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column - - # Primary key constraints - s = sql.select([C.c.column_name, TC.c.constraint_type], sql.and_(TC.c.constraint_name == C.c.constraint_name, - C.c.table_name == table.name, - C.c.table_schema == (table.schema or current_schema))) - c = connection.execute(s) - for row in c: - if 'PRIMARY' in row[TC.c.constraint_type.name] and row[0] in table.c: - table.primary_key.add(table.c[row[0]]) - - # Foreign key constraints - s = sql.select([C.c.column_name, - R.c.table_schema, R.c.table_name, R.c.column_name, - RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule], - sql.and_(C.c.table_name == table.name, - C.c.table_schema == (table.schema or current_schema), - C.c.constraint_name == RR.c.constraint_name, - R.c.constraint_name == RR.c.unique_constraint_name, - C.c.ordinal_position == R.c.ordinal_position - ), - order_by = [RR.c.constraint_name, R.c.ordinal_position]) - rows = connection.execute(s).fetchall() - - def _gen_fkref(table, rschema, rtbl, rcol): - if rschema == current_schema and not table.schema: - return '.'.join([rtbl, rcol]) - else: - return '.'.join([rschema, rtbl, rcol]) - - # group rows by constraint ID, to handle multi-column FKs - fknm, scols, rcols = (None, [], []) - for r in rows: - scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r - # if the reflected schema is the default schema then don't set it because this will - # play into the metadata key causing duplicates. - if rschema == current_schema and not table.schema: - schema.Table(rtbl, table.metadata, autoload=True, autoload_with=connection) - else: - schema.Table(rtbl, table.metadata, schema=rschema, autoload=True, autoload_with=connection) - if rfknm != fknm: - if fknm: - table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True)) - fknm, scols, rcols = (rfknm, [], []) - if not scol in scols: - scols.append(scol) - if not (rschema, rtbl, rcol) in rcols: - rcols.append((rschema, rtbl, rcol)) - - if fknm and scols: - table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True)) - - -class MSSQLDialect_pymssql(MSSQLDialect): - supports_sane_rowcount = False - max_identifier_length = 30 - - @classmethod - def import_dbapi(cls): - import pymssql as module - # pymmsql doesn't have a Binary method. we use string - # TODO: monkeypatching here is less than ideal - module.Binary = lambda st: str(st) - try: - module.version_info = tuple(map(int, module.__version__.split('.'))) - except: - module.version_info = (0, 0, 0) - return module - - def __init__(self, **params): - super(MSSQLDialect_pymssql, self).__init__(**params) - self.use_scope_identity = True - - # pymssql understands only ascii - if self.convert_unicode: - util.warn("pymssql does not support unicode") - self.encoding = params.get('encoding', 'ascii') - - self.colspecs = MSSQLDialect.colspecs.copy() - self.ischema_names = MSSQLDialect.ischema_names.copy() - self.ischema_names['date'] = MSDateTimeAsDate - self.colspecs[sqltypes.Date] = MSDateTimeAsDate - self.ischema_names['time'] = MSDateTimeAsTime - self.colspecs[sqltypes.Time] = MSDateTimeAsTime - - def create_connect_args(self, url): - r = super(MSSQLDialect_pymssql, self).create_connect_args(url) - if hasattr(self, 'query_timeout'): - if self.dbapi.version_info > (0, 8, 0): - r[1]['timeout'] = self.query_timeout - else: - self.dbapi._mssql.set_query_timeout(self.query_timeout) - return r - - def make_connect_string(self, keys, query): - if keys.get('port'): - # pymssql expects port as host:port, not a separate arg - keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])]) - del keys['port'] - return [[], keys] - - def is_disconnect(self, e): - return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) - - def do_begin(self, connection): - pass - - -class MSSQLDialect_pyodbc(MSSQLDialect): - supports_sane_rowcount = False - supports_sane_multi_rowcount = False - # PyODBC unicode is broken on UCS-4 builds - supports_unicode = sys.maxunicode == 65535 - supports_unicode_statements = supports_unicode - execution_ctx_cls = MSSQLExecutionContext_pyodbc - - def __init__(self, description_encoding='latin-1', **params): - super(MSSQLDialect_pyodbc, self).__init__(**params) - self.description_encoding = description_encoding - - if self.server_version_info < (10,): - self.colspecs = MSSQLDialect.colspecs.copy() - self.ischema_names = MSSQLDialect.ischema_names.copy() - self.ischema_names['date'] = MSDateTimeAsDate - self.colspecs[sqltypes.Date] = MSDateTimeAsDate - self.ischema_names['time'] = MSDateTimeAsTime - self.colspecs[sqltypes.Time] = MSDateTimeAsTime - - # FIXME: scope_identity sniff should look at server version, not the ODBC driver - # whether use_scope_identity will work depends on the version of pyodbc - try: - import pyodbc - self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset') - except: - pass - - @classmethod - def import_dbapi(cls): - import pyodbc as module - return module - - def make_connect_string(self, keys, query): - if 'max_identifier_length' in keys: - self.max_identifier_length = int(keys.pop('max_identifier_length')) - - if 'odbc_connect' in keys: - connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))] - else: - dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) - if dsn_connection: - connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] - else: - port = '' - if 'port' in keys and not 'port' in query: - port = ',%d' % int(keys.pop('port')) - - connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'), - 'Server=%s%s' % (keys.pop('host', ''), port), - 'Database=%s' % keys.pop('database', '') ] - - user = keys.pop("user", None) - if user: - connectors.append("UID=%s" % user) - connectors.append("PWD=%s" % keys.pop('password', '')) - else: - connectors.append("TrustedConnection=Yes") - - # if set to 'Yes', the ODBC layer will try to automagically convert - # textual data from your database encoding to your client encoding - # This should obviously be set to 'No' if you query a cp1253 encoded - # database from a latin1 client... - if 'odbc_autotranslate' in keys: - connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) - - connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()]) - - return [[";".join (connectors)], {}] - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.ProgrammingError): - return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e) - elif isinstance(e, self.dbapi.Error): - return '[08S01]' in str(e) - else: - return False - - - def _server_version_info(self, dbapi_con): - """Convert a pyodbc SQL_DBMS_VER string into a tuple.""" - version = [] - r = re.compile('[.\-]') - for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): - try: - version.append(int(n)) - except ValueError: - version.append(n) - return tuple(version) - -class MSSQLDialect_adodbapi(MSSQLDialect): - supports_sane_rowcount = True - supports_sane_multi_rowcount = True - supports_unicode = sys.maxunicode == 65535 - supports_unicode_statements = True - - @classmethod - def import_dbapi(cls): - import adodbapi as module - return module - - colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.DateTime] = MSDateTime_adodbapi - - ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['datetime'] = MSDateTime_adodbapi - - def make_connect_string(self, keys, query): - connectors = ["Provider=SQLOLEDB"] - if 'port' in keys: - connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port"))) - else: - connectors.append ("Data Source=%s" % keys.get("host")) - connectors.append ("Initial Catalog=%s" % keys.get("database")) - user = keys.get("user") - if user: - connectors.append("User Id=%s" % user) - connectors.append("Password=%s" % keys.get("password", "")) - else: - connectors.append("Integrated Security=SSPI") - return [[";".join (connectors)], {}] - - def is_disconnect(self, e): - return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) - - -dialect_mapping = { - 'pymssql': MSSQLDialect_pymssql, - 'pyodbc': MSSQLDialect_pyodbc, - 'adodbapi': MSSQLDialect_adodbapi - } - - -class MSSQLCompiler(compiler.DefaultCompiler): - operators = compiler.OPERATORS.copy() - operators.update({ - sql_operators.concat_op: '+', - sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) - }) - - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - sql_functions.now: 'CURRENT_TIMESTAMP', - sql_functions.current_date: 'GETDATE()', - 'length': lambda x: "LEN(%s)" % x, - sql_functions.char_length: lambda x: "LEN(%s)" % x - } - ) - - extract_map = compiler.DefaultCompiler.extract_map.copy() - extract_map.update ({ - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond', - 'microseconds': 'microsecond' - }) - - def __init__(self, *args, **kwargs): - super(MSSQLCompiler, self).__init__(*args, **kwargs) - self.tablealiases = {} - - def get_select_precolumns(self, select): - """ MS-SQL puts TOP, it's version of LIMIT here """ - if select._distinct or select._limit: - s = select._distinct and "DISTINCT " or "" - - if select._limit: - if not select._offset: - s += "TOP %s " % (select._limit,) - else: - if not self.dialect.has_window_funcs: - raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') - return s - return compiler.DefaultCompiler.get_select_precolumns(self, select) - - def limit_clause(self, select): - # Limit in mssql is after the select keyword - return "" - - def visit_select(self, select, **kwargs): - """Look for ``LIMIT`` and OFFSET in a select statement, and if - so tries to wrap it in a subquery with ``row_number()`` criterion. - - """ - if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset: - # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.process(select._order_by_clause) - if not orderby: - raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') - - _offset = select._offset - _limit = select._limit - select._mssql_visit = True - select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() - - limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) - limitselect.append_whereclause("mssql_rn>%d" % _offset) - if _limit is not None: - limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) - return self.process(limitselect, iswrapper=True, **kwargs) - else: - return compiler.DefaultCompiler.visit_select(self, select, **kwargs) - - def _schema_aliased_table(self, table): - if getattr(table, 'schema', None) is not None: - if table not in self.tablealiases: - self.tablealiases[table] = table.alias() - return self.tablealiases[table] - else: - return None - - def visit_table(self, table, mssql_aliased=False, **kwargs): - if mssql_aliased: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) - - # alias schema-qualified tables - alias = self._schema_aliased_table(table) - if alias is not None: - return self.process(alias, mssql_aliased=True, **kwargs) - else: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) - - def visit_alias(self, alias, **kwargs): - # translate for schema-qualified table aliases - self.tablealiases[alias.original] = alias - kwargs['mssql_aliased'] = True - return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) - - def visit_extract(self, extract): - field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) - - def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) - - def visit_column(self, column, result_map=None, **kwargs): - if column.table is not None and \ - (not self.isupdate and not self.isdelete) or self.is_subquery(): - # translate for schema-qualified table aliases - t = self._schema_aliased_table(column.table) - if t is not None: - converted = expression._corresponding_column_or_error(t, column) - - if result_map is not None: - result_map[column.name.lower()] = (column.name, (column, ), column.type) - - return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs) - - return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs) - - def visit_binary(self, binary, **kwargs): - """Move bind parameters to the right-hand side of an operator, where - possible. - - """ - if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ - and not isinstance(binary.right, expression._BindParamClause): - return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) - else: - if (binary.operator is operator.eq or binary.operator is operator.ne) and ( - (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \ - (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \ - isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)): - op = binary.operator == operator.eq and "IN" or "NOT IN" - return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) - return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) - - def visit_insert(self, insert_stmt): - insert_select = False - if insert_stmt.parameters: - insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)] - if insert_select: - self.isinsert = True - colparams = self._get_colparams(insert_stmt) - preparer = self.preparer - - insert = ' '.join(["INSERT"] + - [self.process(x) for x in insert_stmt._prefixes]) - - if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert: - raise exc.CompileError( - "The version of %s you are using does not support empty inserts." % self.dialect.name) - elif not colparams and self.dialect.supports_default_values: - return (insert + " INTO %s DEFAULT VALUES" % ( - (preparer.format_table(insert_stmt.table),))) - else: - return (insert + " INTO %s (%s) SELECT %s" % - (preparer.format_table(insert_stmt.table), - ', '.join([preparer.format_column(c[0]) - for c in colparams]), - ', '.join([c[1] for c in colparams]))) - else: - return super(MSSQLCompiler, self).visit_insert(insert_stmt) - - def label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label(None) - else: - return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) - - def for_update_clause(self, select): - # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use - return '' - - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) - - # MSSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not self.is_subquery() or select._limit): - return " ORDER BY " + order_by - else: - return "" - - -class MSSQLSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() - - if column.nullable is not None: - if not column.nullable or column.primary_key: - colspec += " NOT NULL" - else: - colspec += " NULL" - - if not column.table: - raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL") - - seq_col = _table_sequence_column(column.table) - - # install a IDENTITY Sequence if we have an implicit IDENTITY column - if seq_col is column: - sequence = getattr(column, 'sequence', None) - if sequence: - start, increment = sequence.start or 1, sequence.increment or 1 - else: - start, increment = 1, 1 - colspec += " IDENTITY(%s,%s)" % (start, increment) - else: - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - return colspec - -class MSSQLSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - self.append("\nDROP INDEX %s.%s" % ( - self.preparer.quote_identifier(index.table.name), - self.preparer.quote(self._validate_identifier(index.name, False), index.quote) - )) - self.execute() - - -class MSSQLIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = RESERVED_WORDS - - def __init__(self, dialect): - super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') - - def _escape_identifier(self, value): - #TODO: determine MSSQL's escaping rules - return value - - def quote_schema(self, schema, force=True): - """Prepare a quoted table and schema name.""" - result = '.'.join([self.quote(x, force) for x in schema.split('.')]) - return result - -dialect = MSSQLDialect -dialect.statement_compiler = MSSQLCompiler -dialect.schemagenerator = MSSQLSchemaGenerator -dialect.schemadropper = MSSQLSchemaDropper -dialect.preparer = MSSQLIdentifierPreparer - diff --git a/lib/sqlalchemy/databases/mxODBC.py b/lib/sqlalchemy/databases/mxODBC.py deleted file mode 100644 index 92f533633..000000000 --- a/lib/sqlalchemy/databases/mxODBC.py +++ /dev/null @@ -1,60 +0,0 @@ -# mxODBC.py -# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch -# Coding: Alexander Houben alexander.houben@thor-solutions.ch -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -""" -A wrapper for a mx.ODBC.Windows DB-API connection. - -Makes sure the mx module is configured to return datetime objects instead -of mx.DateTime.DateTime objects. -""" - -from mx.ODBC.Windows import * - - -class Cursor: - def __init__(self, cursor): - self.cursor = cursor - - def __getattr__(self, attr): - res = getattr(self.cursor, attr) - return res - - def execute(self, *args, **kwargs): - res = self.cursor.execute(*args, **kwargs) - return res - - -class Connection: - def myErrorHandler(self, connection, cursor, errorclass, errorvalue): - err0, err1, err2, err3 = errorvalue - #print ", ".join(["Err%d: %s"%(x, errorvalue[x]) for x in range(4)]) - if int(err1) == 109: - # Ignore "Null value eliminated in aggregate function", this is not an error - return - raise errorclass, errorvalue - - def __init__(self, conn): - self.conn = conn - # install a mx ODBC error handler - self.conn.errorhandler = self.myErrorHandler - - def __getattr__(self, attr): - res = getattr(self.conn, attr) - return res - - def cursor(self, *args, **kwargs): - res = Cursor(self.conn.cursor(*args, **kwargs)) - return res - - -# override 'connect' call -def connect(*args, **kwargs): - import mx.ODBC.Windows - conn = mx.ODBC.Windows.Connect(*args, **kwargs) - conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT - return Connection(conn) -Connect = connect diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py deleted file mode 100644 index ba6b026ea..000000000 --- a/lib/sqlalchemy/databases/mysql.py +++ /dev/null @@ -1,2732 +0,0 @@ -# -*- fill-column: 78 -*- -# mysql.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""Support for the MySQL database. - -Overview --------- - -For normal SQLAlchemy usage, importing this module is unnecessary. It will be -loaded on-demand when a MySQL connection is needed. The generic column types -like :class:`~sqlalchemy.String` and :class:`~sqlalchemy.Integer` will -automatically be adapted to the optimal matching MySQL column type. - -But if you would like to use one of the MySQL-specific or enhanced column -types when creating tables with your :class:`~sqlalchemy.Table` definitions, -then you will need to import them from this module:: - - from sqlalchemy.databases import mysql - - Table('mytable', metadata, - Column('id', Integer, primary_key=True), - Column('ittybittyblob', mysql.MSTinyBlob), - Column('biggy', mysql.MSBigInteger(unsigned=True))) - -All standard MySQL column types are supported. The OpenGIS types are -available for use via table reflection but have no special support or mapping -to Python classes. If you're using these types and have opinions about how -OpenGIS can be smartly integrated into SQLAlchemy please join the mailing -list! - -Supported Versions and Features -------------------------------- - -SQLAlchemy supports 6 major MySQL versions: 3.23, 4.0, 4.1, 5.0, 5.1 and 6.0, -with capabilities increasing with more modern servers. - -Versions 4.1 and higher support the basic SQL functionality that SQLAlchemy -uses in the ORM and SQL expressions. These versions pass the applicable tests -in the suite 100%. No heroic measures are taken to work around major missing -SQL features- if your server version does not support sub-selects, for -example, they won't work in SQLAlchemy either. - -Currently, the only DB-API driver supported is `MySQL-Python` (also referred to -as `MySQLdb`). Either 1.2.1 or 1.2.2 are recommended. The alpha, beta and -gamma releases of 1.2.1 and 1.2.2 should be avoided. Support for Jython and -IronPython is planned. - -===================================== =============== -Feature Minimum Version -===================================== =============== -sqlalchemy.orm 4.1.1 -Table Reflection 3.23.x -DDL Generation 4.1.1 -utf8/Full Unicode Connections 4.1.1 -Transactions 3.23.15 -Two-Phase Transactions 5.0.3 -Nested Transactions 5.0.3 -===================================== =============== - -See the official MySQL documentation for detailed information about features -supported in any given server release. - -Character Sets --------------- - -Many MySQL server installations default to a ``latin1`` encoding for client -connections. All data sent through the connection will be converted into -``latin1``, even if you have ``utf8`` or another character set on your tables -and columns. With versions 4.1 and higher, you can change the connection -character set either through server configuration or by including the -``charset`` parameter in the URL used for ``create_engine``. The ``charset`` -option is passed through to MySQL-Python and has the side-effect of also -enabling ``use_unicode`` in the driver by default. For regular encoded -strings, also pass ``use_unicode=0`` in the connection arguments:: - - # set client encoding to utf8; all strings come back as unicode - create_engine('mysql:///mydb?charset=utf8') - - # set client encoding to utf8; all strings come back as utf8 str - create_engine('mysql:///mydb?charset=utf8&use_unicode=0') - -Storage Engines ---------------- - -Most MySQL server installations have a default table type of ``MyISAM``, a -non-transactional table type. During a transaction, non-transactional storage -engines do not participate and continue to store table changes in autocommit -mode. For fully atomic transactions, all participating tables must use a -transactional engine such as ``InnoDB``, ``Falcon``, ``SolidDB``, `PBXT`, etc. - -Storage engines can be elected when creating tables in SQLAlchemy by supplying -a ``mysql_engine='whatever'`` to the ``Table`` constructor. Any MySQL table -creation option can be specified in this syntax:: - - Table('mytable', metadata, - Column('data', String(32)), - mysql_engine='InnoDB', - mysql_charset='utf8' - ) - -Keys ----- - -Not all MySQL storage engines support foreign keys. For ``MyISAM`` and -similar engines, the information loaded by table reflection will not include -foreign keys. For these tables, you may supply a -:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: - - Table('mytable', metadata, - ForeignKeyConstraint(['other_id'], ['othertable.other_id']), - autoload=True - ) - -When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT``` on -an integer primary key column:: - - >>> t = Table('mytable', metadata, - ... Column('mytable_id', Integer, primary_key=True) - ... ) - >>> t.create() - CREATE TABLE mytable ( - id INTEGER NOT NULL AUTO_INCREMENT, - PRIMARY KEY (id) - ) - -You can disable this behavior by supplying ``autoincrement=False`` to the -:class:`~sqlalchemy.Column`. This flag can also be used to enable -auto-increment on a secondary column in a multi-column key for some storage -engines:: - - Table('mytable', metadata, - Column('gid', Integer, primary_key=True, autoincrement=False), - Column('id', Integer, primary_key=True) - ) - -SQL Mode --------- - -MySQL SQL modes are supported. Modes that enable ``ANSI_QUOTES`` (such as -``ANSI``) require an engine option to modify SQLAlchemy's quoting style. -When using an ANSI-quoting mode, supply ``use_ansiquotes=True`` when -creating your ``Engine``:: - - create_engine('mysql://localhost/test', use_ansiquotes=True) - -This is an engine-wide option and is not toggleable on a per-connection basis. -SQLAlchemy does not presume to ``SET sql_mode`` for you with this option. For -the best performance, set the quoting style server-wide in ``my.cnf`` or by -supplying ``--sql-mode`` to ``mysqld``. You can also use a -:class:`sqlalchemy.pool.Pool` listener hook to issue a ``SET SESSION -sql_mode='...'`` on connect to configure each connection. - -If you do not specify ``use_ansiquotes``, the regular MySQL quoting style is -used by default. - -If you do issue a ``SET sql_mode`` through SQLAlchemy, the dialect must be -updated if the quoting style is changed. Again, this change will affect all -connections:: - - connection.execute('SET sql_mode="ansi"') - connection.dialect.use_ansiquotes = True - -MySQL SQL Extensions --------------------- - -Many of the MySQL SQL extensions are handled through SQLAlchemy's generic -function and operator support:: - - table.select(table.c.password==func.md5('plaintext')) - table.select(table.c.username.op('regexp')('^[a-d]')) - -And of course any valid MySQL statement can be executed as a string as well. - -Some limited direct support for MySQL extensions to SQL is currently -available. - - * SELECT pragma:: - - select(..., prefixes=['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) - - * UPDATE with LIMIT:: - - update(..., mysql_limit=10) - -Troubleshooting ---------------- - -If you have problems that seem server related, first check that you are -using the most recent stable MySQL-Python package available. The Database -Notes page on the wiki at http://www.sqlalchemy.org is a good resource for -timely information affecting MySQL in SQLAlchemy. - -""" - -import datetime, decimal, inspect, re, sys -from array import array as _array - -from sqlalchemy import exc, log, schema, sql, util -from sqlalchemy.sql import operators as sql_operators -from sqlalchemy.sql import functions as sql_functions -from sqlalchemy.sql import compiler - -from sqlalchemy.engine import base as engine_base, default -from sqlalchemy import types as sqltypes - - -__all__ = ( - 'MSBigInteger', 'MSMediumInteger', 'MSBinary', 'MSBit', 'MSBlob', 'MSBoolean', - 'MSChar', 'MSDate', 'MSDateTime', 'MSDecimal', 'MSDouble', - 'MSEnum', 'MSFloat', 'MSInteger', 'MSLongBlob', 'MSLongText', - 'MSMediumBlob', 'MSMediumText', 'MSNChar', 'MSNVarChar', - 'MSNumeric', 'MSSet', 'MSSmallInteger', 'MSString', 'MSText', - 'MSTime', 'MSTimeStamp', 'MSTinyBlob', 'MSTinyInteger', - 'MSTinyText', 'MSVarBinary', 'MSYear' ) - - -RESERVED_WORDS = set( - ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', - 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', - 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check', - 'collate', 'column', 'condition', 'constraint', 'continue', 'convert', - 'create', 'cross', 'current_date', 'current_time', 'current_timestamp', - 'current_user', 'cursor', 'database', 'databases', 'day_hour', - 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal', - 'declare', 'default', 'delayed', 'delete', 'desc', 'describe', - 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop', - 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists', - 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8', - 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', 'having', - 'high_priority', 'hour_microsecond', 'hour_minute', 'hour_second', 'if', - 'ignore', 'in', 'index', 'infile', 'inner', 'inout', 'insensitive', - 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8', 'integer', - 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys', 'kill', - 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines', 'load', - 'localtime', 'localtimestamp', 'lock', 'long', 'longblob', 'longtext', - 'loop', 'low_priority', 'master_ssl_verify_server_cert', 'match', - 'mediumblob', 'mediumint', 'mediumtext', 'middleint', - 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural', - 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize', - 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile', - 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads', - 'read_only', 'read_write', 'real', 'references', 'regexp', 'release', - 'rename', 'repeat', 'replace', 'require', 'restrict', 'return', - 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond', - 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial', - 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning', - 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl', - 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob', - 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo', - 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use', - 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary', - 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with', - 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0 - 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1 - 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', - 'read_only', 'read_write', # 5.1 - ]) - -AUTOCOMMIT_RE = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)', - re.I | re.UNICODE) -SET_RE = re.compile( - r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', - re.I | re.UNICODE) - - -class _NumericType(object): - """Base for MySQL numeric types.""" - - def __init__(self, kw): - self.unsigned = kw.pop('unsigned', False) - self.zerofill = kw.pop('zerofill', False) - - def _extend(self, spec): - "Extend a numeric-type declaration with MySQL specific extensions." - - if self.unsigned: - spec += ' UNSIGNED' - if self.zerofill: - spec += ' ZEROFILL' - return spec - - -class _StringType(object): - """Base for MySQL string types.""" - - def __init__(self, charset=None, collation=None, - ascii=False, unicode=False, binary=False, - national=False, **kwargs): - self.charset = charset - # allow collate= or collation= - self.collation = kwargs.get('collate', collation) - self.ascii = ascii - self.unicode = unicode - self.binary = binary - self.national = national - - def _extend(self, spec): - """Extend a string-type declaration with standard SQL CHARACTER SET / - COLLATE annotations and MySQL specific extensions. - """ - - if self.charset: - charset = 'CHARACTER SET %s' % self.charset - elif self.ascii: - charset = 'ASCII' - elif self.unicode: - charset = 'UNICODE' - else: - charset = None - - if self.collation: - collation = 'COLLATE %s' % self.collation - elif self.binary: - collation = 'BINARY' - else: - collation = None - - if self.national: - # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. - return ' '.join([c for c in ('NATIONAL', spec, collation) - if c is not None]) - return ' '.join([c for c in (spec, charset, collation) - if c is not None]) - - def __repr__(self): - attributes = inspect.getargspec(self.__init__)[0][1:] - attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) - - params = {} - for attr in attributes: - val = getattr(self, attr) - if val is not None and val is not False: - params[attr] = val - - return "%s(%s)" % (self.__class__.__name__, - ', '.join(['%s=%r' % (k, params[k]) for k in params])) - - -class MSNumeric(sqltypes.Numeric, _NumericType): - """MySQL NUMERIC type.""" - - def __init__(self, precision=10, scale=2, asdecimal=True, **kw): - """Construct a NUMERIC. - - :param precision: Total digits in this number. If scale and precision - are both None, values are stored to limits allowed by the server. - - :param scale: The number of digits after the decimal point. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - _NumericType.__init__(self, kw) - sqltypes.Numeric.__init__(self, precision, scale, asdecimal=asdecimal, **kw) - - def get_col_spec(self): - if self.precision is None: - return self._extend("NUMERIC") - else: - return self._extend("NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}) - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - if not self.asdecimal: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - else: - return None - - -class MSDecimal(MSNumeric): - """MySQL DECIMAL type.""" - - def __init__(self, precision=10, scale=2, asdecimal=True, **kw): - """Construct a DECIMAL. - - :param precision: Total digits in this number. If scale and precision - are both None, values are stored to limits allowed by the server. - - :param scale: The number of digits after the decimal point. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - super(MSDecimal, self).__init__(precision, scale, asdecimal=asdecimal, **kw) - - def get_col_spec(self): - if self.precision is None: - return self._extend("DECIMAL") - elif self.scale is None: - return self._extend("DECIMAL(%(precision)s)" % {'precision': self.precision}) - else: - return self._extend("DECIMAL(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}) - - -class MSDouble(sqltypes.Float, _NumericType): - """MySQL DOUBLE type.""" - - def __init__(self, precision=None, scale=None, asdecimal=True, **kw): - """Construct a DOUBLE. - - :param precision: Total digits in this number. If scale and precision - are both None, values are stored to limits allowed by the server. - - :param scale: The number of digits after the decimal point. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - if ((precision is None and scale is not None) or - (precision is not None and scale is None)): - raise exc.ArgumentError( - "You must specify both precision and scale or omit " - "both altogether.") - - _NumericType.__init__(self, kw) - sqltypes.Float.__init__(self, asdecimal=asdecimal, **kw) - self.scale = scale - self.precision = precision - - def get_col_spec(self): - if self.precision is not None and self.scale is not None: - return self._extend("DOUBLE(%(precision)s, %(scale)s)" % - {'precision': self.precision, - 'scale' : self.scale}) - else: - return self._extend('DOUBLE') - - -class MSReal(MSDouble): - """MySQL REAL type.""" - - def __init__(self, precision=None, scale=None, asdecimal=True, **kw): - """Construct a REAL. - - :param precision: Total digits in this number. If scale and precision - are both None, values are stored to limits allowed by the server. - - :param scale: The number of digits after the decimal point. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - MSDouble.__init__(self, precision, scale, asdecimal, **kw) - - def get_col_spec(self): - if self.precision is not None and self.scale is not None: - return self._extend("REAL(%(precision)s, %(scale)s)" % - {'precision': self.precision, - 'scale' : self.scale}) - else: - return self._extend('REAL') - - -class MSFloat(sqltypes.Float, _NumericType): - """MySQL FLOAT type.""" - - def __init__(self, precision=None, scale=None, asdecimal=False, **kw): - """Construct a FLOAT. - - :param precision: Total digits in this number. If scale and precision - are both None, values are stored to limits allowed by the server. - - :param scale: The number of digits after the decimal point. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - _NumericType.__init__(self, kw) - sqltypes.Float.__init__(self, asdecimal=asdecimal, **kw) - self.scale = scale - self.precision = precision - - def get_col_spec(self): - if self.scale is not None and self.precision is not None: - return self._extend("FLOAT(%s, %s)" % (self.precision, self.scale)) - elif self.precision is not None: - return self._extend("FLOAT(%s)" % (self.precision,)) - else: - return self._extend("FLOAT") - - def bind_processor(self, dialect): - return None - - -class MSInteger(sqltypes.Integer, _NumericType): - """MySQL INTEGER type.""" - - def __init__(self, display_width=None, **kw): - """Construct an INTEGER. - - :param display_width: Optional, maximum display width for this number. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - if 'length' in kw: - util.warn_deprecated("'length' is deprecated for MSInteger and subclasses. Use 'display_width'.") - self.display_width = kw.pop('length') - else: - self.display_width = display_width - _NumericType.__init__(self, kw) - sqltypes.Integer.__init__(self, **kw) - - def get_col_spec(self): - if self.display_width is not None: - return self._extend("INTEGER(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("INTEGER") - - -class MSBigInteger(MSInteger): - """MySQL BIGINTEGER type.""" - - def __init__(self, display_width=None, **kw): - """Construct a BIGINTEGER. - - :param display_width: Optional, maximum display width for this number. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - super(MSBigInteger, self).__init__(display_width, **kw) - - def get_col_spec(self): - if self.display_width is not None: - return self._extend("BIGINT(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("BIGINT") - - -class MSMediumInteger(MSInteger): - """MySQL MEDIUMINTEGER type.""" - - def __init__(self, display_width=None, **kw): - """Construct a MEDIUMINTEGER - - :param display_width: Optional, maximum display width for this number. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - super(MSMediumInteger, self).__init__(display_width, **kw) - - def get_col_spec(self): - if self.display_width is not None: - return self._extend("MEDIUMINT(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("MEDIUMINT") - - - -class MSTinyInteger(MSInteger): - """MySQL TINYINT type.""" - - def __init__(self, display_width=None, **kw): - """Construct a TINYINT. - - Note: following the usual MySQL conventions, TINYINT(1) columns - reflected during Table(..., autoload=True) are treated as - Boolean columns. - - :param display_width: Optional, maximum display width for this number. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - super(MSTinyInteger, self).__init__(display_width, **kw) - - def get_col_spec(self): - if self.display_width is not None: - return self._extend("TINYINT(%s)" % self.display_width) - else: - return self._extend("TINYINT") - - -class MSSmallInteger(sqltypes.Smallinteger, MSInteger): - """MySQL SMALLINTEGER type.""" - - def __init__(self, display_width=None, **kw): - """Construct a SMALLINTEGER. - - :param display_width: Optional, maximum display width for this number. - - :param unsigned: a boolean, optional. - - :param zerofill: Optional. If true, values will be stored as strings - left-padded with zeros. Note that this does not effect the values - returned by the underlying database API, which continue to be - numeric. - - """ - self.display_width = display_width - _NumericType.__init__(self, kw) - sqltypes.SmallInteger.__init__(self, **kw) - - def get_col_spec(self): - if self.display_width is not None: - return self._extend("SMALLINT(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("SMALLINT") - - -class MSBit(sqltypes.TypeEngine): - """MySQL BIT type. - - This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater for - MyISAM, MEMORY, InnoDB and BDB. For older versions, use a MSTinyInteger() - type. - - """ - - def __init__(self, length=None): - """Construct a BIT. - - :param length: Optional, number of bits. - - """ - self.length = length - - def result_processor(self, dialect): - """Convert a MySQL's 64 bit, variable length binary string to a long.""" - def process(value): - if value is not None: - v = 0L - for i in map(ord, value): - v = v << 8 | i - value = v - return value - return process - - def get_col_spec(self): - if self.length is not None: - return "BIT(%s)" % self.length - else: - return "BIT" - - -class MSDateTime(sqltypes.DateTime): - """MySQL DATETIME type.""" - - def get_col_spec(self): - return "DATETIME" - - -class MSDate(sqltypes.Date): - """MySQL DATE type.""" - - def get_col_spec(self): - return "DATE" - - -class MSTime(sqltypes.Time): - """MySQL TIME type.""" - - def get_col_spec(self): - return "TIME" - - def result_processor(self, dialect): - def process(value): - # convert from a timedelta value - if value is not None: - return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) - else: - return None - return process - -class MSTimeStamp(sqltypes.TIMESTAMP): - """MySQL TIMESTAMP type. - - To signal the orm to automatically re-select modified rows to retrieve the - updated timestamp, add a ``server_default`` to your - :class:`~sqlalchemy.Column` specification:: - - from sqlalchemy.databases import mysql - Column('updated', mysql.MSTimeStamp, - server_default=sql.text('CURRENT_TIMESTAMP') - ) - - The full range of MySQL 4.1+ TIMESTAMP defaults can be specified in - the the default:: - - server_default=sql.text('CURRENT TIMESTAMP ON UPDATE CURRENT_TIMESTAMP') - - """ - - def get_col_spec(self): - return "TIMESTAMP" - - -class MSYear(sqltypes.TypeEngine): - """MySQL YEAR type, for single byte storage of years 1901-2155.""" - - def __init__(self, display_width=None): - self.display_width = display_width - - def get_col_spec(self): - if self.display_width is None: - return "YEAR" - else: - return "YEAR(%s)" % self.display_width - -class MSText(_StringType, sqltypes.Text): - """MySQL TEXT type, for text up to 2^16 characters.""" - - def __init__(self, length=None, **kwargs): - """Construct a TEXT. - - :param length: Optional, if provided the server may optimize storage - by substituting the smallest TEXT type sufficient to store - ``length`` characters. - - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. - - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. - - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. - - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. - - :param national: Optional. If true, use the server's configured - national character set. - - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.Text.__init__(self, length, - kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None)) - - def get_col_spec(self): - if self.length: - return self._extend("TEXT(%d)" % self.length) - else: - return self._extend("TEXT") - - -class MSTinyText(MSText): - """MySQL TINYTEXT type, for text up to 2^8 characters.""" - - def __init__(self, **kwargs): - """Construct a TINYTEXT. - - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. - - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. - - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. - - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. - - :param national: Optional. If true, use the server's configured - national character set. - - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. - - """ - - super(MSTinyText, self).__init__(**kwargs) - - def get_col_spec(self): - return self._extend("TINYTEXT") - - -class MSMediumText(MSText): - """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" - - def __init__(self, **kwargs): - """Construct a MEDIUMTEXT. - - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. - - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. - - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. - - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. - - :param national: Optional. If true, use the server's configured - national character set. - - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. - - """ - super(MSMediumText, self).__init__(**kwargs) - - def get_col_spec(self): - return self._extend("MEDIUMTEXT") - - -class MSLongText(MSText): - """MySQL LONGTEXT type, for text up to 2^32 characters.""" - - def __init__(self, **kwargs): - """Construct a LONGTEXT. - - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. - - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. - - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. - - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. - - :param national: Optional. If true, use the server's configured - national character set. - - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. - - """ - super(MSLongText, self).__init__(**kwargs) - - def get_col_spec(self): - return self._extend("LONGTEXT") - - -class MSString(_StringType, sqltypes.String): - """MySQL VARCHAR type, for variable-length character data.""" - - def __init__(self, length=None, **kwargs): - """Construct a VARCHAR. - - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. - - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. - - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. - - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. - - :param national: Optional. If true, use the server's configured - national character set. - - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.String.__init__(self, length, - kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None)) - - def get_col_spec(self): - if self.length: - return self._extend("VARCHAR(%d)" % self.length) - else: - return self._extend("VARCHAR") - - -class MSChar(_StringType, sqltypes.CHAR): - """MySQL CHAR type, for fixed-length character data.""" - - def __init__(self, length, **kwargs): - """Construct an NCHAR. - - :param length: Maximum data length, in characters. - - :param binary: Optional, use the default binary collation for the - national character set. This does not affect the type of data - stored, use a BINARY type for binary data. - - :param collation: Optional, request a particular collation. Must be - compatible with the national character set. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.CHAR.__init__(self, length, - kwargs.get('convert_unicode', False)) - - def get_col_spec(self): - return self._extend("CHAR(%(length)s)" % {'length' : self.length}) - - -class MSNVarChar(_StringType, sqltypes.String): - """MySQL NVARCHAR type. - - For variable-length character data in the server's configured national - character set. - """ - - def __init__(self, length=None, **kwargs): - """Construct an NVARCHAR. - - :param length: Maximum data length, in characters. - - :param binary: Optional, use the default binary collation for the - national character set. This does not affect the type of data - stored, use a BINARY type for binary data. - - :param collation: Optional, request a particular collation. Must be - compatible with the national character set. - - """ - kwargs['national'] = True - _StringType.__init__(self, **kwargs) - sqltypes.String.__init__(self, length, - kwargs.get('convert_unicode', False)) - - def get_col_spec(self): - # We'll actually generate the equiv. "NATIONAL VARCHAR" instead - # of "NVARCHAR". - return self._extend("VARCHAR(%(length)s)" % {'length': self.length}) - - -class MSNChar(_StringType, sqltypes.CHAR): - """MySQL NCHAR type. - - For fixed-length character data in the server's configured national - character set. - """ - - def __init__(self, length=None, **kwargs): - """Construct an NCHAR. Arguments are: - - :param length: Maximum data length, in characters. - - :param binary: Optional, use the default binary collation for the - national character set. This does not affect the type of data - stored, use a BINARY type for binary data. - - :param collation: Optional, request a particular collation. Must be - compatible with the national character set. - - """ - kwargs['national'] = True - _StringType.__init__(self, **kwargs) - sqltypes.CHAR.__init__(self, length, - kwargs.get('convert_unicode', False)) - def get_col_spec(self): - # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". - return self._extend("CHAR(%(length)s)" % {'length': self.length}) - - -class _BinaryType(sqltypes.Binary): - """Base for MySQL binary types.""" - - def get_col_spec(self): - if self.length: - return "BLOB(%d)" % self.length - else: - return "BLOB" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - else: - return util.buffer(value) - return process - -class MSVarBinary(_BinaryType): - """MySQL VARBINARY type, for variable length binary data.""" - - def __init__(self, length=None, **kw): - """Construct a VARBINARY. Arguments are: - - :param length: Maximum data length, in characters. - - """ - super(MSVarBinary, self).__init__(length, **kw) - - def get_col_spec(self): - if self.length: - return "VARBINARY(%d)" % self.length - else: - return "BLOB" - - -class MSBinary(_BinaryType): - """MySQL BINARY type, for fixed length binary data""" - - def __init__(self, length=None, **kw): - """Construct a BINARY. - - This is a fixed length type, and short values will be right-padded - with a server-version-specific pad value. - - :param length: Maximum data length, in bytes. If length is not - specified, this will generate a BLOB. This usage is deprecated. - - """ - super(MSBinary, self).__init__(length, **kw) - - def get_col_spec(self): - if self.length: - return "BINARY(%d)" % self.length - else: - return "BLOB" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - else: - return util.buffer(value) - return process - -class MSBlob(_BinaryType): - """MySQL BLOB type, for binary data up to 2^16 bytes""" - - def __init__(self, length=None, **kw): - """Construct a BLOB. Arguments are: - - :param length: Optional, if provided the server may optimize storage - by substituting the smallest TEXT type sufficient to store - ``length`` characters. - - """ - super(MSBlob, self).__init__(length, **kw) - - def get_col_spec(self): - if self.length: - return "BLOB(%d)" % self.length - else: - return "BLOB" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - else: - return util.buffer(value) - return process - - def __repr__(self): - return "%s()" % self.__class__.__name__ - - -class MSTinyBlob(MSBlob): - """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" - - def get_col_spec(self): - return "TINYBLOB" - - -class MSMediumBlob(MSBlob): - """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" - - def get_col_spec(self): - return "MEDIUMBLOB" - - -class MSLongBlob(MSBlob): - """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" - - def get_col_spec(self): - return "LONGBLOB" - - -class MSEnum(MSString): - """MySQL ENUM type.""" - - def __init__(self, *enums, **kw): - """Construct an ENUM. - - Example: - - Column('myenum', MSEnum("foo", "bar", "baz")) - - Arguments are: - - :param enums: The range of valid values for this ENUM. Values will be - quoted when generating the schema according to the quoting flag (see - below). - - :param strict: Defaults to False: ensure that a given value is in this - ENUM's range of permissible values when inserting or updating rows. - Note that MySQL will not raise a fatal error if you attempt to store - an out of range value- an alternate value will be stored instead. - (See MySQL ENUM documentation.) - - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. - - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. - - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. - - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. - - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. - - :param quoting: Defaults to 'auto': automatically determine enum value - quoting. If all enum values are surrounded by the same quoting - character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. - - 'quoted': values in enums are already quoted, they will be used - directly when generating the schema. - - 'unquoted': values in enums are not quoted, they will be escaped and - surrounded by single quotes when generating the schema. - - Previous versions of this type always required manually quoted - values to be supplied; future versions will always quote the string - literals for you. This is a transitional option. - - """ - self.quoting = kw.pop('quoting', 'auto') - - if self.quoting == 'auto': - # What quoting character are we using? - q = None - for e in enums: - if len(e) == 0: - self.quoting = 'unquoted' - break - elif q is None: - q = e[0] - - if e[0] != q or e[-1] != q: - self.quoting = 'unquoted' - break - else: - self.quoting = 'quoted' - - if self.quoting == 'quoted': - util.warn_pending_deprecation( - 'Manually quoting ENUM value literals is deprecated. Supply ' - 'unquoted values and use the quoting= option in cases of ' - 'ambiguity.') - strip_enums = [] - for a in enums: - if a[0:1] == '"' or a[0:1] == "'": - # strip enclosing quotes and unquote interior - a = a[1:-1].replace(a[0] * 2, a[0]) - strip_enums.append(a) - self.enums = strip_enums - else: - self.enums = list(enums) - - self.strict = kw.pop('strict', False) - length = max([len(v) for v in self.enums] + [0]) - super(MSEnum, self).__init__(length, **kw) - - def bind_processor(self, dialect): - super_convert = super(MSEnum, self).bind_processor(dialect) - def process(value): - if self.strict and value is not None and value not in self.enums: - raise exc.InvalidRequestError('"%s" not a valid value for ' - 'this enum' % value) - if super_convert: - return super_convert(value) - else: - return value - return process - - def get_col_spec(self): - quoted_enums = [] - for e in self.enums: - quoted_enums.append("'%s'" % e.replace("'", "''")) - return self._extend("ENUM(%s)" % ",".join(quoted_enums)) - -class MSSet(MSString): - """MySQL SET type.""" - - def __init__(self, *values, **kw): - """Construct a SET. - - Example:: - - Column('myset', MSSet("'foo'", "'bar'", "'baz'")) - - Arguments are: - - :param values: The range of valid values for this SET. Values will be - used exactly as they appear when generating schemas. Strings must - be quoted, as in the example above. Single-quotes are suggested for - ANSI compatibility and are required for portability to servers with - ANSI_QUOTES enabled. - - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. - - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. - - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. - - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. - - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. - - """ - self.__ddl_values = values - - strip_values = [] - for a in values: - if a[0:1] == '"' or a[0:1] == "'": - # strip enclosing quotes and unquote interior - a = a[1:-1].replace(a[0] * 2, a[0]) - strip_values.append(a) - - self.values = strip_values - length = max([len(v) for v in strip_values] + [0]) - super(MSSet, self).__init__(length, **kw) - - def result_processor(self, dialect): - def process(value): - # The good news: - # No ',' quoting issues- commas aren't allowed in SET values - # The bad news: - # Plenty of driver inconsistencies here. - if isinstance(value, util.set_types): - # ..some versions convert '' to an empty set - if not value: - value.add('') - # ..some return sets.Set, even for pythons that have __builtin__.set - if not isinstance(value, set): - value = set(value) - return value - # ...and some versions return strings - if value is not None: - return set(value.split(',')) - else: - return value - return process - - def bind_processor(self, dialect): - super_convert = super(MSSet, self).bind_processor(dialect) - def process(value): - if value is None or isinstance(value, (int, long, basestring)): - pass - else: - if None in value: - value = set(value) - value.remove(None) - value.add('') - value = ','.join(value) - if super_convert: - return super_convert(value) - else: - return value - return process - - def get_col_spec(self): - return self._extend("SET(%s)" % ",".join(self.__ddl_values)) - - -class MSBoolean(sqltypes.Boolean): - """MySQL BOOLEAN type.""" - - def get_col_spec(self): - return "BOOL" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - -colspecs = { - sqltypes.Integer: MSInteger, - sqltypes.Smallinteger: MSSmallInteger, - sqltypes.Numeric: MSNumeric, - sqltypes.Float: MSFloat, - sqltypes.DateTime: MSDateTime, - sqltypes.Date: MSDate, - sqltypes.Time: MSTime, - sqltypes.String: MSString, - sqltypes.Binary: MSBlob, - sqltypes.Boolean: MSBoolean, - sqltypes.Text: MSText, - sqltypes.CHAR: MSChar, - sqltypes.NCHAR: MSNChar, - sqltypes.TIMESTAMP: MSTimeStamp, - sqltypes.BLOB: MSBlob, - MSDouble: MSDouble, - MSReal: MSReal, - _BinaryType: _BinaryType, -} - -# Everything 3.23 through 5.1 excepting OpenGIS types. -ischema_names = { - 'bigint': MSBigInteger, - 'binary': MSBinary, - 'bit': MSBit, - 'blob': MSBlob, - 'boolean':MSBoolean, - 'char': MSChar, - 'date': MSDate, - 'datetime': MSDateTime, - 'decimal': MSDecimal, - 'double': MSDouble, - 'enum': MSEnum, - 'fixed': MSDecimal, - 'float': MSFloat, - 'int': MSInteger, - 'integer': MSInteger, - 'longblob': MSLongBlob, - 'longtext': MSLongText, - 'mediumblob': MSMediumBlob, - 'mediumint': MSMediumInteger, - 'mediumtext': MSMediumText, - 'nchar': MSNChar, - 'nvarchar': MSNVarChar, - 'numeric': MSNumeric, - 'set': MSSet, - 'smallint': MSSmallInteger, - 'text': MSText, - 'time': MSTime, - 'timestamp': MSTimeStamp, - 'tinyblob': MSTinyBlob, - 'tinyint': MSTinyInteger, - 'tinytext': MSTinyText, - 'varbinary': MSVarBinary, - 'varchar': MSString, - 'year': MSYear, -} - - -class MySQLExecutionContext(default.DefaultExecutionContext): - def post_exec(self): - if self.compiled.isinsert and not self.executemany: - if (not len(self._last_inserted_ids) or - self._last_inserted_ids[0] is None): - self._last_inserted_ids = ([self.cursor.lastrowid] + - self._last_inserted_ids[1:]) - elif (not self.isupdate and not self.should_autocommit and - self.statement and SET_RE.match(self.statement)): - # This misses if a user forces autocommit on text('SET NAMES'), - # which is probably a programming error anyhow. - self.connection.info.pop(('mysql', 'charset'), None) - - def should_autocommit_text(self, statement): - return AUTOCOMMIT_RE.match(statement) - - -class MySQLDialect(default.DefaultDialect): - """Details of the MySQL dialect. Not used directly in application code.""" - name = 'mysql' - supports_alter = True - supports_unicode_statements = False - # identifiers are 64, however aliases can be 255... - max_identifier_length = 255 - supports_sane_rowcount = True - default_paramstyle = 'format' - - def __init__(self, use_ansiquotes=None, **kwargs): - self.use_ansiquotes = use_ansiquotes - default.DefaultDialect.__init__(self, **kwargs) - - def dbapi(cls): - import MySQLdb as mysql - return mysql - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(database='db', username='user', - password='passwd') - opts.update(url.query) - - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'connect_timeout', int) - util.coerce_kw_type(opts, 'client_flag', int) - util.coerce_kw_type(opts, 'local_infile', int) - # Note: using either of the below will cause all strings to be returned - # as Unicode, both in raw SQL operations and with column types like - # String and MSString. - util.coerce_kw_type(opts, 'use_unicode', bool) - util.coerce_kw_type(opts, 'charset', str) - - # Rich values 'cursorclass' and 'conv' are not supported via - # query string. - - ssl = {} - for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: - if key in opts: - ssl[key[4:]] = opts[key] - util.coerce_kw_type(ssl, key[4:], str) - del opts[key] - if ssl: - opts['ssl'] = ssl - - # FOUND_ROWS must be set in CLIENT_FLAGS to enable - # supports_sane_rowcount. - client_flag = opts.get('client_flag', 0) - if self.dbapi is not None: - try: - import MySQLdb.constants.CLIENT as CLIENT_FLAGS - client_flag |= CLIENT_FLAGS.FOUND_ROWS - except: - pass - opts['client_flag'] = client_flag - return [[], opts] - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def do_executemany(self, cursor, statement, parameters, context=None): - rowcount = cursor.executemany(statement, parameters) - if context is not None: - context._rowcount = rowcount - - def supports_unicode_statements(self): - return True - - def do_commit(self, connection): - """Execute a COMMIT.""" - - # COMMIT/ROLLBACK were introduced in 3.23.15. - # Yes, we have at least one user who has to talk to these old versions! - # - # Ignore commit/rollback if support isn't present, otherwise even basic - # operations via autocommit fail. - try: - connection.commit() - except: - if self._server_version_info(connection) < (3, 23, 15): - args = sys.exc_info()[1].args - if args and args[0] == 1064: - return - raise - - def do_rollback(self, connection): - """Execute a ROLLBACK.""" - - try: - connection.rollback() - except: - if self._server_version_info(connection) < (3, 23, 15): - args = sys.exc_info()[1].args - if args and args[0] == 1064: - return - raise - - def do_begin_twophase(self, connection, xid): - connection.execute("XA BEGIN %s", xid) - - def do_prepare_twophase(self, connection, xid): - connection.execute("XA END %s", xid) - connection.execute("XA PREPARE %s", xid) - - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): - if not is_prepared: - connection.execute("XA END %s", xid) - connection.execute("XA ROLLBACK %s", xid) - - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): - if not is_prepared: - self.do_prepare_twophase(connection, xid) - connection.execute("XA COMMIT %s", xid) - - def do_recover_twophase(self, connection): - resultset = connection.execute("XA RECOVER") - return [row['data'][0:row['gtrid_length']] for row in resultset] - - def do_ping(self, connection): - connection.ping() - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return e.args[0] in (2006, 2013, 2014, 2045, 2055) - elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, this is the error you get - return "(0, '')" in str(e) - else: - return False - - def get_default_schema_name(self, connection): - return connection.execute('SELECT DATABASE()').scalar() - get_default_schema_name = engine_base.connection_memoize( - ('dialect', 'default_schema_name'))(get_default_schema_name) - - def table_names(self, connection, schema): - """Return a Unicode SHOW TABLES from a given schema.""" - - charset = self._detect_charset(connection) - self._autoset_identifier_style(connection) - rp = connection.execute("SHOW TABLES FROM %s" % - self.identifier_preparer.quote_identifier(schema)) - return [row[0] for row in _compat_fetchall(rp, charset=charset)] - - def has_table(self, connection, table_name, schema=None): - # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly - # on macosx (and maybe win?) with multibyte table names. - # - # TODO: if this is not a problem on win, make the strategy swappable - # based on platform. DESCRIBE is slower. - - # [ticket:726] - # full_name = self.identifier_preparer.format_table(table, - # use_schema=True) - - self._autoset_identifier_style(connection) - - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, table_name)) - - st = "DESCRIBE %s" % full_name - rs = None - try: - try: - rs = connection.execute(st) - have = rs.rowcount > 0 - rs.close() - return have - except exc.SQLError, e: - if e.orig.args[0] == 1146: - return False - raise - finally: - if rs: - rs.close() - - def server_version_info(self, connection): - """A tuple of the database server version. - - Formats the remote server version as a tuple of version values, - e.g. ``(5, 0, 44)``. If there are strings in the version number - they will be in the tuple too, so don't count on these all being - ``int`` values. - - This is a fast check that does not require a round trip. It is also - cached per-Connection. - """ - - return self._server_version_info(connection.connection.connection) - server_version_info = engine_base.connection_memoize( - ('mysql', 'server_version_info'))(server_version_info) - - def _server_version_info(self, dbapi_con): - """Convert a MySQL-python server_info string into a tuple.""" - - version = [] - r = re.compile('[.\-]') - for n in r.split(dbapi_con.get_server_info()): - try: - version.append(int(n)) - except ValueError: - version.append(n) - return tuple(version) - - def reflecttable(self, connection, table, include_columns): - """Load column definitions from the server.""" - - charset = self._detect_charset(connection) - self._autoset_identifier_style(connection) - - try: - reflector = self.reflector - except AttributeError: - preparer = self.identifier_preparer - if (self.server_version_info(connection) < (4, 1) and - self.use_ansiquotes): - # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 - preparer = MySQLIdentifierPreparer(self) - - self.reflector = reflector = MySQLSchemaReflector(preparer) - - sql = self._show_create_table(connection, table, charset) - if sql.startswith('CREATE ALGORITHM'): - # Adapt views to something table-like. - columns = self._describe_table(connection, table, charset) - sql = reflector._describe_to_create(table, columns) - - self._adjust_casing(connection, table) - - return reflector.reflect(connection, table, sql, charset, - only=include_columns) - - def _adjust_casing(self, connection, table, charset=None): - """Adjust Table name to the server case sensitivity, if needed.""" - - casing = self._detect_casing(connection) - - # For winxx database hosts. TODO: is this really needed? - if casing == 1 and table.name != table.name.lower(): - table.name = table.name.lower() - lc_alias = schema._get_table_key(table.name, table.schema) - table.metadata.tables[lc_alias] = table - - - def _detect_charset(self, connection): - """Sniff out the character set in use for connection results.""" - - # Allow user override, won't sniff if force_charset is set. - if ('mysql', 'force_charset') in connection.info: - return connection.info[('mysql', 'force_charset')] - - # Note: MySQL-python 1.2.1c7 seems to ignore changes made - # on a connection via set_character_set() - if self.server_version_info(connection) < (4, 1, 0): - try: - return connection.connection.character_set_name() - except AttributeError: - # < 1.2.1 final MySQL-python drivers have no charset support. - # a query is needed. - pass - - # Prefer 'character_set_results' for the current connection over the - # value in the driver. SET NAMES or individual variable SETs will - # change the charset without updating the driver's view of the world. - # - # If it's decided that issuing that sort of SQL leaves you SOL, then - # this can prefer the driver value. - rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") - opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)]) - - if 'character_set_results' in opts: - return opts['character_set_results'] - try: - return connection.connection.character_set_name() - except AttributeError: - # Still no charset on < 1.2.1 final... - if 'character_set' in opts: - return opts['character_set'] - else: - util.warn( - "Could not detect the connection character set with this " - "combination of MySQL server and MySQL-python. " - "MySQL-python >= 1.2.2 is recommended. Assuming latin1.") - return 'latin1' - _detect_charset = engine_base.connection_memoize( - ('mysql', 'charset'))(_detect_charset) - - - def _detect_casing(self, connection): - """Sniff out identifier case sensitivity. - - Cached per-connection. This value can not change without a server - restart. - - """ - # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html - - charset = self._detect_charset(connection) - row = _compat_fetchone(connection.execute( - "SHOW VARIABLES LIKE 'lower_case_table_names'"), - charset=charset) - if not row: - cs = 0 - else: - # 4.0.15 returns OFF or ON according to [ticket:489] - # 3.23 doesn't, 4.0.27 doesn't.. - if row[1] == 'OFF': - cs = 0 - elif row[1] == 'ON': - cs = 1 - else: - cs = int(row[1]) - row.close() - return cs - _detect_casing = engine_base.connection_memoize( - ('mysql', 'lower_case_table_names'))(_detect_casing) - - def _detect_collations(self, connection): - """Pull the active COLLATIONS list from the server. - - Cached per-connection. - """ - - collations = {} - if self.server_version_info(connection) < (4, 1, 0): - pass - else: - charset = self._detect_charset(connection) - rs = connection.execute('SHOW COLLATION') - for row in _compat_fetchall(rs, charset): - collations[row[0]] = row[1] - return collations - _detect_collations = engine_base.connection_memoize( - ('mysql', 'collations'))(_detect_collations) - - def use_ansiquotes(self, useansi): - self._use_ansiquotes = useansi - if useansi: - self.preparer = MySQLANSIIdentifierPreparer - else: - self.preparer = MySQLIdentifierPreparer - # icky - if hasattr(self, 'identifier_preparer'): - self.identifier_preparer = self.preparer(self) - if hasattr(self, 'reflector'): - del self.reflector - - use_ansiquotes = property(lambda s: s._use_ansiquotes, use_ansiquotes, - doc="True if ANSI_QUOTES is in effect.") - - def _autoset_identifier_style(self, connection, charset=None): - """Detect and adjust for the ANSI_QUOTES sql mode. - - If the dialect's use_ansiquotes is unset, query the server's sql mode - and reset the identifier style. - - Note that this currently *only* runs during reflection. Ideally this - would run the first time a connection pool connects to the database, - but the infrastructure for that is not yet in place. - """ - - if self.use_ansiquotes is not None: - return - - row = _compat_fetchone( - connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), - charset=charset) - if not row: - mode = '' - else: - mode = row[1] or '' - # 4.0 - if mode.isdigit(): - mode_no = int(mode) - mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' - - self.use_ansiquotes = 'ANSI_QUOTES' in mode - - def _show_create_table(self, connection, table, charset=None, - full_name=None): - """Run SHOW CREATE TABLE for a ``Table``.""" - - if full_name is None: - full_name = self.identifier_preparer.format_table(table) - st = "SHOW CREATE TABLE %s" % full_name - - rp = None - try: - try: - rp = connection.execute(st) - except exc.SQLError, e: - if e.orig.args[0] == 1146: - raise exc.NoSuchTableError(full_name) - else: - raise - row = _compat_fetchone(rp, charset=charset) - if not row: - raise exc.NoSuchTableError(full_name) - return row[1].strip() - finally: - if rp: - rp.close() - - return sql - - def _describe_table(self, connection, table, charset=None, - full_name=None): - """Run DESCRIBE for a ``Table`` and return processed rows.""" - - if full_name is None: - full_name = self.identifier_preparer.format_table(table) - st = "DESCRIBE %s" % full_name - - rp, rows = None, None - try: - try: - rp = connection.execute(st) - except exc.SQLError, e: - if e.orig.args[0] == 1146: - raise exc.NoSuchTableError(full_name) - else: - raise - rows = _compat_fetchall(rp, charset=charset) - finally: - if rp: - rp.close() - return rows - -class _MySQLPythonRowProxy(object): - """Return consistent column values for all versions of MySQL-python. - - Smooth over data type issues (esp. with alpha driver versions) and - normalize strings as Unicode regardless of user-configured driver - encoding settings. - """ - - # Some MySQL-python versions can return some columns as - # sets.Set(['value']) (seriously) but thankfully that doesn't - # seem to come up in DDL queries. - - def __init__(self, rowproxy, charset): - self.rowproxy = rowproxy - self.charset = charset - def __getitem__(self, index): - item = self.rowproxy[index] - if isinstance(item, _array): - item = item.tostring() - if self.charset and isinstance(item, str): - return item.decode(self.charset) - else: - return item - def __getattr__(self, attr): - item = getattr(self.rowproxy, attr) - if isinstance(item, _array): - item = item.tostring() - if self.charset and isinstance(item, str): - return item.decode(self.charset) - else: - return item - - -class MySQLCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() - operators.update({ - sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), - sql_operators.mod: '%%', - sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y) - }) - functions = compiler.DefaultCompiler.functions.copy() - functions.update ({ - sql_functions.random: 'rand%(expr)s', - "utc_timestamp":"UTC_TIMESTAMP" - }) - - extract_map = compiler.DefaultCompiler.extract_map.copy() - extract_map.update ({ - 'milliseconds': 'millisecond', - }) - - def visit_typeclause(self, typeclause): - type_ = typeclause.type.dialect_impl(self.dialect) - if isinstance(type_, MSInteger): - if getattr(type_, 'unsigned', False): - return 'UNSIGNED INTEGER' - else: - return 'SIGNED INTEGER' - elif isinstance(type_, (MSDecimal, MSDateTime, MSDate, MSTime)): - return type_.get_col_spec() - elif isinstance(type_, MSText): - return 'CHAR' - elif (isinstance(type_, _StringType) and not - isinstance(type_, (MSEnum, MSSet))): - if getattr(type_, 'length'): - return 'CHAR(%s)' % type_.length - else: - return 'CHAR' - elif isinstance(type_, _BinaryType): - return 'BINARY' - elif isinstance(type_, MSNumeric): - return type_.get_col_spec().replace('NUMERIC', 'DECIMAL') - elif isinstance(type_, MSTimeStamp): - return 'DATETIME' - elif isinstance(type_, (MSDateTime, MSDate, MSTime)): - return type_.get_col_spec() - else: - return None - - def visit_cast(self, cast, **kwargs): - # No cast until 4, no decimals until 5. - type_ = self.process(cast.typeclause) - if type_ is None: - return self.process(cast.clause) - - return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) - - - def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy MySQLDB dialect now automatically escapes '%' in text() expressions to '%%'.") - return text.replace('%', '%%') - - def get_select_precolumns(self, select): - if isinstance(select._distinct, basestring): - return select._distinct.upper() + " " - elif select._distinct: - return "DISTINCT " - else: - return "" - - def visit_join(self, join, asfrom=False, **kwargs): - # 'JOIN ... ON ...' for inner joins isn't available until 4.0. - # Apparently < 3.23.17 requires theta joins for inner joins - # (but not outer). Not generating these currently, but - # support can be added, preferably after dialects are - # refactored to be version-sensitive. - return ''.join( - (self.process(join.left, asfrom=True), - (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), - self.process(join.right, asfrom=True), - " ON ", - self.process(join.onclause))) - - def for_update_clause(self, select): - if select.for_update == 'read': - return ' LOCK IN SHARE MODE' - else: - return super(MySQLCompiler, self).for_update_clause(select) - - def limit_clause(self, select): - # MySQL supports: - # LIMIT <limit> - # LIMIT <offset>, <limit> - # and in server versions > 3.3: - # LIMIT <limit> OFFSET <offset> - # The latter is more readable for offsets but we're stuck with the - # former until we can refine dialects by server revision. - - limit, offset = select._limit, select._offset - - if (limit, offset) == (None, None): - return '' - elif offset is not None: - # As suggested by the MySQL docs, need to apply an - # artificial limit if one wasn't provided - if limit is None: - limit = 18446744073709551615 - return ' \n LIMIT %s, %s' % (offset, limit) - else: - # No offset provided, so just use the limit - return ' \n LIMIT %s' % (limit,) - - def visit_update(self, update_stmt): - self.stack.append({'from': set([update_stmt.table])}) - - self.isupdate = True - colparams = self._get_colparams(update_stmt) - - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) - - if update_stmt._whereclause: - text += " WHERE " + self.process(update_stmt._whereclause) - - limit = update_stmt.kwargs.get('mysql_limit', None) - if limit: - text += " LIMIT %s" % limit - - self.stack.pop(-1) - - return text - -# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. -# Starting with MySQL 4.1.2, these indexes are created automatically. -# In older versions, the indexes must be created explicitly or the -# creation of foreign key constraints fails." - -class MySQLSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, first_pk=False): - """Builds column DDL.""" - - colspec = [self.preparer.format_column(column), - column.type.dialect_impl(self.dialect).get_col_spec()] - - default = self.get_column_default_string(column) - if default is not None: - colspec.append('DEFAULT ' + default) - - if not column.nullable: - colspec.append('NOT NULL') - - if column.primary_key and column.autoincrement: - try: - first = [c for c in column.table.primary_key.columns - if (c.autoincrement and - isinstance(c.type, sqltypes.Integer) and - not c.foreign_keys)].pop(0) - if column is first: - colspec.append('AUTO_INCREMENT') - except IndexError: - pass - - return ' '.join(colspec) - - def post_create_table(self, table): - """Build table-level CREATE options like ENGINE and COLLATE.""" - - table_opts = [] - for k in table.kwargs: - if k.startswith('mysql_'): - opt = k[6:].upper() - joiner = '=' - if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', - 'CHARACTER SET', 'COLLATE'): - joiner = ' ' - - table_opts.append(joiner.join((opt, table.kwargs[k]))) - return ' '.join(table_opts) - - -class MySQLSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - self.append("\nDROP INDEX %s ON %s" % - (self.preparer.quote(self._validate_identifier(index.name, False), index.quote), - self.preparer.format_table(index.table))) - self.execute() - - def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % - (self.preparer.format_table(constraint.table), - self.preparer.format_constraint(constraint))) - self.execute() - - -class MySQLSchemaReflector(object): - """Parses SHOW CREATE TABLE output.""" - - def __init__(self, identifier_preparer): - """Construct a MySQLSchemaReflector. - - identifier_preparer - An ANSIIdentifierPreparer type, used to determine the identifier - quoting style in effect. - """ - - self.preparer = identifier_preparer - self._prep_regexes() - - def reflect(self, connection, table, show_create, charset, only=None): - """Parse MySQL SHOW CREATE TABLE and fill in a ''Table''. - - show_create - Unicode output of SHOW CREATE TABLE - - table - A ''Table'', to be loaded with Columns, Indexes, etc. - table.name will be set if not already - - charset - FIXME, some constructed values (like column defaults) - currently can't be Unicode. ''charset'' will convert them - into the connection character set. - - only - An optional sequence of column names. If provided, only - these columns will be reflected, and any keys or constraints - that include columns outside this set will also be omitted. - That means that if ``only`` includes only one column in a - 2 part primary key, the entire primary key will be omitted. - """ - - keys, constraints = [], [] - - if only: - only = set(only) - - for line in re.split(r'\r?\n', show_create): - if line.startswith(' ' + self.preparer.initial_quote): - self._add_column(table, line, charset, only) - # a regular table options line - elif line.startswith(') '): - self._set_options(table, line) - # an ANSI-mode table options line - elif line == ')': - pass - elif line.startswith('CREATE '): - self._set_name(table, line) - # Not present in real reflection, but may be if loading from a file. - elif not line: - pass - else: - type_, spec = self.parse_constraints(line) - if type_ is None: - util.warn("Unknown schema content: %r" % line) - elif type_ == 'key': - keys.append(spec) - elif type_ == 'constraint': - constraints.append(spec) - else: - pass - - self._set_keys(table, keys, only) - self._set_constraints(table, constraints, connection, only) - - def _set_name(self, table, line): - """Override a Table name with the reflected name. - - table - A ``Table`` - - line - The first line of SHOW CREATE TABLE output. - """ - - # Don't override by default. - if table.name is None: - table.name = self.parse_name(line) - - def _add_column(self, table, line, charset, only=None): - spec = self.parse_column(line) - if not spec: - util.warn("Unknown column definition %r" % line) - return - if not spec['full']: - util.warn("Incomplete reflection of column definition %r" % line) - - name, type_, args, notnull = \ - spec['name'], spec['coltype'], spec['arg'], spec['notnull'] - - if only and name not in only: - self.logger.info("Omitting reflected column %s.%s" % - (table.name, name)) - return - - # Convention says that TINYINT(1) columns == BOOLEAN - if type_ == 'tinyint' and args == '1': - type_ = 'boolean' - args = None - - try: - col_type = ischema_names[type_] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (type_, name)) - col_type = sqltypes.NullType - - # Column type positional arguments eg. varchar(32) - if args is None or args == '': - type_args = [] - elif args[0] == "'" and args[-1] == "'": - type_args = self._re_csv_str.findall(args) - else: - type_args = [int(v) for v in self._re_csv_int.findall(args)] - - # Column type keyword options - type_kw = {} - for kw in ('unsigned', 'zerofill'): - if spec.get(kw, False): - type_kw[kw] = True - for kw in ('charset', 'collate'): - if spec.get(kw, False): - type_kw[kw] = spec[kw] - - if type_ == 'enum': - type_kw['quoting'] = 'quoted' - - type_instance = col_type(*type_args, **type_kw) - - col_args, col_kw = [], {} - - # NOT NULL - if spec.get('notnull', False): - col_kw['nullable'] = False - - # AUTO_INCREMENT - if spec.get('autoincr', False): - col_kw['autoincrement'] = True - elif issubclass(col_type, sqltypes.Integer): - col_kw['autoincrement'] = False - - # DEFAULT - default = spec.get('default', None) - if default is not None and default != 'NULL': - # Defaults should be in the native charset for the moment - default = default.encode(charset) - if type_ == 'timestamp': - # can't be NULL for TIMESTAMPs - if (default[0], default[-1]) != ("'", "'"): - default = sql.text(default) - else: - default = default[1:-1] - col_args.append(schema.DefaultClause(default)) - - table.append_column(schema.Column(name, type_instance, - *col_args, **col_kw)) - - def _set_keys(self, table, keys, only): - """Add ``Index`` and ``PrimaryKeyConstraint`` items to a ``Table``. - - Most of the information gets dropped here- more is reflected than - the schema objects can currently represent. - - table - A ``Table`` - - keys - A sequence of key specifications produced by `constraints` - - only - Optional `set` of column names. If provided, keys covering - columns not in this set will be omitted. - """ - - for spec in keys: - flavor = spec['type'] - col_names = [s[0] for s in spec['columns']] - - if only and not set(col_names).issubset(only): - if flavor is None: - flavor = 'index' - self.logger.info( - "Omitting %s KEY for (%s), key covers ommitted columns." % - (flavor, ', '.join(col_names))) - continue - - constraint = False - if flavor == 'PRIMARY': - key = schema.PrimaryKeyConstraint() - constraint = True - elif flavor == 'UNIQUE': - key = schema.Index(spec['name'], unique=True) - elif flavor in (None, 'FULLTEXT', 'SPATIAL'): - key = schema.Index(spec['name']) - else: - self.logger.info( - "Converting unknown KEY type %s to a plain KEY" % flavor) - key = schema.Index(spec['name']) - - for col in [table.c[name] for name in col_names]: - key.append_column(col) - - if constraint: - table.append_constraint(key) - - def _set_constraints(self, table, constraints, connection, only): - """Apply constraints to a ``Table``.""" - - default_schema = None - - for spec in constraints: - # only FOREIGN KEYs - ref_name = spec['table'][-1] - ref_schema = len(spec['table']) > 1 and spec['table'][-2] or table.schema - - if not ref_schema: - if default_schema is None: - default_schema = connection.dialect.get_default_schema_name( - connection) - if table.schema == default_schema: - ref_schema = table.schema - - loc_names = spec['local'] - if only and not set(loc_names).issubset(only): - self.logger.info( - "Omitting FOREIGN KEY for (%s), key covers ommitted " - "columns." % (', '.join(loc_names))) - continue - - ref_key = schema._get_table_key(ref_name, ref_schema) - if ref_key in table.metadata.tables: - ref_table = table.metadata.tables[ref_key] - else: - ref_table = schema.Table( - ref_name, table.metadata, schema=ref_schema, - autoload=True, autoload_with=connection) - - ref_names = spec['foreign'] - - if ref_schema: - refspec = [".".join([ref_schema, ref_name, column]) for column in ref_names] - else: - refspec = [".".join([ref_name, column]) for column in ref_names] - - con_kw = {} - for opt in ('name', 'onupdate', 'ondelete'): - if spec.get(opt, False): - con_kw[opt] = spec[opt] - - key = schema.ForeignKeyConstraint(loc_names, refspec, link_to_name=True, **con_kw) - table.append_constraint(key) - - def _set_options(self, table, line): - """Apply safe reflected table options to a ``Table``. - - table - A ``Table`` - - line - The final line of SHOW CREATE TABLE output. - """ - - options = self.parse_table_options(line) - for nope in ('auto_increment', 'data_directory', 'index_directory'): - options.pop(nope, None) - - for opt, val in options.items(): - table.kwargs['mysql_%s' % opt] = val - - def _prep_regexes(self): - """Pre-compile regular expressions.""" - - self._re_columns = [] - self._pr_options = [] - self._re_options_util = {} - - _final = self.preparer.final_quote - - quotes = dict(zip(('iq', 'fq', 'esc_fq'), - [re.escape(s) for s in - (self.preparer.initial_quote, - _final, - self.preparer._escape_identifier(_final))])) - - self._pr_name = _pr_compile( - r'^CREATE (?:\w+ +)?TABLE +' - r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes, - self.preparer._unescape_identifier) - - # `col`,`col2`(32),`col3`(15) DESC - # - # Note: ASC and DESC aren't reflected, so we'll punt... - self._re_keyexprs = _re_compile( - r'(?:' - r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)' - r'(?:\((\d+)\))?(?=\,|$))+' % quotes) - - # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' - self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27') - - # 123 or 123,456 - self._re_csv_int = _re_compile(r'\d+') - - - # `colname` <type> [type opts] - # (NOT NULL | NULL) - # DEFAULT ('value' | CURRENT_TIMESTAMP...) - # COMMENT 'comment' - # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT) - # STORAGE (DISK|MEMORY) - self._re_column = _re_compile( - r' ' - r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' - r'(?P<coltype>\w+)' - r'(?:\((?P<arg>(?:\d+|\d+,\d+|' - r'(?:\x27(?:\x27\x27|[^\x27])*\x27,?)+))\))?' - r'(?: +(?P<unsigned>UNSIGNED))?' - r'(?: +(?P<zerofill>ZEROFILL))?' - r'(?: +CHARACTER SET +(?P<charset>\w+))?' - r'(?: +COLLATE +(P<collate>\w+))?' - r'(?: +(?P<notnull>NOT NULL))?' - r'(?: +DEFAULT +(?P<default>' - r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+)' - r'(?:ON UPDATE \w+)?' - r'))?' - r'(?: +(?P<autoincr>AUTO_INCREMENT))?' - r'(?: +COMMENT +(P<comment>(?:\x27\x27|[^\x27])+))?' - r'(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?' - r'(?: +STORAGE +(?P<storage>\w+))?' - r'(?: +(?P<extra>.*))?' - r',?$' - % quotes - ) - - # Fallback, try to parse as little as possible - self._re_column_loose = _re_compile( - r' ' - r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' - r'(?P<coltype>\w+)' - r'(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?' - r'.*?(?P<notnull>NOT NULL)?' - % quotes - ) - - # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? - # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) - # KEY_BLOCK_SIZE size | WITH PARSER name - self._re_key = _re_compile( - r' ' - r'(?:(?P<type>\S+) )?KEY' - r'(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?' - r'(?: +USING +(?P<using_pre>\S+))?' - r' +\((?P<columns>.+?)\)' - r'(?: +USING +(?P<using_post>\S+))?' - r'(?: +KEY_BLOCK_SIZE +(?P<keyblock>\S+))?' - r'(?: +WITH PARSER +(?P<parser>\S+))?' - r',?$' - % quotes - ) - - # CONSTRAINT `name` FOREIGN KEY (`local_col`) - # REFERENCES `remote` (`remote_col`) - # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE - # ON DELETE CASCADE ON UPDATE RESTRICT - # - # unique constraints come back as KEYs - kw = quotes.copy() - kw['on'] = 'RESTRICT|CASCASDE|SET NULL|NOACTION' - self._re_constraint = _re_compile( - r' ' - r'CONSTRAINT +' - r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' - r'FOREIGN KEY +' - r'\((?P<local>[^\)]+?)\) REFERENCES +' - r'(?P<table>%(iq)s[^%(fq)s]+%(fq)s(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +' - r'\((?P<foreign>[^\)]+?)\)' - r'(?: +(?P<match>MATCH \w+))?' - r'(?: +ON DELETE (?P<ondelete>%(on)s))?' - r'(?: +ON UPDATE (?P<onupdate>%(on)s))?' - % kw - ) - - # PARTITION - # - # punt! - self._re_partition = _re_compile( - r' ' - r'(?:SUB)?PARTITION') - - # Table-level options (COLLATE, ENGINE, etc.) - for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT', - 'AVG_ROW_LENGTH', 'CHARACTER SET', - 'DEFAULT CHARSET', 'CHECKSUM', - 'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD', - 'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT', - 'KEY_BLOCK_SIZE'): - self._add_option_word(option) - - for option in (('COMMENT', 'DATA_DIRECTORY', 'INDEX_DIRECTORY', - 'PASSWORD', 'CONNECTION')): - self._add_option_string(option) - - self._add_option_regex('UNION', r'\([^\)]+\)') - self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK') - self._add_option_regex('RAID_TYPE', - r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+') - self._re_options_util['='] = _re_compile(r'\s*=\s*$') - - def _add_option_string(self, directive): - regex = (r'(?P<directive>%s\s*(?:=\s*)?)' - r'(?:\x27.(?P<val>.*?)\x27(?!\x27)\x27)' % - re.escape(directive)) - self._pr_options.append( - _pr_compile(regex, lambda v: v.replace("''", "'"))) - - def _add_option_word(self, directive): - regex = (r'(?P<directive>%s\s*(?:=\s*)?)' - r'(?P<val>\w+)' % re.escape(directive)) - self._pr_options.append(_pr_compile(regex)) - - def _add_option_regex(self, directive, regex): - regex = (r'(?P<directive>%s\s*(?:=\s*)?)' - r'(?P<val>%s)' % (re.escape(directive), regex)) - self._pr_options.append(_pr_compile(regex)) - - - def parse_name(self, line): - """Extract the table name. - - line - The first line of SHOW CREATE TABLE - """ - - regex, cleanup = self._pr_name - m = regex.match(line) - if not m: - return None - return cleanup(m.group('name')) - - def parse_column(self, line): - """Extract column details. - - Falls back to a 'minimal support' variant if full parse fails. - - line - Any column-bearing line from SHOW CREATE TABLE - """ - - m = self._re_column.match(line) - if m: - spec = m.groupdict() - spec['full'] = True - return spec - m = self._re_column_loose.match(line) - if m: - spec = m.groupdict() - spec['full'] = False - return spec - return None - - def parse_constraints(self, line): - """Parse a KEY or CONSTRAINT line. - - line - A line of SHOW CREATE TABLE output - """ - - # KEY - m = self._re_key.match(line) - if m: - spec = m.groupdict() - # convert columns into name, length pairs - spec['columns'] = self._parse_keyexprs(spec['columns']) - return 'key', spec - - # CONSTRAINT - m = self._re_constraint.match(line) - if m: - spec = m.groupdict() - spec['table'] = \ - self.preparer.unformat_identifiers(spec['table']) - spec['local'] = [c[0] - for c in self._parse_keyexprs(spec['local'])] - spec['foreign'] = [c[0] - for c in self._parse_keyexprs(spec['foreign'])] - return 'constraint', spec - - # PARTITION and SUBPARTITION - m = self._re_partition.match(line) - if m: - # Punt! - return 'partition', line - - # No match. - return (None, line) - - def parse_table_options(self, line): - """Build a dictionary of all reflected table-level options. - - line - The final line of SHOW CREATE TABLE output. - """ - - options = {} - - if not line or line == ')': - return options - - r_eq_trim = self._re_options_util['='] - - for regex, cleanup in self._pr_options: - m = regex.search(line) - if not m: - continue - directive, value = m.group('directive'), m.group('val') - directive = r_eq_trim.sub('', directive).lower() - if cleanup: - value = cleanup(value) - options[directive] = value - - return options - - def _describe_to_create(self, table, columns): - """Re-format DESCRIBE output as a SHOW CREATE TABLE string. - - DESCRIBE is a much simpler reflection and is sufficient for - reflecting views for runtime use. This method formats DDL - for columns only- keys are omitted. - - `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples. - SHOW FULL COLUMNS FROM rows must be rearranged for use with - this function. - """ - - buffer = [] - for row in columns: - (name, col_type, nullable, default, extra) = \ - [row[i] for i in (0, 1, 2, 4, 5)] - - line = [' '] - line.append(self.preparer.quote_identifier(name)) - line.append(col_type) - if not nullable: - line.append('NOT NULL') - if default: - if 'auto_increment' in default: - pass - elif (col_type.startswith('timestamp') and - default.startswith('C')): - line.append('DEFAULT') - line.append(default) - elif default == 'NULL': - line.append('DEFAULT') - line.append(default) - else: - line.append('DEFAULT') - line.append("'%s'" % default.replace("'", "''")) - if extra: - line.append(extra) - - buffer.append(' '.join(line)) - - return ''.join([('CREATE TABLE %s (\n' % - self.preparer.quote_identifier(table.name)), - ',\n'.join(buffer), - '\n) ']) - - def _parse_keyexprs(self, identifiers): - """Unpack '"col"(2),"col" ASC'-ish strings into components.""" - - return self._re_keyexprs.findall(identifiers) - -log.class_logger(MySQLSchemaReflector) - - -class _MySQLIdentifierPreparer(compiler.IdentifierPreparer): - """MySQL-specific schema identifier configuration.""" - - reserved_words = RESERVED_WORDS - - def __init__(self, dialect, **kw): - super(_MySQLIdentifierPreparer, self).__init__(dialect, **kw) - - def _quote_free_identifiers(self, *ids): - """Unilaterally identifier-quote any number of strings.""" - - return tuple([self.quote_identifier(i) for i in ids if i is not None]) - - -class MySQLIdentifierPreparer(_MySQLIdentifierPreparer): - """Traditional MySQL-specific schema identifier configuration.""" - - def __init__(self, dialect): - super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote="`") - - def _escape_identifier(self, value): - return value.replace('`', '``') - - def _unescape_identifier(self, value): - return value.replace('``', '`') - - -class MySQLANSIIdentifierPreparer(_MySQLIdentifierPreparer): - """ANSI_QUOTES MySQL schema identifier configuration.""" - - pass - - -def _compat_fetchall(rp, charset=None): - """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" - - return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()] - -def _compat_fetchone(rp, charset=None): - """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" - - return _MySQLPythonRowProxy(rp.fetchone(), charset) - -def _pr_compile(regex, cleanup=None): - """Prepare a 2-tuple of compiled regex and callable.""" - - return (_re_compile(regex), cleanup) - -def _re_compile(regex): - """Compile a string to regex, I and UNICODE.""" - - return re.compile(regex, re.I | re.UNICODE) - -dialect = MySQLDialect -dialect.statement_compiler = MySQLCompiler -dialect.schemagenerator = MySQLSchemaGenerator -dialect.schemadropper = MySQLSchemaDropper -dialect.execution_ctx_cls = MySQLExecutionContext diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py deleted file mode 100644 index 852cab448..000000000 --- a/lib/sqlalchemy/databases/oracle.py +++ /dev/null @@ -1,904 +0,0 @@ -# oracle.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Support for the Oracle database. - -Oracle version 8 through current (11g at the time of this writing) are supported. - -Driver ------- - -The Oracle dialect uses the cx_oracle driver, available at -http://cx-oracle.sourceforge.net/ . The dialect has several behaviors -which are specifically tailored towards compatibility with this module. - -Connecting ----------- - -Connecting with create_engine() uses the standard URL approach of -``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the -host, port, and dbname tokens are converted to a TNS name using the cx_oracle -:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name. - -Additional arguments which may be specified either as query string arguments on the -URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are: - -* *allow_twophase* - enable two-phase transactions. Defaults to ``True``. - -* *auto_convert_lobs* - defaults to True, see the section on LOB objects. - -* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters. - This is required for LOB datatypes but can be disabled to reduce overhead. Defaults - to ``True``. - -* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an - integer value. This value is only available as a URL query string argument. - -* *threaded* - enable multithreaded access to cx_oracle connections. Defaults - to ``True``. Note that this is the opposite default of cx_oracle itself. - -* *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8). Defaults - to ``True``. If ``False``, Oracle-8 compatible constructs are used for joins. - -* *optimize_limits* - defaults to ``False``. see the section on LIMIT/OFFSET. - -Auto Increment Behavior ------------------------ - -SQLAlchemy Table objects which include integer primary keys are usually assumed to have -"autoincrementing" behavior, meaning they can generate their own primary key values upon -INSERT. Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences -to produce these values. With the Oracle dialect, *a sequence must always be explicitly -specified to enable autoincrement*. This is divergent with the majority of documentation -examples which assume the usage of an autoincrement-capable database. To specify sequences, -use the sqlalchemy.schema.Sequence object which is passed to a Column construct:: - - t = Table('mytable', metadata, - Column('id', Integer, Sequence('id_seq'), primary_key=True), - Column(...), ... - ) - -This step is also required when using table reflection, i.e. autoload=True:: - - t = Table('mytable', metadata, - Column('id', Integer, Sequence('id_seq'), primary_key=True), - autoload=True - ) - -LOB Objects ------------ - -cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set -is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default, -SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First, -the LOB object requires an active cursor association, meaning if you were to fetch many rows -at once such that cx_oracle had to go back to the database and fetch a new batch of rows, -the LOB objects in the already-fetched rows are now unreadable and will raise an error. -SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read. -The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy -defaults to 50 (cx_oracle normally defaults this to one). - -Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to -"normalize" the results to look more like other DBAPIs. - -The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place -for all statement executions, even plain string-based statements for which SQLA has no awareness -of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases -without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch" -of LOB objects, can be disabled using auto_convert_lobs=False. - -LIMIT/OFFSET Support --------------------- - -Oracle has no support for the LIMIT or OFFSET keywords. Whereas previous versions of SQLAlchemy -used the "ROW NUMBER OVER..." construct to simulate LIMIT/OFFSET, SQLAlchemy 0.5 now uses -a wrapped subquery approach in conjunction with ROWNUM. The exact methodology is taken from -http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html . Note that the -"FIRST ROWS()" optimization keyword mentioned is not used by default, as the user community felt -this was stepping into the bounds of optimization that is better left on the DBA side, but this -prefix can be added by enabling the optimize_limits=True flag on create_engine(). - -Two Phase Transaction Support ------------------------------ - -Two Phase transactions are implemented using XA transactions. Success has been reported of them -working successfully but this should be regarded as an experimental feature. - -Oracle 8 Compatibility ----------------------- - -When using Oracle 8, a "use_ansi=False" flag is available which converts all -JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN -makes use of Oracle's (+) operator. - -Synonym/DBLINK Reflection -------------------------- - -When using reflection with Table objects, the dialect can optionally search for tables -indicated by synonyms that reference DBLINK-ed tables by passing the flag -oracle_resolve_synonyms=True as a keyword argument to the Table construct. If DBLINK -is not in use this flag should be left off. - -""" - -import datetime, random, re - -from sqlalchemy import util, sql, schema, log -from sqlalchemy.engine import default, base -from sqlalchemy.sql import compiler, visitors, expression -from sqlalchemy.sql import operators as sql_operators, functions as sql_functions -from sqlalchemy import types as sqltypes - - -class OracleNumeric(sqltypes.Numeric): - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class OracleInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class OracleSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class OracleDate(sqltypes.Date): - def get_col_spec(self): - return "DATE" - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - def process(value): - if not isinstance(value, datetime.datetime): - return value - else: - return value.date() - return process - -class OracleDateTime(sqltypes.DateTime): - def get_col_spec(self): - return "DATE" - - def result_processor(self, dialect): - def process(value): - if value is None or isinstance(value, datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year, value.month, - value.day,value.hour, value.minute, value.second) - return process - -# Note: -# Oracle DATE == DATETIME -# Oracle does not allow milliseconds in DATE -# Oracle does not support TIME columns - -# only if cx_oracle contains TIMESTAMP -class OracleTimestamp(sqltypes.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" - - def get_dbapi_type(self, dialect): - return dialect.TIMESTAMP - - def result_processor(self, dialect): - def process(value): - if value is None or isinstance(value, datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year, value.month, - value.day,value.hour, value.minute, value.second) - return process - -class OracleString(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} - -class OracleNVarchar(sqltypes.Unicode, OracleString): - def get_col_spec(self): - return "NVARCHAR2(%(length)s)" % {'length' : self.length} - -class OracleText(sqltypes.Text): - def get_dbapi_type(self, dbapi): - return dbapi.CLOB - - def get_col_spec(self): - return "CLOB" - - def result_processor(self, dialect): - super_process = super(OracleText, self).result_processor(dialect) - if not dialect.auto_convert_lobs: - return super_process - lob = dialect.dbapi.LOB - def process(value): - if isinstance(value, lob): - if super_process: - return super_process(value.read()) - else: - return value.read() - else: - if super_process: - return super_process(value) - else: - return value - return process - - -class OracleChar(sqltypes.CHAR): - def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} - -class OracleBinary(sqltypes.Binary): - def get_dbapi_type(self, dbapi): - return dbapi.BLOB - - def get_col_spec(self): - return "BLOB" - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - if not dialect.auto_convert_lobs: - return None - lob = dialect.dbapi.LOB - def process(value): - if isinstance(value, lob): - return value.read() - else: - return value - return process - -class OracleRaw(OracleBinary): - def get_col_spec(self): - return "RAW(%(length)s)" % {'length' : self.length} - -class OracleBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "SMALLINT" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - -colspecs = { - sqltypes.Integer : OracleInteger, - sqltypes.Smallinteger : OracleSmallInteger, - sqltypes.Numeric : OracleNumeric, - sqltypes.Float : OracleNumeric, - sqltypes.DateTime : OracleDateTime, - sqltypes.Date : OracleDate, - sqltypes.String : OracleString, - sqltypes.Binary : OracleBinary, - sqltypes.Boolean : OracleBoolean, - sqltypes.Text : OracleText, - sqltypes.TIMESTAMP : OracleTimestamp, - sqltypes.CHAR: OracleChar, -} - -ischema_names = { - 'VARCHAR2' : OracleString, - 'NVARCHAR2' : OracleNVarchar, - 'CHAR' : OracleString, - 'DATE' : OracleDateTime, - 'DATETIME' : OracleDateTime, - 'NUMBER' : OracleNumeric, - 'BLOB' : OracleBinary, - 'BFILE' : OracleBinary, - 'CLOB' : OracleText, - 'TIMESTAMP' : OracleTimestamp, - 'RAW' : OracleRaw, - 'FLOAT' : OracleNumeric, - 'DOUBLE PRECISION' : OracleNumeric, - 'LONG' : OracleText, -} - -class OracleExecutionContext(default.DefaultExecutionContext): - def pre_exec(self): - super(OracleExecutionContext, self).pre_exec() - if self.dialect.auto_setinputsizes: - self.set_input_sizes() - if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: - for key in self.compiled.binds: - bindparam = self.compiled.binds[key] - name = self.compiled.bind_names[bindparam] - value = self.compiled_parameters[0][name] - if bindparam.isoutparam: - dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if not hasattr(self, 'out_parameters'): - self.out_parameters = {} - self.out_parameters[name] = self.cursor.var(dbtype) - self.parameters[0][name] = self.out_parameters[name] - - def create_cursor(self): - c = self._connection.connection.cursor() - if self.dialect.arraysize: - c.cursor.arraysize = self.dialect.arraysize - return c - - def get_result_proxy(self): - if hasattr(self, 'out_parameters'): - if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: - for bind, name in self.compiled.bind_names.iteritems(): - if name in self.out_parameters: - type = bind.type - result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect) - if result_processor is not None: - self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue()) - else: - self.out_parameters[name] = self.out_parameters[name].getvalue() - else: - for k in self.out_parameters: - self.out_parameters[k] = self.out_parameters[k].getvalue() - - if self.cursor.description is not None: - for column in self.cursor.description: - type_code = column[1] - if type_code in self.dialect.ORACLE_BINARY_TYPES: - return base.BufferedColumnResultProxy(self) - - return base.ResultProxy(self) - -class OracleDialect(default.DefaultDialect): - name = 'oracle' - supports_alter = True - supports_unicode_statements = False - max_identifier_length = 30 - supports_sane_rowcount = True - supports_sane_multi_rowcount = False - preexecute_pk_sequences = True - supports_pk_autoincrement = False - default_paramstyle = 'named' - - def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, optimize_limits=False, arraysize=50, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - self.use_ansi = use_ansi - self.threaded = threaded - self.arraysize = arraysize - self.allow_twophase = allow_twophase - self.optimize_limits = optimize_limits - self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) - self.auto_setinputsizes = auto_setinputsizes - self.auto_convert_lobs = auto_convert_lobs - if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__: - self.dbapi_type_map = {} - self.ORACLE_BINARY_TYPES = [] - else: - # only use this for LOB objects. using it for strings, dates - # etc. leads to a little too much magic, reflection doesn't know if it should - # expect encoded strings or unicodes, etc. - self.dbapi_type_map = { - self.dbapi.CLOB: OracleText(), - self.dbapi.BLOB: OracleBinary(), - self.dbapi.BINARY: OracleRaw(), - } - self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)] - - def dbapi(cls): - import cx_Oracle - return cx_Oracle - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - dialect_opts = dict(url.query) - for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', - 'threaded', 'allow_twophase'): - if opt in dialect_opts: - util.coerce_kw_type(dialect_opts, opt, bool) - setattr(self, opt, dialect_opts[opt]) - - if url.database: - # if we have a database, then we have a remote host - port = url.port - if port: - port = int(port) - else: - port = 1521 - dsn = self.dbapi.makedsn(url.host, port, url.database) - else: - # we have a local tnsname - dsn = url.host - - opts = dict( - user=url.username, - password=url.password, - dsn=dsn, - threaded=self.threaded, - twophase=self.allow_twophase, - ) - if 'mode' in url.query: - opts['mode'] = url.query['mode'] - if isinstance(opts['mode'], basestring): - mode = opts['mode'].upper() - if mode == 'SYSDBA': - opts['mode'] = self.dbapi.SYSDBA - elif mode == 'SYSOPER': - opts['mode'] = self.dbapi.SYSOPER - else: - util.coerce_kw_type(opts, 'mode', int) - # Can't set 'handle' or 'pool' via URL query args, use connect_args - - return ([], opts) - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.InterfaceError): - return "not connected" in str(e) - else: - return "ORA-03114" in str(e) or "ORA-03113" in str(e) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def create_xid(self): - """create a two-phase transaction ID. - - this id will be passed to do_begin_twophase(), do_rollback_twophase(), - do_commit_twophase(). its format is unspecified.""" - - id = random.randint(0, 2 ** 128) - return (0x1234, "%032x" % id, "%032x" % 9) - - def do_release_savepoint(self, connection, name): - # Oracle does not support RELEASE SAVEPOINT - pass - - def do_begin_twophase(self, connection, xid): - connection.connection.begin(*xid) - - def do_prepare_twophase(self, connection, xid): - connection.connection.prepare() - - def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): - self.do_rollback(connection.connection) - - def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): - self.do_commit(connection.connection) - - def do_recover_twophase(self, connection): - pass - - def has_table(self, connection, table_name, schema=None): - if not schema: - schema = self.get_default_schema_name(connection) - cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self._denormalize_name(table_name), 'schema_name':self._denormalize_name(schema)}) - return cursor.fetchone() is not None - - def has_sequence(self, connection, sequence_name, schema=None): - if not schema: - schema = self.get_default_schema_name(connection) - cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self._denormalize_name(sequence_name), 'schema_name':self._denormalize_name(schema)}) - return cursor.fetchone() is not None - - def _normalize_name(self, name): - if name is None: - return None - elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding)): - return name.lower().decode(self.encoding) - else: - return name.decode(self.encoding) - - def _denormalize_name(self, name): - if name is None: - return None - elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): - return name.upper().encode(self.encoding) - else: - return name.encode(self.encoding) - - def get_default_schema_name(self, connection): - return self._normalize_name(connection.execute('SELECT USER FROM DUAL').scalar()) - get_default_schema_name = base.connection_memoize( - ('dialect', 'default_schema_name'))(get_default_schema_name) - - def table_names(self, connection, schema): - # note that table_names() isnt loading DBLINKed or synonym'ed tables - if schema is None: - s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')" - cursor = connection.execute(s) - else: - s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner" - cursor = connection.execute(s, {'owner': self._denormalize_name(schema)}) - return [self._normalize_name(row[0]) for row in cursor] - - def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None): - """search for a local synonym matching the given desired owner/name. - - if desired_owner is None, attempts to locate a distinct owner. - - returns the actual name, owner, dblink name, and synonym name if found. - """ - - sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME - from ALL_SYNONYMS WHERE """ - - clauses = [] - params = {} - if desired_synonym: - clauses.append("SYNONYM_NAME=:synonym_name") - params['synonym_name'] = desired_synonym - if desired_owner: - clauses.append("TABLE_OWNER=:desired_owner") - params['desired_owner'] = desired_owner - if desired_table: - clauses.append("TABLE_NAME=:tname") - params['tname'] = desired_table - - sql += " AND ".join(clauses) - - result = connection.execute(sql, **params) - if desired_owner: - row = result.fetchone() - if row: - return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME'] - else: - return None, None, None, None - else: - rows = result.fetchall() - if len(rows) > 1: - raise AssertionError("There are multiple tables visible to the schema, you must specify owner") - elif len(rows) == 1: - row = rows[0] - return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME'] - else: - return None, None, None, None - - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - - resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False) - - if resolve_synonyms: - actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name)) - else: - actual_name, owner, dblink, synonym = None, None, None, None - - if not actual_name: - actual_name = self._denormalize_name(table.name) - if not dblink: - dblink = '' - if not owner: - owner = self._denormalize_name(table.schema or self.get_default_schema_name(connection)) - - c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner}) - - while True: - row = c.fetchone() - if row is None: - break - - (colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) - - if include_columns and colname not in include_columns: - continue - - # INTEGER if the scale is 0 and precision is null - # NUMBER if the scale and precision are both null - # NUMBER(9,2) if the precision is 9 and the scale is 2 - # NUMBER(3) if the precision is 3 and scale is 0 - #length is ignored except for CHAR and VARCHAR2 - if coltype == 'NUMBER' : - if precision is None and scale is None: - coltype = OracleNumeric - elif precision is None and scale == 0 : - coltype = OracleInteger - else : - coltype = OracleNumeric(precision, scale) - elif coltype=='CHAR' or coltype=='VARCHAR2': - coltype = ischema_names.get(coltype, OracleString)(length) - else: - coltype = re.sub(r'\(\d+\)', '', coltype) - try: - coltype = ischema_names[coltype] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, colname)) - coltype = sqltypes.NULLTYPE - - colargs = [] - if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) - - table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs)) - - if not table.columns: - raise AssertionError("Couldn't find any column information for table %s" % actual_name) - - c = connection.execute("""SELECT - ac.constraint_name, - ac.constraint_type, - loc.column_name AS local_column, - rem.table_name AS remote_table, - rem.column_name AS remote_column, - rem.owner AS remote_owner - FROM all_constraints%(dblink)s ac, - all_cons_columns%(dblink)s loc, - all_cons_columns%(dblink)s rem - WHERE ac.table_name = :table_name - AND ac.constraint_type IN ('R','P') - AND ac.owner = :owner - AND ac.owner = loc.owner - AND ac.constraint_name = loc.constraint_name - AND ac.r_owner = rem.owner(+) - AND ac.r_constraint_name = rem.constraint_name(+) - -- order multiple primary keys correctly - ORDER BY ac.constraint_name, loc.position, rem.position""" - % {'dblink':dblink}, {'table_name' : actual_name, 'owner' : owner}) - - fks = {} - while True: - row = c.fetchone() - if row is None: - break - #print "ROW:" , row - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) - if cons_type == 'P': - table.primary_key.add(table.c[local_column]) - elif cons_type == 'R': - try: - fk = fks[cons_name] - except KeyError: - fk = ([], []) - fks[cons_name] = fk - if remote_table is None: - # ticket 363 - util.warn( - ("Got 'None' querying 'table_name' from " - "all_cons_columns%(dblink)s - does the user have " - "proper rights to the table?") % {'dblink':dblink}) - continue - - if resolve_synonyms: - ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(remote_owner), desired_table=self._denormalize_name(remote_table)) - if ref_synonym: - remote_table = self._normalize_name(ref_synonym) - remote_owner = self._normalize_name(ref_remote_owner) - - if not table.schema and self._denormalize_name(remote_owner) == owner: - refspec = ".".join([remote_table, remote_column]) - t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) - else: - refspec = ".".join([x for x in [remote_owner, remote_table, remote_column] if x]) - t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, schema=remote_owner, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) - - if local_column not in fk[0]: - fk[0].append(local_column) - if refspec not in fk[1]: - fk[1].append(refspec) - - for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True)) - - -class _OuterJoinColumn(sql.ClauseElement): - __visit_name__ = 'outer_join_column' - - def __init__(self, column): - self.column = column - -class OracleCompiler(compiler.DefaultCompiler): - """Oracle compiler modifies the lexical structure of Select - statements to work under non-ANSI configured Oracle databases, if - the use_ansi flag is False. - """ - - operators = compiler.DefaultCompiler.operators.copy() - operators.update( - { - sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y), - sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) - } - ) - - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - sql_functions.now : 'CURRENT_TIMESTAMP' - } - ) - - def __init__(self, *args, **kwargs): - super(OracleCompiler, self).__init__(*args, **kwargs) - self.__wheres = {} - - def default_from(self): - """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. - - The Oracle compiler tacks a "FROM DUAL" to the statement. - """ - - return " FROM DUAL" - - def apply_function_parens(self, func): - return len(func.clauses) > 0 - - def visit_join(self, join, **kwargs): - if self.dialect.use_ansi: - return compiler.DefaultCompiler.visit_join(self, join, **kwargs) - else: - return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) - - def _get_nonansi_join_whereclause(self, froms): - clauses = [] - - def visit_join(join): - if join.isouter: - def visit_binary(binary): - if binary.operator == sql_operators.eq: - if binary.left.table is join.right: - binary.left = _OuterJoinColumn(binary.left) - elif binary.right.table is join.right: - binary.right = _OuterJoinColumn(binary.right) - clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary})) - else: - clauses.append(join.onclause) - - for f in froms: - visitors.traverse(f, {}, {'join':visit_join}) - return sql.and_(*clauses) - - def visit_outer_join_column(self, vc): - return self.process(vc.column) + "(+)" - - def visit_sequence(self, seq): - return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" - - def visit_alias(self, alias, asfrom=False, **kwargs): - """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" - - if asfrom: - alias_name = isinstance(alias.name, expression._generated_label) and \ - self._truncated_identifier("alias", alias.name) or alias.name - - return self.process(alias.original, asfrom=True, **kwargs) + " " +\ - self.preparer.format_alias(alias, alias_name) - else: - return self.process(alias.original, **kwargs) - - def _TODO_visit_compound_select(self, select): - """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" - pass - - def visit_select(self, select, **kwargs): - """Look for ``LIMIT`` and OFFSET in a select statement, and if - so tries to wrap it in a subquery with ``rownum`` criterion. - """ - - if not getattr(select, '_oracle_visit', None): - if not self.dialect.use_ansi: - if self.stack and 'from' in self.stack[-1]: - existingfroms = self.stack[-1]['from'] - else: - existingfroms = None - - froms = select._get_display_froms(existingfroms) - whereclause = self._get_nonansi_join_whereclause(froms) - if whereclause: - select = select.where(whereclause) - select._oracle_visit = True - - if select._limit is not None or select._offset is not None: - # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html - # - # Generalized form of an Oracle pagination query: - # select ... from ( - # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from ( - # select distinct ... where ... order by ... - # ) where ROWNUM <= :limit+:offset - # ) where ora_rn > :offset - # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0 - - # TODO: use annotations instead of clone + attr set ? - select = select._generate() - select._oracle_visit = True - - # Wrap the middle select and add the hint - limitselect = sql.select([c for c in select.c]) - if select._limit and self.dialect.optimize_limits: - limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit) - - limitselect._oracle_visit = True - limitselect._is_wrapper = True - - # If needed, add the limiting clause - if select._limit is not None: - max_row = select._limit - if select._offset is not None: - max_row += select._offset - limitselect.append_whereclause( - sql.literal_column("ROWNUM")<=max_row) - - # If needed, add the ora_rn, and wrap again with offset. - if select._offset is None: - select = limitselect - else: - limitselect = limitselect.column( - sql.literal_column("ROWNUM").label("ora_rn")) - limitselect._oracle_visit = True - limitselect._is_wrapper = True - - offsetselect = sql.select( - [c for c in limitselect.c if c.key!='ora_rn']) - offsetselect._oracle_visit = True - offsetselect._is_wrapper = True - - offsetselect.append_whereclause( - sql.literal_column("ora_rn")>select._offset) - - select = offsetselect - - kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) - return compiler.DefaultCompiler.visit_select(self, select, **kwargs) - - def limit_clause(self, select): - return "" - - def for_update_clause(self, select): - if select.for_update == "nowait": - return " FOR UPDATE NOWAIT" - else: - return super(OracleCompiler, self).for_update_clause(select) - - -class OracleSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - return colspec - - def visit_sequence(self, sequence): - if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name, sequence.schema): - self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - -class OracleSchemaDropper(compiler.SchemaDropper): - def visit_sequence(self, sequence): - if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name, sequence.schema): - self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - -class OracleDefaultRunner(base.DefaultRunner): - def visit_sequence(self, seq): - return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {}) - -class OracleIdentifierPreparer(compiler.IdentifierPreparer): - def format_savepoint(self, savepoint): - name = re.sub(r'^_+', '', savepoint.ident) - return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) - - -dialect = OracleDialect -dialect.statement_compiler = OracleCompiler -dialect.schemagenerator = OracleSchemaGenerator -dialect.schemadropper = OracleSchemaDropper -dialect.preparer = OracleIdentifierPreparer -dialect.defaultrunner = OracleDefaultRunner -dialect.execution_ctx_cls = OracleExecutionContext diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py deleted file mode 100644 index 154d971e3..000000000 --- a/lib/sqlalchemy/databases/postgres.py +++ /dev/null @@ -1,889 +0,0 @@ -# postgres.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""Support for the PostgreSQL database. - -Driver ------- - -The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . -The dialect has several behaviors which are specifically tailored towards compatibility -with this module. - -Note that psycopg1 is **not** supported. - -Connecting ----------- - -URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`. - -PostgreSQL-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: - -* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support - this feature. What this essentially means from a psycopg2 point of view is that the cursor is - created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows - are not immediately pre-fetched and buffered after statement execution, but are instead left - on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy` - uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows - at a time are fetched over the wire to reduce conversational overhead. - -Sequences/SERIAL ----------------- - -PostgreSQL supports sequences, and SQLAlchemy uses these as the default means of creating -new primary key values for integer-based primary key columns. When creating tables, -SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns, -which generates a sequence corresponding to the column and associated with it based on -a naming convention. - -To specify a specific named sequence to be used for primary key generation, use the -:func:`~sqlalchemy.schema.Sequence` construct:: - - Table('sometable', metadata, - Column('id', Integer, Sequence('some_id_seq'), primary_key=True) - ) - -Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of -having the "last insert identifier" available, the sequence is executed independently -beforehand and the new value is retrieved, to be used in the subsequent insert. Note -that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using -"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior -is used. - -PostgreSQL 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports -as well. A future release of SQLA will use this feature by default in lieu of -sequence pre-execution in order to retrieve new primary key values, when available. - -INSERT/UPDATE...RETURNING -------------------------- - -The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes, -but must be explicitly enabled on a per-statement basis:: - - # INSERT..RETURNING - result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\ - values(name='foo') - print result.fetchall() - - # UPDATE..RETURNING - result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\ - where(table.c.name=='foo').values(name='bar') - print result.fetchall() - -Indexes -------- - -PostgreSQL supports partial indexes. To create them pass a postgres_where -option to the Index constructor:: - - Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10) - -Transactions ------------- - -The PostgreSQL dialect fully supports SAVEPOINT and two-phase commit operations. - - -""" - -import decimal, random, re, string - -from sqlalchemy import sql, schema, exc, util -from sqlalchemy.engine import base, default -from sqlalchemy.sql import compiler, expression -from sqlalchemy.sql import operators as sql_operators -from sqlalchemy import types as sqltypes - - -class PGInet(sqltypes.TypeEngine): - def get_col_spec(self): - return "INET" - -class PGCidr(sqltypes.TypeEngine): - def get_col_spec(self): - return "CIDR" - -class PGMacAddr(sqltypes.TypeEngine): - def get_col_spec(self): - return "MACADDR" - -class PGNumeric(sqltypes.Numeric): - def get_col_spec(self): - if not self.precision: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - if self.asdecimal: - return None - else: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - -class PGFloat(sqltypes.Float): - def get_col_spec(self): - if not self.precision: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - -class PGInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class PGSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class PGBigInteger(PGInteger): - def get_col_spec(self): - return "BIGINT" - -class PGDateTime(sqltypes.DateTime): - def get_col_spec(self): - return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" - -class PGDate(sqltypes.Date): - def get_col_spec(self): - return "DATE" - -class PGTime(sqltypes.Time): - def get_col_spec(self): - return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" - -class PGInterval(sqltypes.TypeEngine): - def get_col_spec(self): - return "INTERVAL" - -class PGText(sqltypes.Text): - def get_col_spec(self): - return "TEXT" - -class PGString(sqltypes.String): - def get_col_spec(self): - if self.length: - return "VARCHAR(%(length)d)" % {'length' : self.length} - else: - return "VARCHAR" - -class PGChar(sqltypes.CHAR): - def get_col_spec(self): - if self.length: - return "CHAR(%(length)d)" % {'length' : self.length} - else: - return "CHAR" - -class PGBinary(sqltypes.Binary): - def get_col_spec(self): - return "BYTEA" - -class PGBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BOOLEAN" - -class PGBit(sqltypes.TypeEngine): - def get_col_spec(self): - return "BIT" - -class PGUuid(sqltypes.TypeEngine): - def get_col_spec(self): - return "UUID" - -class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): - def __init__(self, item_type, mutable=True): - if isinstance(item_type, type): - item_type = item_type() - self.item_type = item_type - self.mutable = mutable - - def copy_value(self, value): - if value is None: - return None - elif self.mutable: - return list(value) - else: - return value - - def compare_values(self, x, y): - return x == y - - def is_mutable(self): - return self.mutable - - def dialect_impl(self, dialect, **kwargs): - impl = self.__class__.__new__(self.__class__) - impl.__dict__.update(self.__dict__) - impl.item_type = self.item_type.dialect_impl(dialect) - return impl - - def bind_processor(self, dialect): - item_proc = self.item_type.bind_processor(dialect) - def process(value): - if value is None: - return value - def convert_item(item): - if isinstance(item, (list, tuple)): - return [convert_item(child) for child in item] - else: - if item_proc: - return item_proc(item) - else: - return item - return [convert_item(item) for item in value] - return process - - def result_processor(self, dialect): - item_proc = self.item_type.result_processor(dialect) - def process(value): - if value is None: - return value - def convert_item(item): - if isinstance(item, list): - return [convert_item(child) for child in item] - else: - if item_proc: - return item_proc(item) - else: - return item - return [convert_item(item) for item in value] - return process - def get_col_spec(self): - return self.item_type.get_col_spec() + '[]' - -colspecs = { - sqltypes.Integer : PGInteger, - sqltypes.Smallinteger : PGSmallInteger, - sqltypes.Numeric : PGNumeric, - sqltypes.Float : PGFloat, - sqltypes.DateTime : PGDateTime, - sqltypes.Date : PGDate, - sqltypes.Time : PGTime, - sqltypes.String : PGString, - sqltypes.Binary : PGBinary, - sqltypes.Boolean : PGBoolean, - sqltypes.Text : PGText, - sqltypes.CHAR: PGChar, -} - -ischema_names = { - 'integer' : PGInteger, - 'bigint' : PGBigInteger, - 'smallint' : PGSmallInteger, - 'character varying' : PGString, - 'character' : PGChar, - '"char"' : PGChar, - 'name': PGChar, - 'text' : PGText, - 'numeric' : PGNumeric, - 'float' : PGFloat, - 'real' : PGFloat, - 'inet': PGInet, - 'cidr': PGCidr, - 'uuid':PGUuid, - 'bit':PGBit, - 'macaddr': PGMacAddr, - 'double precision' : PGFloat, - 'timestamp' : PGDateTime, - 'timestamp with time zone' : PGDateTime, - 'timestamp without time zone' : PGDateTime, - 'time with time zone' : PGTime, - 'time without time zone' : PGTime, - 'date' : PGDate, - 'time': PGTime, - 'bytea' : PGBinary, - 'boolean' : PGBoolean, - 'interval':PGInterval, -} - -# TODO: filter out 'FOR UPDATE' statements -SERVER_SIDE_CURSOR_RE = re.compile( - r'\s*SELECT', - re.I | re.UNICODE) - -class PGExecutionContext(default.DefaultExecutionContext): - def create_cursor(self): - # TODO: coverage for server side cursors + select.for_update() - is_server_side = \ - self.dialect.server_side_cursors and \ - ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) - and not getattr(self.compiled.statement, 'for_update', False)) \ - or \ - ( - (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) - and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) - ) - - self.__is_server_side = is_server_side - if is_server_side: - # use server-side cursors: - # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) - return self._connection.connection.cursor(ident) - else: - return self._connection.connection.cursor() - - def get_result_proxy(self): - if self.__is_server_side: - return base.BufferedRowResultProxy(self) - else: - return base.ResultProxy(self) - -class PGDialect(default.DefaultDialect): - name = 'postgres' - supports_alter = True - supports_unicode_statements = False - max_identifier_length = 63 - supports_sane_rowcount = True - supports_sane_multi_rowcount = False - preexecute_pk_sequences = True - supports_pk_autoincrement = False - default_paramstyle = 'pyformat' - supports_default_values = True - supports_empty_insert = False - - def __init__(self, server_side_cursors=False, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - self.server_side_cursors = server_side_cursors - - def dbapi(cls): - import psycopg2 as psycopg - return psycopg - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) - opts.update(url.query) - return ([], opts) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def do_begin_twophase(self, connection, xid): - self.do_begin(connection.connection) - - def do_prepare_twophase(self, connection, xid): - connection.execute(sql.text("PREPARE TRANSACTION :tid", bindparams=[sql.bindparam('tid', xid)])) - - def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): - if is_prepared: - if recover: - #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions - # Must find out a way how to make the dbapi not open a transaction. - connection.execute(sql.text("ROLLBACK")) - connection.execute(sql.text("ROLLBACK PREPARED :tid", bindparams=[sql.bindparam('tid', xid)])) - connection.execute(sql.text("BEGIN")) - self.do_rollback(connection.connection) - else: - self.do_rollback(connection.connection) - - def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): - if is_prepared: - if recover: - connection.execute(sql.text("ROLLBACK")) - connection.execute(sql.text("COMMIT PREPARED :tid", bindparams=[sql.bindparam('tid', xid)])) - connection.execute(sql.text("BEGIN")) - self.do_rollback(connection.connection) - else: - self.do_commit(connection.connection) - - def do_recover_twophase(self, connection): - resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) - return [row[0] for row in resultset] - - def get_default_schema_name(self, connection): - return connection.scalar("select current_schema()", None) - get_default_schema_name = base.connection_memoize( - ('dialect', 'default_schema_name'))(get_default_schema_name) - - def last_inserted_ids(self): - if self.context.last_inserted_ids is None: - raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without PostgreSQL OIDs enabled") - else: - return self.context.last_inserted_ids - - def has_table(self, connection, table_name, schema=None): - # seems like case gets folded in pg_class... - if schema is None: - cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)}); - else: - cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema}); - return bool( not not cursor.rowcount ) - - def has_sequence(self, connection, sequence_name): - cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)}) - return bool(not not cursor.rowcount) - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'closed the connection' in str(e) or 'connection not open' in str(e) - elif isinstance(e, self.dbapi.InterfaceError): - return 'connection already closed' in str(e) or 'cursor already closed' in str(e) - elif isinstance(e, self.dbapi.ProgrammingError): - # yes, it really says "losed", not "closed" - return "losed the connection unexpectedly" in str(e) - else: - return False - - def table_names(self, connection, schema): - s = """ - SELECT relname - FROM pg_class c - WHERE relkind = 'r' - AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) - """ % locals() - return [row[0].decode(self.encoding) for row in connection.execute(s)] - - def server_version_info(self, connection): - v = connection.execute("select version()").scalar() - m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v) - if not m: - raise AssertionError("Could not determine version from string '%s'" % v) - return tuple([int(x) for x in m.group(1, 2, 3)]) - - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - if table.schema is not None: - schema_where_clause = "n.nspname = :schema" - schemaname = table.schema - if isinstance(schemaname, str): - schemaname = schemaname.decode(self.encoding) - else: - schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - schemaname = None - - SQL_COLS = """ - SELECT a.attname, - pg_catalog.format_type(a.atttypid, a.atttypmod), - (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d - WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) - AS DEFAULT, - a.attnotnull, a.attnum, a.attrelid as table_oid - FROM pg_catalog.pg_attribute a - WHERE a.attrelid = ( - SELECT c.oid - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE (%s) - AND c.relname = :table_name AND c.relkind in ('r','v') - ) AND a.attnum > 0 AND NOT a.attisdropped - ORDER BY a.attnum - """ % schema_where_clause - - s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}) - tablename = table.name - if isinstance(tablename, str): - tablename = tablename.decode(self.encoding) - c = connection.execute(s, table_name=tablename, schema=schemaname) - rows = c.fetchall() - - if not rows: - raise exc.NoSuchTableError(table.name) - - domains = self._load_domains(connection) - - for name, format_type, default, notnull, attnum, table_oid in rows: - if include_columns and name not in include_columns: - continue - - ## strip (30) from character varying(30) - attype = re.search('([^\([]+)', format_type).group(1) - nullable = not notnull - is_array = format_type.endswith('[]') - - try: - charlen = re.search('\(([\d,]+)\)', format_type).group(1) - except: - charlen = False - - numericprec = False - numericscale = False - if attype == 'numeric': - if charlen is False: - numericprec, numericscale = (None, None) - else: - numericprec, numericscale = charlen.split(',') - charlen = False - if attype == 'double precision': - numericprec, numericscale = (53, False) - charlen = False - if attype == 'integer': - numericprec, numericscale = (32, 0) - charlen = False - - args = [] - for a in (charlen, numericprec, numericscale): - if a is None: - args.append(None) - elif a is not False: - args.append(int(a)) - - kwargs = {} - if attype == 'timestamp with time zone': - kwargs['timezone'] = True - elif attype == 'timestamp without time zone': - kwargs['timezone'] = False - - coltype = None - if attype in ischema_names: - coltype = ischema_names[attype] - else: - if attype in domains: - domain = domains[attype] - if domain['attype'] in ischema_names: - # A table can't override whether the domain is nullable. - nullable = domain['nullable'] - - if domain['default'] and not default: - # It can, however, override the default value, but can't set it to null. - default = domain['default'] - coltype = ischema_names[domain['attype']] - - if coltype: - coltype = coltype(*args, **kwargs) - if is_array: - coltype = PGArray(coltype) - else: - util.warn("Did not recognize type '%s' of column '%s'" % - (attype, name)) - coltype = sqltypes.NULLTYPE - - colargs = [] - if default is not None: - match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) - if match is not None: - # the default is related to a Sequence - sch = table.schema - if '.' not in match.group(2) and sch is not None: - # unconditionally quote the schema name. this could - # later be enhanced to obey quoting rules / "quote schema" - default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) - colargs.append(schema.DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) - - - # Primary keys - PK_SQL = """ - SELECT attname FROM pg_attribute - WHERE attrelid = ( - SELECT indexrelid FROM pg_index i - WHERE i.indrelid = :table - AND i.indisprimary = 't') - ORDER BY attnum - """ - t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) - c = connection.execute(t, table=table_oid) - for row in c.fetchall(): - pk = row[0] - if pk in table.c: - col = table.c[pk] - table.primary_key.add(col) - if col.default is None: - col.autoincrement = False - - # Foreign keys - FK_SQL = """ - SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef - FROM pg_catalog.pg_constraint r - WHERE r.conrelid = :table AND r.contype = 'f' - ORDER BY 1 - """ - - t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) - c = connection.execute(t, table=table_oid) - for conname, condef in c.fetchall(): - m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() - (constrained_columns, referred_schema, referred_table, referred_columns) = m - constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] - if referred_schema: - referred_schema = preparer._unquote_identifier(referred_schema) - elif table.schema is not None and table.schema == self.get_default_schema_name(connection): - # no schema (i.e. its the default schema), and the table we're - # reflecting has the default schema explicit, then use that. - # i.e. try to use the user's conventions - referred_schema = table.schema - referred_table = preparer._unquote_identifier(referred_table) - referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] - - refspec = [] - if referred_schema is not None: - schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, - autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_schema, referred_table, column])) - else: - schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_table, column])) - - table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True)) - - # Indexes - IDX_SQL = """ - SELECT c.relname, i.indisunique, i.indexprs, i.indpred, - a.attname - FROM pg_index i, pg_class c, pg_attribute a - WHERE i.indrelid = :table AND i.indexrelid = c.oid - AND a.attrelid = i.indexrelid AND i.indisprimary = 'f' - ORDER BY c.relname, a.attnum - """ - t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode}) - c = connection.execute(t, table=table_oid) - indexes = {} - sv_idx_name = None - for row in c.fetchall(): - idx_name, unique, expr, prd, col = row - - if expr: - if not idx_name == sv_idx_name: - util.warn( - "Skipped unsupported reflection of expression-based index %s" - % idx_name) - sv_idx_name = idx_name - continue - if prd and not idx_name == sv_idx_name: - util.warn( - "Predicate of partial index %s ignored during reflection" - % idx_name) - sv_idx_name = idx_name - - if not indexes.has_key(idx_name): - indexes[idx_name] = [unique, []] - indexes[idx_name][1].append(col) - - for name, (unique, columns) in indexes.items(): - schema.Index(name, *[table.columns[c] for c in columns], - **dict(unique=unique)) - - - - def _load_domains(self, connection): - ## Load data types for domains: - SQL_DOMAINS = """ - SELECT t.typname as "name", - pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", - not t.typnotnull as "nullable", - t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema" - FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid - WHERE t.typtype = 'd' - """ - - s = sql.text(SQL_DOMAINS, typemap={'attname':sqltypes.Unicode}) - c = connection.execute(s) - - domains = {} - for domain in c.fetchall(): - ## strip (30) from character varying(30) - attype = re.search('([^\(]+)', domain['attype']).group(1) - if domain['visible']: - # 'visible' just means whether or not the domain is in a - # schema that's on the search path -- or not overriden by - # a schema with higher presedence. If it's not visible, - # it will be prefixed with the schema-name when it's used. - name = domain['name'] - else: - name = "%s.%s" % (domain['schema'], domain['name']) - - domains[name] = {'attype':attype, 'nullable': domain['nullable'], 'default': domain['default']} - - return domains - - -class PGCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() - operators.update( - { - sql_operators.mod : '%%', - sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y), - } - ) - - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - 'TIMESTAMP':util.deprecated(message="Use a literal string 'timestamp <value>' instead")(lambda x:'TIMESTAMP %s' % x), - } - ) - - def visit_sequence(self, seq): - if seq.optional: - return None - else: - return "nextval('%s')" % self.preparer.format_sequence(seq) - - def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.") - return text.replace('%', '%%') - - def limit_clause(self, select): - text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: - text += " \n LIMIT ALL" - text += " OFFSET " + str(select._offset) - return text - - def get_select_precolumns(self, select): - if select._distinct: - if isinstance(select._distinct, bool): - return "DISTINCT " - elif isinstance(select._distinct, (list, tuple)): - return "DISTINCT ON (" + ', '.join( - [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] - )+ ") " - else: - return "DISTINCT ON (" + unicode(select._distinct) + ") " - else: - return "" - - def for_update_clause(self, select): - if select.for_update == 'nowait': - return " FOR UPDATE NOWAIT" - else: - return super(PGCompiler, self).for_update_clause(select) - - def _append_returning(self, text, stmt): - returning_cols = stmt.kwargs['postgres_returning'] - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, expression.Selectable): - for co in c.columns: - yield co - else: - yield c - columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)] - text += ' RETURNING ' + string.join(columns, ', ') - return text - - def visit_update(self, update_stmt): - text = super(PGCompiler, self).visit_update(update_stmt) - if 'postgres_returning' in update_stmt.kwargs: - return self._append_returning(text, update_stmt) - else: - return text - - def visit_insert(self, insert_stmt): - text = super(PGCompiler, self).visit_insert(insert_stmt) - if 'postgres_returning' in insert_stmt.kwargs: - return self._append_returning(text, insert_stmt) - else: - return text - - def visit_extract(self, extract, **kwargs): - field = self.extract_map.get(extract.field, extract.field) - return "EXTRACT(%s FROM %s::timestamp)" % ( - field, self.process(extract.expr)) - - -class PGSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) - if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): - if isinstance(column.type, PGBigInteger): - colspec += " BIGSERIAL" - else: - colspec += " SERIAL" - else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - return colspec - - def visit_sequence(self, sequence): - if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)): - self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - - def visit_index(self, index): - preparer = self.preparer - self.append("CREATE ") - if index.unique: - self.append("UNIQUE ") - self.append("INDEX %s ON %s (%s)" \ - % (preparer.quote(self._validate_identifier(index.name, True), index.quote), - preparer.format_table(index.table), - string.join([preparer.format_column(c) for c in index.columns], ', '))) - whereclause = index.kwargs.get('postgres_where', None) - if whereclause is not None: - compiler = self._compile(whereclause, None) - # this might belong to the compiler class - inlined_clause = str(compiler) % dict( - [(key,bind.value) for key,bind in compiler.binds.iteritems()]) - self.append(" WHERE " + inlined_clause) - self.execute() - -class PGSchemaDropper(compiler.SchemaDropper): - def visit_sequence(self, sequence): - if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): - self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - -class PGDefaultRunner(base.DefaultRunner): - def __init__(self, context): - base.DefaultRunner.__init__(self, context) - # craete cursor which won't conflict with a server-side cursor - self.cursor = context._connection.connection.cursor() - - def get_column_default(self, column, isinsert=True): - if column.primary_key: - # pre-execute passive defaults on primary keys - if (isinstance(column.server_default, schema.DefaultClause) and - column.server_default.arg is not None): - return self.execute_string("select %s" % column.server_default.arg) - elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): - sch = column.table.schema - # TODO: this has to build into the Sequence object so we can get the quoting - # logic from it - if sch is not None: - exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) - else: - exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - return self.execute_string(exc.encode(self.dialect.encoding)) - - return super(PGDefaultRunner, self).get_column_default(column) - - def visit_sequence(self, seq): - if not seq.optional: - return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq))) - else: - return None - -class PGIdentifierPreparer(compiler.IdentifierPreparer): - def _unquote_identifier(self, value): - if value[0] == self.initial_quote: - value = value[1:-1].replace('""','"') - return value - -dialect = PGDialect -dialect.statement_compiler = PGCompiler -dialect.schemagenerator = PGSchemaGenerator -dialect.schemadropper = PGSchemaDropper -dialect.preparer = PGIdentifierPreparer -dialect.defaultrunner = PGDefaultRunner -dialect.execution_ctx_cls = PGExecutionContext diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py deleted file mode 100644 index 8952b2b1d..000000000 --- a/lib/sqlalchemy/databases/sqlite.py +++ /dev/null @@ -1,646 +0,0 @@ -# sqlite.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""Support for the SQLite database. - -Driver ------- - -When using Python 2.5 and above, the built in ``sqlite3`` driver is -already installed and no additional installation is needed. Otherwise, -the ``pysqlite2`` driver needs to be present. This is the same driver as -``sqlite3``, just with a different name. - -The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3`` -is loaded. This allows an explicitly installed pysqlite driver to take -precedence over the built in one. As with all dialects, a specific -DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control -this explicitly:: - - from sqlite3 import dbapi2 as sqlite - e = create_engine('sqlite:///file.db', module=sqlite) - -Full documentation on pysqlite is available at: -`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_ - -Connect Strings ---------------- - -The file specification for the SQLite database is taken as the "database" portion of -the URL. Note that the format of a url is:: - - driver://user:pass@host/database - -This means that the actual filename to be used starts with the characters to the -**right** of the third slash. So connecting to a relative filepath looks like:: - - # relative path - e = create_engine('sqlite:///path/to/database.db') - -An absolute path, which is denoted by starting with a slash, means you need **four** -slashes:: - - # absolute path - e = create_engine('sqlite:////path/to/database.db') - -To use a Windows path, regular drive specifications and backslashes can be used. -Double backslashes are probably needed:: - - # absolute path on Windows - e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db') - -The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify -``sqlite://`` and nothing else:: - - # in-memory database - e = create_engine('sqlite://') - -Threading Behavior ------------------- - -Pysqlite connections do not support being moved between threads, unless -the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition, -when using an in-memory SQLite database, the full database exists only within -the scope of a single connection. It is reported that an in-memory -database does not support being shared between threads regardless of the -``check_same_thread`` flag - which means that a multithreaded -application **cannot** share data from a ``:memory:`` database across threads -unless access to the connection is limited to a single worker thread which communicates -through a queueing mechanism to concurrent threads. - -To provide a default which accomodates SQLite's default threading capabilities -somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool` -be used by default. This pool maintains a single SQLite connection per thread -that is held open up to a count of five concurrent threads. When more than five threads -are used, a cleanup mechanism will dispose of excess unused connections. - -Two optional pool implementations that may be appropriate for particular SQLite usage scenarios: - - * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded - application using an in-memory database, assuming the threading issues inherent in - pysqlite are somehow accomodated for. This pool holds persistently onto a single connection - which is never closed, and is returned for all requests. - - * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that - makes use of a file-based sqlite database. This pool disables any actual "pooling" - behavior, and simply opens and closes real connections corresonding to the :func:`connect()` - and :func:`close()` methods. SQLite can "connect" to a particular file with very high - efficiency, so this option may actually perform better without the extra overhead - of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection - useless since the database would be lost as soon as the connection is "returned" to the pool. - -Date and Time Types -------------------- - -SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide -out of the box functionality for translating values between Python `datetime` objects -and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime` -and related types provide date formatting and parsing functionality when SQlite is used. -The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`. -These types represent dates and times as ISO formatted strings, which also nicely -support ordering. There's no reliance on typical "libc" internals for these functions -so historical dates are fully supported. - -Unicode -------- - -In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's -default behavior regarding Unicode is that all strings are returned as Python unicode objects -in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is -*not* used, you will still always receive unicode data back from a result set. It is -**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type -to represent strings, since it will raise a warning if a non-unicode Python string is -passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can -quickly create confusion, particularly when using the ORM as internal data is not -always represented by an actual database result string. - -""" - - -import datetime, re, time - -from sqlalchemy import sql, schema, exc, pool, DefaultClause -from sqlalchemy.engine import default -import sqlalchemy.types as sqltypes -import sqlalchemy.util as util -from sqlalchemy.sql import compiler, functions as sql_functions -from types import NoneType - -class SLNumeric(sqltypes.Numeric): - def bind_processor(self, dialect): - type_ = self.asdecimal and str or float - def process(value): - if value is not None: - return type_(value) - else: - return value - return process - - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class SLFloat(sqltypes.Float): - def bind_processor(self, dialect): - type_ = self.asdecimal and str or float - def process(value): - if value is not None: - return type_(value) - else: - return value - return process - - def get_col_spec(self): - return "FLOAT" - -class SLInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class SLSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class DateTimeMixin(object): - def _bind_processor(self, format, elements): - def process(value): - if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)): - raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.") - elif value is not None: - return format % tuple([getattr(value, attr, 0) for attr in elements]) - else: - return None - return process - - def _result_processor(self, fn, regexp): - def process(value): - if value is not None: - return fn(*[int(x or 0) for x in regexp.match(value).groups()]) - else: - return None - return process - -class SLDateTime(DateTimeMixin, sqltypes.DateTime): - __legacy_microseconds__ = False - - def get_col_spec(self): - return "TIMESTAMP" - - def bind_processor(self, dialect): - if self.__legacy_microseconds__: - return self._bind_processor( - "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", - ("year", "month", "day", "hour", "minute", "second", "microsecond") - ) - else: - return self._bind_processor( - "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", - ("year", "month", "day", "hour", "minute", "second", "microsecond") - ) - - _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?") - def result_processor(self, dialect): - return self._result_processor(datetime.datetime, self._reg) - -class SLDate(DateTimeMixin, sqltypes.Date): - def get_col_spec(self): - return "DATE" - - def bind_processor(self, dialect): - return self._bind_processor( - "%4.4d-%2.2d-%2.2d", - ("year", "month", "day") - ) - - _reg = re.compile(r"(\d+)-(\d+)-(\d+)") - def result_processor(self, dialect): - return self._result_processor(datetime.date, self._reg) - -class SLTime(DateTimeMixin, sqltypes.Time): - __legacy_microseconds__ = False - - def get_col_spec(self): - return "TIME" - - def bind_processor(self, dialect): - if self.__legacy_microseconds__: - return self._bind_processor( - "%2.2d:%2.2d:%2.2d.%s", - ("hour", "minute", "second", "microsecond") - ) - else: - return self._bind_processor( - "%2.2d:%2.2d:%2.2d.%06d", - ("hour", "minute", "second", "microsecond") - ) - - _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") - def result_processor(self, dialect): - return self._result_processor(datetime.time, self._reg) - -class SLUnicodeMixin(object): - def bind_processor(self, dialect): - if self.convert_unicode or dialect.convert_unicode: - if self.assert_unicode is None: - assert_unicode = dialect.assert_unicode - else: - assert_unicode = self.assert_unicode - - if not assert_unicode: - return None - - def process(value): - if not isinstance(value, (unicode, NoneType)): - if assert_unicode == 'warn': - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) - return value - else: - raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) - else: - return value - return process - else: - return None - - def result_processor(self, dialect): - return None - -class SLText(SLUnicodeMixin, sqltypes.Text): - def get_col_spec(self): - return "TEXT" - -class SLString(SLUnicodeMixin, sqltypes.String): - def get_col_spec(self): - return "VARCHAR" + (self.length and "(%d)" % self.length or "") - -class SLChar(SLUnicodeMixin, sqltypes.CHAR): - def get_col_spec(self): - return "CHAR" + (self.length and "(%d)" % self.length or "") - -class SLBinary(sqltypes.Binary): - def get_col_spec(self): - return "BLOB" - -class SLBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BOOLEAN" - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - return value and 1 or 0 - return process - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value == 1 - return process - -colspecs = { - sqltypes.Binary: SLBinary, - sqltypes.Boolean: SLBoolean, - sqltypes.CHAR: SLChar, - sqltypes.Date: SLDate, - sqltypes.DateTime: SLDateTime, - sqltypes.Float: SLFloat, - sqltypes.Integer: SLInteger, - sqltypes.NCHAR: SLChar, - sqltypes.Numeric: SLNumeric, - sqltypes.Smallinteger: SLSmallInteger, - sqltypes.String: SLString, - sqltypes.Text: SLText, - sqltypes.Time: SLTime, -} - -ischema_names = { - 'BLOB': SLBinary, - 'BOOL': SLBoolean, - 'BOOLEAN': SLBoolean, - 'CHAR': SLChar, - 'DATE': SLDate, - 'DATETIME': SLDateTime, - 'DECIMAL': SLNumeric, - 'FLOAT': SLFloat, - 'INT': SLInteger, - 'INTEGER': SLInteger, - 'NUMERIC': SLNumeric, - 'REAL': SLNumeric, - 'SMALLINT': SLSmallInteger, - 'TEXT': SLText, - 'TIME': SLTime, - 'TIMESTAMP': SLDateTime, - 'VARCHAR': SLString, -} - -class SQLiteExecutionContext(default.DefaultExecutionContext): - def post_exec(self): - if self.compiled.isinsert and not self.executemany: - if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - -class SQLiteDialect(default.DefaultDialect): - name = 'sqlite' - supports_alter = False - supports_unicode_statements = True - default_paramstyle = 'qmark' - supports_default_values = True - supports_empty_insert = False - - def __init__(self, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - def vers(num): - return tuple([int(x) for x in num.split('.')]) - if self.dbapi is not None: - sqlite_ver = self.dbapi.version_info - if sqlite_ver < (2, 1, '3'): - util.warn( - ("The installed version of pysqlite2 (%s) is out-dated " - "and will cause errors in some cases. Version 2.1.3 " - "or greater is recommended.") % - '.'.join([str(subver) for subver in sqlite_ver])) - if self.dbapi.sqlite_version_info < (3, 3, 8): - self.supports_default_values = False - self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) - - def dbapi(cls): - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError, e: - try: - from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. - except ImportError: - raise e - return sqlite - dbapi = classmethod(dbapi) - - def server_version_info(self, connection): - return self.dbapi.sqlite_version_info - - def create_connect_args(self, url): - if url.username or url.password or url.host or url.port: - raise exc.ArgumentError( - "Invalid SQLite URL: %s\n" - "Valid SQLite URL forms are:\n" - " sqlite:///:memory: (or, sqlite://)\n" - " sqlite:///relative/path/to/file.db\n" - " sqlite:////absolute/path/to/file.db" % (url,)) - filename = url.database or ':memory:' - - opts = url.query.copy() - util.coerce_kw_type(opts, 'timeout', float) - util.coerce_kw_type(opts, 'isolation_level', str) - util.coerce_kw_type(opts, 'detect_types', int) - util.coerce_kw_type(opts, 'check_same_thread', bool) - util.coerce_kw_type(opts, 'cached_statements', int) - - return ([filename], opts) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def is_disconnect(self, e): - return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) - - def table_names(self, connection, schema): - if schema is not None: - qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema - s = ("SELECT name FROM %s " - "WHERE type='table' ORDER BY name") % (master,) - rs = connection.execute(s) - else: - try: - s = ("SELECT name FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) - except exc.DBAPIError: - raise - s = ("SELECT name FROM sqlite_master " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) - - return [row[0] for row in rs] - - def has_table(self, connection, table_name, schema=None): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) - - row = cursor.fetchone() - - # consume remaining rows, to work around - # http://www.sqlite.org/cvstrac/tktview?tn=1884 - while cursor.fetchone() is not None: - pass - - return (row is not None) - - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - if table.schema is None: - pragma = "PRAGMA " - else: - pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema) - qtable = preparer.format_table(table, False) - - c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) - found_table = False - while True: - row = c.fetchone() - if row is None: - break - - found_table = True - (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5]) - name = re.sub(r'^\"|\"$', '', name) - if include_columns and name not in include_columns: - continue - match = re.match(r'(\w+)(\(.*?\))?', type_) - if match: - coltype = match.group(1) - args = match.group(2) - else: - coltype = "VARCHAR" - args = '' - - try: - coltype = ischema_names[coltype] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, name)) - coltype = sqltypes.NullType - - if args is not None: - args = re.findall(r'(\d+)', args) - coltype = coltype(*[int(a) for a in args]) - - colargs = [] - if has_default: - colargs.append(DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))) - fks = {} - while True: - row = c.fetchone() - if row is None: - break - (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4]) - tablename = re.sub(r'^\"|\"$', '', tablename) - localcol = re.sub(r'^\"|\"$', '', localcol) - remotecol = re.sub(r'^\"|\"$', '', remotecol) - try: - fk = fks[constraint_name] - except KeyError: - fk = ([], []) - fks[constraint_name] = fk - - # look up the table based on the given table's engine, not 'self', - # since it could be a ProxyEngine - remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection) - constrained_column = table.c[localcol].name - refspec = ".".join([tablename, remotecol]) - if constrained_column not in fk[0]: - fk[0].append(constrained_column) - if refspec not in fk[1]: - fk[1].append(refspec) - for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True)) - # check for UNIQUE indexes - c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable))) - unique_indexes = [] - while True: - row = c.fetchone() - if row is None: - break - if (row[2] == 1): - unique_indexes.append(row[1]) - # loop thru unique indexes for one that includes the primary key - for idx in unique_indexes: - c = connection.execute("%sindex_info(%s)" % (pragma, idx)) - cols = [] - while True: - row = c.fetchone() - if row is None: - break - cols.append(row[2]) - -def _pragma_cursor(cursor): - if cursor.closed: - cursor._fetchone_impl = lambda: None - return cursor - -class SQLiteCompiler(compiler.DefaultCompiler): - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - sql_functions.now: 'CURRENT_TIMESTAMP', - sql_functions.char_length: 'length%(expr)s' - } - ) - - extract_map = compiler.DefaultCompiler.extract_map.copy() - extract_map.update({ - 'month': '%m', - 'day': '%d', - 'year': '%Y', - 'second': '%S', - 'hour': '%H', - 'doy': '%j', - 'minute': '%M', - 'epoch': '%s', - 'dow': '%w', - 'week': '%W' - }) - - def visit_cast(self, cast, **kwargs): - if self.dialect.supports_cast: - return super(SQLiteCompiler, self).visit_cast(cast) - else: - return self.process(cast.clause) - - def visit_extract(self, extract): - try: - return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( - self.extract_map[extract.field], self.process(extract.expr)) - except KeyError: - raise exc.ArgumentError( - "%s is not a valid extract argument." % extract.field) - - def limit_clause(self, select): - text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: - text += " \n LIMIT -1" - text += " OFFSET " + str(select._offset) - else: - text += " OFFSET 0" - return text - - def for_update_clause(self, select): - # sqlite has no "FOR UPDATE" AFAICT - return '' - - -class SQLiteSchemaGenerator(compiler.SchemaGenerator): - - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - return colspec - -class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = set([ - 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', - 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', - 'conflict', 'constraint', 'create', 'cross', 'current_date', - 'current_time', 'current_timestamp', 'database', 'default', - 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', - 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', - 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', - 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', - 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is', - 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural', - 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer', - 'plan', 'pragma', 'primary', 'query', 'raise', 'references', - 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback', - 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', - 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using', - 'vacuum', 'values', 'view', 'virtual', 'when', 'where', 'indexed', - ]) - - def __init__(self, dialect): - super(SQLiteIdentifierPreparer, self).__init__(dialect) - -dialect = SQLiteDialect -dialect.poolclass = pool.SingletonThreadPool -dialect.statement_compiler = SQLiteCompiler -dialect.schemagenerator = SQLiteSchemaGenerator -dialect.preparer = SQLiteIdentifierPreparer -dialect.execution_ctx_cls = SQLiteExecutionContext diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py deleted file mode 100644 index f5b48e147..000000000 --- a/lib/sqlalchemy/databases/sybase.py +++ /dev/null @@ -1,875 +0,0 @@ -# sybase.py -# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch -# Coding: Alexander Houben alexander.houben@thor-solutions.ch -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -""" -Sybase database backend. - -Known issues / TODO: - - * Uses the mx.ODBC driver from egenix (version 2.1.0) - * The current version of sqlalchemy.databases.sybase only supports - mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need - some development) - * Support for pyodbc has been built in but is not yet complete (needs - further development) - * Results of running tests/alltests.py: - Ran 934 tests in 287.032s - FAILED (failures=3, errors=1) - * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) -""" - -import datetime, operator - -from sqlalchemy import util, sql, schema, exc -from sqlalchemy.sql import compiler, expression -from sqlalchemy.engine import default, base -from sqlalchemy import types as sqltypes -from sqlalchemy.sql import operators as sql_operators -from sqlalchemy import MetaData, Table, Column -from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey - - -__all__ = [ - 'SybaseTypeError' - 'SybaseNumeric', 'SybaseFloat', 'SybaseInteger', 'SybaseBigInteger', - 'SybaseTinyInteger', 'SybaseSmallInteger', - 'SybaseDateTime_mxodbc', 'SybaseDateTime_pyodbc', - 'SybaseDate_mxodbc', 'SybaseDate_pyodbc', - 'SybaseTime_mxodbc', 'SybaseTime_pyodbc', - 'SybaseText', 'SybaseString', 'SybaseChar', 'SybaseBinary', - 'SybaseBoolean', 'SybaseTimeStamp', 'SybaseMoney', 'SybaseSmallMoney', - 'SybaseUniqueIdentifier', - ] - - -RESERVED_WORDS = set([ - "add", "all", "alter", "and", - "any", "as", "asc", "backup", - "begin", "between", "bigint", "binary", - "bit", "bottom", "break", "by", - "call", "capability", "cascade", "case", - "cast", "char", "char_convert", "character", - "check", "checkpoint", "close", "comment", - "commit", "connect", "constraint", "contains", - "continue", "convert", "create", "cross", - "cube", "current", "current_timestamp", "current_user", - "cursor", "date", "dbspace", "deallocate", - "dec", "decimal", "declare", "default", - "delete", "deleting", "desc", "distinct", - "do", "double", "drop", "dynamic", - "else", "elseif", "encrypted", "end", - "endif", "escape", "except", "exception", - "exec", "execute", "existing", "exists", - "externlogin", "fetch", "first", "float", - "for", "force", "foreign", "forward", - "from", "full", "goto", "grant", - "group", "having", "holdlock", "identified", - "if", "in", "index", "index_lparen", - "inner", "inout", "insensitive", "insert", - "inserting", "install", "instead", "int", - "integer", "integrated", "intersect", "into", - "iq", "is", "isolation", "join", - "key", "lateral", "left", "like", - "lock", "login", "long", "match", - "membership", "message", "mode", "modify", - "natural", "new", "no", "noholdlock", - "not", "notify", "null", "numeric", - "of", "off", "on", "open", - "option", "options", "or", "order", - "others", "out", "outer", "over", - "passthrough", "precision", "prepare", "primary", - "print", "privileges", "proc", "procedure", - "publication", "raiserror", "readtext", "real", - "reference", "references", "release", "remote", - "remove", "rename", "reorganize", "resource", - "restore", "restrict", "return", "revoke", - "right", "rollback", "rollup", "save", - "savepoint", "scroll", "select", "sensitive", - "session", "set", "setuser", "share", - "smallint", "some", "sqlcode", "sqlstate", - "start", "stop", "subtrans", "subtransaction", - "synchronize", "syntax_error", "table", "temporary", - "then", "time", "timestamp", "tinyint", - "to", "top", "tran", "trigger", - "truncate", "tsequal", "unbounded", "union", - "unique", "unknown", "unsigned", "update", - "updating", "user", "using", "validate", - "values", "varbinary", "varchar", "variable", - "varying", "view", "wait", "waitfor", - "when", "where", "while", "window", - "with", "with_cube", "with_lparen", "with_rollup", - "within", "work", "writetext", - ]) - -ischema = MetaData() - -tables = Table("SYSTABLE", ischema, - Column("table_id", Integer, primary_key=True), - Column("file_id", SMALLINT), - Column("table_name", CHAR(128)), - Column("table_type", CHAR(10)), - Column("creator", Integer), - #schema="information_schema" - ) - -domains = Table("SYSDOMAIN", ischema, - Column("domain_id", Integer, primary_key=True), - Column("domain_name", CHAR(128)), - Column("type_id", SMALLINT), - Column("precision", SMALLINT, quote=True), - #schema="information_schema" - ) - -columns = Table("SYSCOLUMN", ischema, - Column("column_id", Integer, primary_key=True), - Column("table_id", Integer, ForeignKey(tables.c.table_id)), - Column("pkey", CHAR(1)), - Column("column_name", CHAR(128)), - Column("nulls", CHAR(1)), - Column("width", SMALLINT), - Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)), - # FIXME: should be mx.BIGINT - Column("max_identity", Integer), - # FIXME: should be mx.ODBC.Windows.LONGVARCHAR - Column("default", String), - Column("scale", Integer), - #schema="information_schema" - ) - -foreignkeys = Table("SYSFOREIGNKEY", ischema, - Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True), - Column("foreign_key_id", SMALLINT, primary_key=True), - Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)), - #schema="information_schema" - ) -fkcols = Table("SYSFKCOL", ischema, - Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True), - Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True), - Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True), - Column("primary_column_id", Integer), - #schema="information_schema" - ) - -class SybaseTypeError(sqltypes.TypeEngine): - def result_processor(self, dialect): - return None - - def bind_processor(self, dialect): - def process(value): - raise exc.InvalidRequestError("Data type not supported", [value]) - return process - - def get_col_spec(self): - raise exc.CompileError("Data type not supported") - -class SybaseNumeric(sqltypes.Numeric): - def get_col_spec(self): - if self.scale is None: - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s)" % {'precision' : self.precision} - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class SybaseFloat(sqltypes.FLOAT, SybaseNumeric): - def __init__(self, precision = 10, asdecimal = False, scale = 2, **kwargs): - super(sqltypes.FLOAT, self).__init__(precision, asdecimal, **kwargs) - self.scale = scale - - def get_col_spec(self): - # if asdecimal is True, handle same way as SybaseNumeric - if self.asdecimal: - return SybaseNumeric.get_col_spec(self) - if self.precision is None: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return float(value) - if self.asdecimal: - return SybaseNumeric.result_processor(self, dialect) - return process - -class SybaseInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class SybaseBigInteger(SybaseInteger): - def get_col_spec(self): - return "BIGINT" - -class SybaseTinyInteger(SybaseInteger): - def get_col_spec(self): - return "TINYINT" - -class SybaseSmallInteger(SybaseInteger): - def get_col_spec(self): - return "SMALLINT" - -class SybaseDateTime_mxodbc(sqltypes.DateTime): - def __init__(self, *a, **kw): - super(SybaseDateTime_mxodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - -class SybaseDateTime_pyodbc(sqltypes.DateTime): - def __init__(self, *a, **kw): - super(SybaseDateTime_pyodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - # Convert the datetime.datetime back to datetime.time - return value - return process - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - return value - return process - -class SybaseDate_mxodbc(sqltypes.Date): - def __init__(self, *a, **kw): - super(SybaseDate_mxodbc, self).__init__(False) - - def get_col_spec(self): - return "DATE" - -class SybaseDate_pyodbc(sqltypes.Date): - def __init__(self, *a, **kw): - super(SybaseDate_pyodbc, self).__init__(False) - - def get_col_spec(self): - return "DATE" - -class SybaseTime_mxodbc(sqltypes.Time): - def __init__(self, *a, **kw): - super(SybaseTime_mxodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - # Convert the datetime.datetime back to datetime.time - return datetime.time(value.hour, value.minute, value.second, value.microsecond) - return process - -class SybaseTime_pyodbc(sqltypes.Time): - def __init__(self, *a, **kw): - super(SybaseTime_pyodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - # Convert the datetime.datetime back to datetime.time - return datetime.time(value.hour, value.minute, value.second, value.microsecond) - return process - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - return datetime.datetime(1970, 1, 1, value.hour, value.minute, value.second, value.microsecond) - return process - -class SybaseText(sqltypes.Text): - def get_col_spec(self): - return "TEXT" - -class SybaseString(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} - -class SybaseChar(sqltypes.CHAR): - def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} - -class SybaseBinary(sqltypes.Binary): - def get_col_spec(self): - return "IMAGE" - -class SybaseBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BIT" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - -class SybaseTimeStamp(sqltypes.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" - -class SybaseMoney(sqltypes.TypeEngine): - def get_col_spec(self): - return "MONEY" - -class SybaseSmallMoney(SybaseMoney): - def get_col_spec(self): - return "SMALLMONEY" - -class SybaseUniqueIdentifier(sqltypes.TypeEngine): - def get_col_spec(self): - return "UNIQUEIDENTIFIER" - -class SybaseSQLExecutionContext(default.DefaultExecutionContext): - pass - -class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext): - - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): - super(SybaseSQLExecutionContext_mxodbc, self).__init__(dialect, connection, compiled, statement, parameters) - - def pre_exec(self): - super(SybaseSQLExecutionContext_mxodbc, self).pre_exec() - - def post_exec(self): - if self.compiled.isinsert: - table = self.compiled.statement.table - # get the inserted values of the primary key - - # get any sequence IDs first (using @@identity) - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - lastrowid = int(row[0]) - if lastrowid > 0: - # an IDENTITY was inserted, fetch it - # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! - if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: - self._last_inserted_ids = [lastrowid] - else: - self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] - super(SybaseSQLExecutionContext_mxodbc, self).post_exec() - -class SybaseSQLExecutionContext_pyodbc(SybaseSQLExecutionContext): - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): - super(SybaseSQLExecutionContext_pyodbc, self).__init__(dialect, connection, compiled, statement, parameters) - - def pre_exec(self): - super(SybaseSQLExecutionContext_pyodbc, self).pre_exec() - - def post_exec(self): - if self.compiled.isinsert: - table = self.compiled.statement.table - # get the inserted values of the primary key - - # get any sequence IDs first (using @@identity) - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - lastrowid = int(row[0]) - if lastrowid > 0: - # an IDENTITY was inserted, fetch it - # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! - if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: - self._last_inserted_ids = [lastrowid] - else: - self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] - super(SybaseSQLExecutionContext_pyodbc, self).post_exec() - -class SybaseSQLDialect(default.DefaultDialect): - colspecs = { - # FIXME: unicode support - #sqltypes.Unicode : SybaseUnicode, - sqltypes.Integer : SybaseInteger, - sqltypes.SmallInteger : SybaseSmallInteger, - sqltypes.Numeric : SybaseNumeric, - sqltypes.Float : SybaseFloat, - sqltypes.String : SybaseString, - sqltypes.Binary : SybaseBinary, - sqltypes.Boolean : SybaseBoolean, - sqltypes.Text : SybaseText, - sqltypes.CHAR : SybaseChar, - sqltypes.TIMESTAMP : SybaseTimeStamp, - sqltypes.FLOAT : SybaseFloat, - } - - ischema_names = { - 'integer' : SybaseInteger, - 'unsigned int' : SybaseInteger, - 'unsigned smallint' : SybaseInteger, - 'unsigned bigint' : SybaseInteger, - 'bigint': SybaseBigInteger, - 'smallint' : SybaseSmallInteger, - 'tinyint' : SybaseTinyInteger, - 'varchar' : SybaseString, - 'long varchar' : SybaseText, - 'char' : SybaseChar, - 'decimal' : SybaseNumeric, - 'numeric' : SybaseNumeric, - 'float' : SybaseFloat, - 'double' : SybaseFloat, - 'binary' : SybaseBinary, - 'long binary' : SybaseBinary, - 'varbinary' : SybaseBinary, - 'bit': SybaseBoolean, - 'image' : SybaseBinary, - 'timestamp': SybaseTimeStamp, - 'money': SybaseMoney, - 'smallmoney': SybaseSmallMoney, - 'uniqueidentifier': SybaseUniqueIdentifier, - - 'java.lang.Object' : SybaseTypeError, - 'java serialization' : SybaseTypeError, - } - - name = 'sybase' - # Sybase backend peculiarities - supports_unicode_statements = False - supports_sane_rowcount = False - supports_sane_multi_rowcount = False - execution_ctx_cls = SybaseSQLExecutionContext - - def __new__(cls, dbapi=None, *args, **kwargs): - if cls != SybaseSQLDialect: - return super(SybaseSQLDialect, cls).__new__(cls, *args, **kwargs) - if dbapi: - print dbapi.__name__ - dialect = dialect_mapping.get(dbapi.__name__) - return dialect(*args, **kwargs) - else: - return object.__new__(cls, *args, **kwargs) - - def __init__(self, **params): - super(SybaseSQLDialect, self).__init__(**params) - self.text_as_varchar = False - # FIXME: what is the default schema for sybase connections (DBA?) ? - self.set_default_schema_name("dba") - - def dbapi(cls, module_name=None): - if module_name: - try: - dialect_cls = dialect_mapping[module_name] - return dialect_cls.import_dbapi() - except KeyError: - raise exc.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name) - else: - for dialect_cls in dialect_mapping.values(): - try: - return dialect_cls.import_dbapi() - except ImportError, e: - pass - else: - raise ImportError('No DBAPI module detected for SybaseSQL - please install mxodbc') - dbapi = classmethod(dbapi) - - def type_descriptor(self, typeobj): - newobj = sqltypes.adapt_type(typeobj, self.colspecs) - return newobj - - def last_inserted_ids(self): - return self.context.last_inserted_ids - - def get_default_schema_name(self, connection): - return self.schema_name - - def set_default_schema_name(self, schema_name): - self.schema_name = schema_name - - def do_execute(self, cursor, statement, params, **kwargs): - params = tuple(params) - super(SybaseSQLDialect, self).do_execute(cursor, statement, params, **kwargs) - - # FIXME: remove ? - def _execute(self, c, statement, parameters): - try: - if parameters == {}: - parameters = () - c.execute(statement, parameters) - self.context.rowcount = c.rowcount - c.DBPROP_COMMITPRESERVE = "Y" - except Exception, e: - raise exc.DBAPIError.instance(statement, parameters, e) - - def table_names(self, connection, schema): - """Ignore the schema and the charset for now.""" - s = sql.select([tables.c.table_name], - sql.not_(tables.c.table_name.like("SYS%")) and - tables.c.creator >= 100 - ) - rp = connection.execute(s) - return [row[0] for row in rp.fetchall()] - - def has_table(self, connection, tablename, schema=None): - # FIXME: ignore schemas for sybase - s = sql.select([tables.c.table_name], tables.c.table_name == tablename) - - c = connection.execute(s) - row = c.fetchone() - print "has_table: " + tablename + ": " + str(bool(row is not None)) - return row is not None - - def reflecttable(self, connection, table, include_columns): - # Get base columns - if table.schema is not None: - current_schema = table.schema - else: - current_schema = self.get_default_schema_name(connection) - - s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) - - c = connection.execute(s) - found_table = False - # makes sure we append the columns in the correct order - while True: - row = c.fetchone() - if row is None: - break - found_table = True - (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = ( - row[columns.c.column_name], - row[domains.c.domain_name], - row[columns.c.nulls] == 'Y', - row[columns.c.width], - row[domains.c.precision], - row[columns.c.scale], - row[columns.c.default], - row[columns.c.pkey] == 'Y', - row[columns.c.max_identity], - row[tables.c.table_id], - row[columns.c.column_id], - ) - if include_columns and name not in include_columns: - continue - - # FIXME: else problems with SybaseBinary(size) - if numericscale == 0: - numericscale = None - - args = [] - for a in (charlen, numericprec, numericscale): - if a is not None: - args.append(a) - coltype = self.ischema_names.get(type, None) - if coltype == SybaseString and charlen == -1: - coltype = SybaseText() - else: - if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % - (type, name)) - coltype = sqltypes.NULLTYPE - coltype = coltype(*args) - colargs = [] - if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) - - # any sequences ? - col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs) - if int(max_identity) > 0: - col.sequence = schema.Sequence(name + '_identity') - col.sequence.start = int(max_identity) - col.sequence.increment = 1 - - # append the column - table.append_column(col) - - # any foreign key constraint for this table ? - # note: no multi-column foreign keys are considered - s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name } - c = connection.execute(s) - foreignKeys = {} - while True: - row = c.fetchone() - if row is None: - break - (foreign_table, foreign_column, primary_table, primary_column) = ( - row[0], row[1], row[2], row[3], - ) - if not primary_table in foreignKeys.keys(): - foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]] - else: - foreignKeys[primary_table][0].append('%s'%(foreign_column)) - foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column)) - for primary_table in foreignKeys.keys(): - #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) - table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - -class SybaseSQLDialect_mxodbc(SybaseSQLDialect): - execution_ctx_cls = SybaseSQLExecutionContext_mxodbc - - def __init__(self, **params): - super(SybaseSQLDialect_mxodbc, self).__init__(**params) - - self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()} - - def import_dbapi(cls): - #import mx.ODBC.Windows as module - import mxODBC as module - return module - import_dbapi = classmethod(import_dbapi) - - colspecs = SybaseSQLDialect.colspecs.copy() - colspecs[sqltypes.Time] = SybaseTime_mxodbc - colspecs[sqltypes.Date] = SybaseDate_mxodbc - colspecs[sqltypes.DateTime] = SybaseDateTime_mxodbc - - ischema_names = SybaseSQLDialect.ischema_names.copy() - ischema_names['time'] = SybaseTime_mxodbc - ischema_names['date'] = SybaseDate_mxodbc - ischema_names['datetime'] = SybaseDateTime_mxodbc - ischema_names['smalldatetime'] = SybaseDateTime_mxodbc - - def is_disconnect(self, e): - # FIXME: optimize - #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) - #return True - return False - - def do_execute(self, cursor, statement, parameters, context=None, **kwargs): - super(SybaseSQLDialect_mxodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) - - def create_connect_args(self, url): - '''Return a tuple of *args,**kwargs''' - # FIXME: handle mx.odbc.Windows proprietary args - opts = url.translate_connect_args(username='user') - opts.update(url.query) - argsDict = {} - argsDict['user'] = opts['user'] - argsDict['password'] = opts['password'] - connArgs = [[opts['dsn']], argsDict] - return connArgs - - -class SybaseSQLDialect_pyodbc(SybaseSQLDialect): - execution_ctx_cls = SybaseSQLExecutionContext_pyodbc - - def __init__(self, **params): - super(SybaseSQLDialect_pyodbc, self).__init__(**params) - self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()} - - def import_dbapi(cls): - import mypyodbc as module - return module - import_dbapi = classmethod(import_dbapi) - - colspecs = SybaseSQLDialect.colspecs.copy() - colspecs[sqltypes.Time] = SybaseTime_pyodbc - colspecs[sqltypes.Date] = SybaseDate_pyodbc - colspecs[sqltypes.DateTime] = SybaseDateTime_pyodbc - - ischema_names = SybaseSQLDialect.ischema_names.copy() - ischema_names['time'] = SybaseTime_pyodbc - ischema_names['date'] = SybaseDate_pyodbc - ischema_names['datetime'] = SybaseDateTime_pyodbc - ischema_names['smalldatetime'] = SybaseDateTime_pyodbc - - def is_disconnect(self, e): - # FIXME: optimize - #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) - #return True - return False - - def do_execute(self, cursor, statement, parameters, context=None, **kwargs): - super(SybaseSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) - - def create_connect_args(self, url): - '''Return a tuple of *args,**kwargs''' - # FIXME: handle pyodbc proprietary args - opts = url.translate_connect_args(username='user') - opts.update(url.query) - - self.autocommit = False - if 'autocommit' in opts: - self.autocommit = bool(int(opts.pop('autocommit'))) - - argsDict = {} - argsDict['UID'] = opts['user'] - argsDict['PWD'] = opts['password'] - argsDict['DSN'] = opts['dsn'] - connArgs = [[';'.join(["%s=%s"%(key, argsDict[key]) for key in argsDict])], {'autocommit' : self.autocommit}] - return connArgs - - -dialect_mapping = { - 'sqlalchemy.databases.mxODBC' : SybaseSQLDialect_mxodbc, -# 'pyodbc' : SybaseSQLDialect_pyodbc, - } - - -class SybaseSQLCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() - operators.update({ - sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y), - }) - - extract_map = compiler.DefaultCompiler.extract_map.copy() - extract_map.update ({ - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond' - }) - - - def bindparam_string(self, name): - res = super(SybaseSQLCompiler, self).bindparam_string(name) - if name.lower().startswith('literal'): - res = 'STRING(%s)' % res - return res - - def get_select_precolumns(self, select): - s = select._distinct and "DISTINCT " or "" - if select._limit: - #if select._limit == 1: - #s += "FIRST " - #else: - #s += "TOP %s " % (select._limit,) - s += "TOP %s " % (select._limit,) - if select._offset: - if not select._limit: - # FIXME: sybase doesn't allow an offset without a limit - # so use a huge value for TOP here - s += "TOP 1000000 " - s += "START AT %s " % (select._offset+1,) - return s - - def limit_clause(self, select): - # Limit in sybase is after the select keyword - return "" - - def visit_binary(self, binary): - """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: - return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) - else: - return super(SybaseSQLCompiler, self).visit_binary(binary) - - def label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label(None) - else: - return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) - - function_rewrites = {'current_date': 'getdate', - } - def visit_function(self, func): - func.name = self.function_rewrites.get(func.name, func.name) - res = super(SybaseSQLCompiler, self).visit_function(func) - if func.name.lower() == 'getdate': - # apply CAST operator - # FIXME: what about _pyodbc ? - cast = expression._Cast(func, SybaseDate_mxodbc) - # infinite recursion - # res = self.visit_cast(cast) - res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) - return res - - def visit_extract(self, extract): - field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) - - def for_update_clause(self, select): - # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use - return '' - - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) - - # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not self.is_subquery() or select._limit): - return " ORDER BY " + order_by - else: - return "" - - -class SybaseSQLSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - - colspec = self.preparer.format_column(column) - - if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ - column.autoincrement and isinstance(column.type, sqltypes.Integer): - if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): - column.sequence = schema.Sequence(column.name + '_seq') - - if hasattr(column, 'sequence'): - column.table.has_sequence = column - #colspec += " numeric(30,0) IDENTITY" - colspec += " Integer IDENTITY" - else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - - if not column.nullable: - colspec += " NOT NULL" - - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - return colspec - - -class SybaseSQLSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - self.append("\nDROP INDEX %s.%s" % ( - self.preparer.quote_identifier(index.table.name), - self.preparer.quote(self._validate_identifier(index.name, False), index.quote) - )) - self.execute() - - -class SybaseSQLDefaultRunner(base.DefaultRunner): - pass - - -class SybaseSQLIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = RESERVED_WORDS - - def __init__(self, dialect): - super(SybaseSQLIdentifierPreparer, self).__init__(dialect) - - def _escape_identifier(self, value): - #TODO: determin SybaseSQL's escapeing rules - return value - - def _fold_identifier_case(self, value): - #TODO: determin SybaseSQL's case folding rules - return value - - -dialect = SybaseSQLDialect -dialect.statement_compiler = SybaseSQLCompiler -dialect.schemagenerator = SybaseSQLSchemaGenerator -dialect.schemadropper = SybaseSQLSchemaDropper -dialect.preparer = SybaseSQLIdentifierPreparer -dialect.defaultrunner = SybaseSQLDefaultRunner |