diff options
Diffstat (limited to 'lib/sqlalchemy/databases/mssql.py')
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 159 |
1 files changed, 71 insertions, 88 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 1852edefb..6d2ff66cd 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -52,7 +52,22 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions - +def dbapi(module_name=None): + if module_name: + try: + dialect_cls = dialect_mapping[module_name] + return dialect_cls.import_dbapi() + except KeyError: + raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) + else: + for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]: + try: + return dialect_cls.import_dbapi() + except ImportError, e: + pass + else: + raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc') + class MSNumeric(sqltypes.Numeric): def convert_result_value(self, value, dialect): return value @@ -142,9 +157,6 @@ class MSString(sqltypes.String): return "VARCHAR(%(length)s)" % {'length' : self.length} class MSNVarchar(MSString): - """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True. """ - impl = sqltypes.Unicode - def get_col_spec(self): if self.length: return "NVARCHAR(%(length)s)" % {'length' : self.length} @@ -154,19 +166,7 @@ class MSNVarchar(MSString): return "NTEXT" class AdoMSNVarchar(MSNVarchar): - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - return value - -class MSUnicode(sqltypes.Unicode): - """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl.""" - impl = MSNVarchar - -class AdoMSUnicode(MSUnicode): - impl = AdoMSNVarchar - + """overrides bindparam/result processing to not convert any unicode strings""" def convert_bind_param(self, value, dialect): return value @@ -215,9 +215,9 @@ def descriptor(): ]} class MSSQLExecutionContext(default.DefaultExecutionContext): - def __init__(self, dialect): + def __init__(self, *args, **kwargs): self.IINSERT = self.HASIDENT = False - super(MSSQLExecutionContext, self).__init__(dialect) + super(MSSQLExecutionContext, self).__init__(*args, **kwargs) def _has_implicit_sequence(self, column): if column.primary_key and column.autoincrement: @@ -227,14 +227,14 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): return True return False - def pre_exec(self, engine, proxy, compiled, parameters, **kwargs): + def pre_exec(self): """MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. Activate it if the feature is turned on and needed. """ - if getattr(compiled, "isinsert", False): - tbl = compiled.statement.table + if self.compiled.isinsert: + tbl = self.compiled.statement.table if not hasattr(tbl, 'has_sequence'): tbl.has_sequence = None for column in tbl.c: @@ -243,39 +243,43 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): break self.HASIDENT = bool(tbl.has_sequence) - if engine.dialect.auto_identity_insert and self.HASIDENT: - if isinstance(parameters, list): - self.IINSERT = tbl.has_sequence.key in parameters[0] + if self.dialect.auto_identity_insert and self.HASIDENT: + if isinstance(self.compiled_parameters, list): + self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0] else: - self.IINSERT = tbl.has_sequence.key in parameters + self.IINSERT = tbl.has_sequence.key in self.compiled_parameters else: self.IINSERT = False if self.IINSERT: - proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name) + # TODO: quoting rules for table name here ? + self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name) - super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs) + super(MSSQLExecutionContext, self).pre_exec() - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + def post_exec(self): """Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column). """ - if getattr(compiled, "isinsert", False): + if self.compiled.isinsert: if self.IINSERT: - proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name) + # TODO: quoting rules for table name here ? + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name) self.IINSERT = False elif self.HASIDENT: - cursor = proxy("SELECT @@IDENTITY AS lastrowid") - row = cursor.fetchone() + self.cursor.execute("SELECT @@IDENTITY AS lastrowid") + row = self.cursor.fetchone() self._last_inserted_ids = [int(row[0])] # print "LAST ROW ID", self._last_inserted_ids self.HASIDENT = False + super(MSSQLExecutionContext, self).post_exec() class MSSQLDialect(ansisql.ANSIDialect): colspecs = { + sqltypes.Unicode : MSNVarchar, sqltypes.Integer : MSInteger, sqltypes.Smallinteger: MSSmallInteger, sqltypes.Numeric : MSNumeric, @@ -283,7 +287,6 @@ class MSSQLDialect(ansisql.ANSIDialect): sqltypes.DateTime : MSDateTime, sqltypes.Date : MSDate, sqltypes.String : MSString, - sqltypes.Unicode : MSUnicode, sqltypes.Binary : MSBinary, sqltypes.Boolean : MSBoolean, sqltypes.TEXT : MSText, @@ -296,7 +299,7 @@ class MSSQLDialect(ansisql.ANSIDialect): 'smallint' : MSSmallInteger, 'tinyint' : MSTinyInteger, 'varchar' : MSString, - 'nvarchar' : MSUnicode, + 'nvarchar' : MSNVarchar, 'char' : MSChar, 'nchar' : MSNChar, 'text' : MSText, @@ -312,30 +315,16 @@ class MSSQLDialect(ansisql.ANSIDialect): 'image' : MSBinary } - def __new__(cls, module_name=None, *args, **kwargs): - module = kwargs.get('module', None) + def __new__(cls, dbapi=None, *args, **kwargs): if cls != MSSQLDialect: return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs) - if module_name: - dialect = dialect_mapping.get(module_name) - if not dialect: - raise exceptions.InvalidRequestError('Unsupported MSSQL module requested (must be adodbpi, pymssql or pyodbc): ' + module_name) - if not hasattr(dialect, 'module'): - raise dialect.saved_import_error + if dbapi: + dialect = dialect_mapping.get(dbapi.__name__) return dialect(*args, **kwargs) - elif module: - return object.__new__(cls, *args, **kwargs) else: - for dialect in dialect_preference: - if hasattr(dialect, 'module'): - return dialect(*args, **kwargs) - #raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc') - else: - return object.__new__(cls, *args, **kwargs) + return object.__new__(cls, *args, **kwargs) - def __init__(self, module_name=None, module=None, auto_identity_insert=True, **params): - if not hasattr(self, 'module'): - self.module = module + def __init__(self, auto_identity_insert=True, **params): super(MSSQLDialect, self).__init__(**params) self.auto_identity_insert = auto_identity_insert self.text_as_varchar = False @@ -352,8 +341,8 @@ class MSSQLDialect(ansisql.ANSIDialect): self.text_as_varchar = bool(opts.pop('text_as_varchar')) return self.make_connect_string(opts) - def create_execution_context(self): - return MSSQLExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return MSSQLExecutionContext(self, *args, **kwargs) def type_descriptor(self, typeobj): newobj = sqltypes.adapt_type(typeobj, self.colspecs) @@ -373,13 +362,13 @@ class MSSQLDialect(ansisql.ANSIDialect): return MSSQLCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return MSSQLSchemaGenerator(*args, **kwargs) + return MSSQLSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return MSSQLSchemaDropper(*args, **kwargs) + return MSSQLSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return MSSQLDefaultRunner(engine, proxy) + def defaultrunner(self, connection, **kwargs): + return MSSQLDefaultRunner(connection, **kwargs) def preparer(self): return MSSQLIdentifierPreparer(self) @@ -411,19 +400,12 @@ class MSSQLDialect(ansisql.ANSIDialect): def raw_connection(self, connection): """Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes""" try: + # TODO: probably want to move this to individual dialect subclasses to + # save on the exception throw + simplify return connection.connection.__dict__['_pymssqlCnx__cnx'] except: return connection.connection.adoConn - def connection(self): - """returns a managed DBAPI connection from this SQLEngine's connection pool.""" - c = self._pool.connect() - c.supportsTransactions = 0 - return c - - def dbapi(self): - return self.module - def uppercase_table(self, t): # convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive t.name = t.name.upper() @@ -558,13 +540,14 @@ class MSSQLDialect(ansisql.ANSIDialect): table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) class MSSQLDialect_pymssql(MSSQLDialect): - try: + def import_dbapi(cls): import pymssql as module # pymmsql doesn't have a Binary method. we use string + # TODO: monkeypatching here is less than ideal module.Binary = lambda st: str(st) - except ImportError, e: - saved_import_error = e - + return module + import_dbapi = classmethod(import_dbapi) + def supports_sane_rowcount(self): return True @@ -578,7 +561,7 @@ class MSSQLDialect_pymssql(MSSQLDialect): def create_connect_args(self, url): r = super(MSSQLDialect_pymssql, self).create_connect_args(url) if hasattr(self, 'query_timeout'): - self.module._mssql.set_query_timeout(self.query_timeout) + self.dbapi._mssql.set_query_timeout(self.query_timeout) return r def make_connect_string(self, keys): @@ -621,15 +604,16 @@ class MSSQLDialect_pymssql(MSSQLDialect): ## r.fetch_array() class MSSQLDialect_pyodbc(MSSQLDialect): - try: + + def import_dbapi(cls): import pyodbc as module - except ImportError, e: - saved_import_error = e - + return module + import_dbapi = classmethod(import_dbapi) + colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Unicode] = AdoMSUnicode + colspecs[sqltypes.Unicode] = AdoMSNVarchar ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['nvarchar'] = AdoMSUnicode + ischema_names['nvarchar'] = AdoMSNVarchar def supports_sane_rowcount(self): return False @@ -648,15 +632,15 @@ class MSSQLDialect_pyodbc(MSSQLDialect): class MSSQLDialect_adodbapi(MSSQLDialect): - try: + def import_dbapi(cls): import adodbapi as module - except ImportError, e: - saved_import_error = e + return module + import_dbapi = classmethod(import_dbapi) colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Unicode] = AdoMSUnicode + colspecs[sqltypes.Unicode] = AdoMSNVarchar ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['nvarchar'] = AdoMSUnicode + ischema_names['nvarchar'] = AdoMSNVarchar def supports_sane_rowcount(self): return True @@ -676,13 +660,11 @@ class MSSQLDialect_adodbapi(MSSQLDialect): connectors.append("Integrated Security=SSPI") return [[";".join (connectors)], {}] - dialect_mapping = { 'pymssql': MSSQLDialect_pymssql, 'pyodbc': MSSQLDialect_pyodbc, 'adodbapi': MSSQLDialect_adodbapi } -dialect_preference = [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc] class MSSQLCompiler(ansisql.ANSICompiler): @@ -770,7 +752,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() # install a IDENTITY Sequence if we have an implicit IDENTITY column if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ @@ -797,6 +779,7 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): self.execute() class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): + # TODO: does ms-sql have standalone sequences ? pass class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): |