diff options
author | Michael Trier <mtrier@gmail.com> | 2008-12-30 06:39:37 +0000 |
---|---|---|
committer | Michael Trier <mtrier@gmail.com> | 2008-12-30 06:39:37 +0000 |
commit | f62a78242dd7d16f40722db5b8310f900b04efec (patch) | |
tree | 58797258b080390634d43d7c24791119a8b473a0 /lib/sqlalchemy/databases/mssql.py | |
parent | dfd80ba089c0d0637f54cbd6b21332d5f5115999 (diff) | |
download | sqlalchemy-f62a78242dd7d16f40722db5b8310f900b04efec.tar.gz |
Modifications to the mssql dialect in order to to pass through unicode in the pyodbc dialect.
Diffstat (limited to 'lib/sqlalchemy/databases/mssql.py')
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 67 |
1 files changed, 30 insertions, 37 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index ad9ba847a..490e562e9 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -221,6 +221,7 @@ from decimal import Decimal as _python_Decimal MSSQL_RESERVED_WORDS = set(['function']) + class _StringType(object): """Base for MSSQL string types.""" @@ -253,6 +254,33 @@ class _StringType(object): return "%s(%s)" % (self.__class__.__name__, ', '.join(['%s=%r' % (k, params[k]) for k in params])) + def bind_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + if self.assert_unicode is None: + assert_unicode = dialect.assert_unicode + else: + assert_unicode = self.assert_unicode + + if not assert_unicode: + return None + + def process(value): + if not isinstance(value, (unicode, sqltypes.NoneType)): + if assert_unicode == 'warn': + util.warn("Unicode type received non-unicode bind " + "param value %r" % value) + return value + else: + raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) + else: + return value + return process + else: + return None + + def result_processor(self, dialect): + return None + class MSNumeric(sqltypes.Numeric): def result_processor(self, dialect): @@ -573,36 +601,6 @@ class MSNVarchar(_StringType, sqltypes.Unicode): return self._extend("NVARCHAR") -class AdoMSNVarchar(_StringType, sqltypes.Unicode): - """overrides bindparam/result processing to not convert any unicode strings""" - - def __init__(self, length=None, **kwargs): - """Construct a NVARCHAR. - - :param length: Optional, Maximum data length, in characters. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.Unicode.__init__(self, length=length, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - return None - - def get_col_spec(self): - if self.length: - return self._extend("NVARCHAR(%(length)s)" % {'length' : self.length}) - else: - return self._extend("NVARCHAR") - - class MSChar(_StringType, sqltypes.CHAR): """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum of 8,000 characters.""" @@ -1086,7 +1084,7 @@ class MSSQLDialect(default.DefaultDialect): coltype = self.ischema_names.get(type, None) kwargs = {} - if coltype in (MSString, MSChar, MSNVarchar, AdoMSNVarchar, MSNChar, MSText, MSNText): + if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText): if collation: kwargs.update(collation=collation) @@ -1098,7 +1096,7 @@ class MSSQLDialect(default.DefaultDialect): (type, name)) coltype = sqltypes.NULLTYPE - elif coltype in (MSNVarchar, AdoMSNVarchar) and charlen == -1: + elif coltype in (MSNVarchar,) and charlen == -1: args[0] = None coltype = coltype(*args, **kwargs) colargs = [] @@ -1262,9 +1260,6 @@ class MSSQLDialect_pyodbc(MSSQLDialect): ischema_names = MSSQLDialect.ischema_names.copy() ischema_names['smalldatetime'] = MSDate_pyodbc ischema_names['datetime'] = MSDateTime_pyodbc - if supports_unicode: - colspecs[sqltypes.Unicode] = AdoMSNVarchar - ischema_names['nvarchar'] = AdoMSNVarchar def make_connect_string(self, keys, query): if 'max_identifier_length' in keys: @@ -1335,11 +1330,9 @@ class MSSQLDialect_adodbapi(MSSQLDialect): return module colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Unicode] = AdoMSNVarchar colspecs[sqltypes.DateTime] = MSDateTime_adodbapi ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['nvarchar'] = AdoMSNVarchar ischema_names['datetime'] = MSDateTime_adodbapi def make_connect_string(self, keys, query): |