summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/mssql.py
diff options
context:
space:
mode:
authorMichael Trier <mtrier@gmail.com>2008-12-30 06:39:37 +0000
committerMichael Trier <mtrier@gmail.com>2008-12-30 06:39:37 +0000
commitf62a78242dd7d16f40722db5b8310f900b04efec (patch)
tree58797258b080390634d43d7c24791119a8b473a0 /lib/sqlalchemy/databases/mssql.py
parentdfd80ba089c0d0637f54cbd6b21332d5f5115999 (diff)
downloadsqlalchemy-f62a78242dd7d16f40722db5b8310f900b04efec.tar.gz
Modifications to the mssql dialect in order to to pass through unicode in the pyodbc dialect.
Diffstat (limited to 'lib/sqlalchemy/databases/mssql.py')
-rw-r--r--lib/sqlalchemy/databases/mssql.py67
1 files changed, 30 insertions, 37 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index ad9ba847a..490e562e9 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -221,6 +221,7 @@ from decimal import Decimal as _python_Decimal
MSSQL_RESERVED_WORDS = set(['function'])
+
class _StringType(object):
"""Base for MSSQL string types."""
@@ -253,6 +254,33 @@ class _StringType(object):
return "%s(%s)" % (self.__class__.__name__,
', '.join(['%s=%r' % (k, params[k]) for k in params]))
+ def bind_processor(self, dialect):
+ if self.convert_unicode or dialect.convert_unicode:
+ if self.assert_unicode is None:
+ assert_unicode = dialect.assert_unicode
+ else:
+ assert_unicode = self.assert_unicode
+
+ if not assert_unicode:
+ return None
+
+ def process(value):
+ if not isinstance(value, (unicode, sqltypes.NoneType)):
+ if assert_unicode == 'warn':
+ util.warn("Unicode type received non-unicode bind "
+ "param value %r" % value)
+ return value
+ else:
+ raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
+ else:
+ return value
+ return process
+ else:
+ return None
+
+ def result_processor(self, dialect):
+ return None
+
class MSNumeric(sqltypes.Numeric):
def result_processor(self, dialect):
@@ -573,36 +601,6 @@ class MSNVarchar(_StringType, sqltypes.Unicode):
return self._extend("NVARCHAR")
-class AdoMSNVarchar(_StringType, sqltypes.Unicode):
- """overrides bindparam/result processing to not convert any unicode strings"""
-
- def __init__(self, length=None, **kwargs):
- """Construct a NVARCHAR.
-
- :param length: Optional, Maximum data length, in characters.
-
- :param collation: Optional, a column-level collation for this string
- value. Accepts a Windows Collation Name or a SQL Collation Name.
-
- """
- _StringType.__init__(self, **kwargs)
- sqltypes.Unicode.__init__(self, length=length,
- convert_unicode=kwargs.get('convert_unicode', True),
- assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
- def bind_processor(self, dialect):
- return None
-
- def result_processor(self, dialect):
- return None
-
- def get_col_spec(self):
- if self.length:
- return self._extend("NVARCHAR(%(length)s)" % {'length' : self.length})
- else:
- return self._extend("NVARCHAR")
-
-
class MSChar(_StringType, sqltypes.CHAR):
"""MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
of 8,000 characters."""
@@ -1086,7 +1084,7 @@ class MSSQLDialect(default.DefaultDialect):
coltype = self.ischema_names.get(type, None)
kwargs = {}
- if coltype in (MSString, MSChar, MSNVarchar, AdoMSNVarchar, MSNChar, MSText, MSNText):
+ if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText):
if collation:
kwargs.update(collation=collation)
@@ -1098,7 +1096,7 @@ class MSSQLDialect(default.DefaultDialect):
(type, name))
coltype = sqltypes.NULLTYPE
- elif coltype in (MSNVarchar, AdoMSNVarchar) and charlen == -1:
+ elif coltype in (MSNVarchar,) and charlen == -1:
args[0] = None
coltype = coltype(*args, **kwargs)
colargs = []
@@ -1262,9 +1260,6 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
ischema_names = MSSQLDialect.ischema_names.copy()
ischema_names['smalldatetime'] = MSDate_pyodbc
ischema_names['datetime'] = MSDateTime_pyodbc
- if supports_unicode:
- colspecs[sqltypes.Unicode] = AdoMSNVarchar
- ischema_names['nvarchar'] = AdoMSNVarchar
def make_connect_string(self, keys, query):
if 'max_identifier_length' in keys:
@@ -1335,11 +1330,9 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
return module
colspecs = MSSQLDialect.colspecs.copy()
- colspecs[sqltypes.Unicode] = AdoMSNVarchar
colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
ischema_names = MSSQLDialect.ischema_names.copy()
- ischema_names['nvarchar'] = AdoMSNVarchar
ischema_names['datetime'] = MSDateTime_adodbapi
def make_connect_string(self, keys, query):