summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/mssql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/databases/mssql.py')
-rw-r--r--lib/sqlalchemy/databases/mssql.py113
1 files changed, 58 insertions, 55 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index ba1c0fd9d..206291404 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -25,7 +25,7 @@
* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
-* ``select.limit`` implemented as ``SELECT TOP n``
+* ``select._limit`` implemented as ``SELECT TOP n``
Known issues / TODO:
@@ -39,16 +39,11 @@ Known issues / TODO:
"""
-import sys, StringIO, string, types, re, datetime, random
+import datetime, random, warnings
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
-import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
+from sqlalchemy import sql, schema, ansisql, exceptions
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
-
+from sqlalchemy.engine import default
class MSNumeric(sqltypes.Numeric):
def convert_result_value(self, value, dialect):
@@ -500,7 +495,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
row = c.fetchone()
return row is not None
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
import sqlalchemy.databases.information_schema as ischema
# Get base columns
@@ -532,16 +527,22 @@ class MSSQLDialect(ansisql.ANSIDialect):
row[columns.c.numeric_scale],
row[columns.c.column_default]
)
+ if include_columns and name not in include_columns:
+ continue
args = []
for a in (charlen, numericprec, numericscale):
if a is not None:
args.append(a)
- coltype = self.ischema_names[type]
+ coltype = self.ischema_names.get(type, None)
if coltype == MSString and charlen == -1:
coltype = MSText()
else:
- if coltype == MSNVarchar and charlen == -1:
+ if coltype is None:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name)))
+ coltype = sqltypes.NULLTYPE
+
+ elif coltype == MSNVarchar and charlen == -1:
charlen = None
coltype = coltype(*args)
colargs= []
@@ -812,12 +813,12 @@ class MSSQLCompiler(ansisql.ANSICompiler):
super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
self.tablealiases = {}
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
- s = select.distinct and "DISTINCT " or ""
- if select.limit:
- s += "TOP %s " % (select.limit,)
- if select.offset:
+ s = select._distinct and "DISTINCT " or ""
+ if select._limit:
+ s += "TOP %s " % (select._limit,)
+ if select._offset:
raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
@@ -825,49 +826,50 @@ class MSSQLCompiler(ansisql.ANSICompiler):
# Limit in mssql is after the select keyword
return ""
- def visit_table(self, table):
+ def _schema_aliased_table(self, table):
+ if getattr(table, 'schema', None) is not None:
+ if not self.tablealiases.has_key(table):
+ self.tablealiases[table] = table.alias()
+ return self.tablealiases[table]
+ else:
+ return None
+
+ def visit_table(self, table, mssql_aliased=False, **kwargs):
+ if mssql_aliased:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
# alias schema-qualified tables
- if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table):
- alias = table.alias()
- self.tablealiases[table] = alias
- self.traverse(alias)
- self.froms[('alias', table)] = self.froms[table]
- for c in alias.c:
- self.traverse(c)
- self.traverse(alias.oid_column)
- self.tablealiases[alias] = self.froms[table]
- self.froms[table] = self.froms[alias]
+ alias = self._schema_aliased_table(table)
+ if alias is not None:
+ return self.process(alias, mssql_aliased=True, **kwargs)
else:
- super(MSSQLCompiler, self).visit_table(table)
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
- def visit_alias(self, alias):
+ def visit_alias(self, alias, **kwargs):
# translate for schema-qualified table aliases
- if self.froms.has_key(('alias', alias.original)):
- self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name
- self.strings[alias] = ""
- else:
- super(MSSQLCompiler, self).visit_alias(alias)
+ self.tablealiases[alias.original] = alias
+ return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
def visit_column(self, column):
- # translate for schema-qualified table aliases
- super(MSSQLCompiler, self).visit_column(column)
- if column.table is not None and self.tablealiases.has_key(column.table):
- self.strings[column] = \
- self.strings[self.tablealiases[column.table].corresponding_column(column)]
+ if column.table is not None:
+ # translate for schema-qualified table aliases
+ t = self._schema_aliased_table(column.table)
+ if t is not None:
+ return self.process(t.corresponding_column(column))
+ return super(MSSQLCompiler, self).visit_column(column)
def visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=':
- binary.left, binary.right = binary.right, binary.left
- super(MSSQLCompiler, self).visit_binary(binary)
-
- def visit_select(self, select):
- # label function calls, so they return a name in cursor.description
- for i,c in enumerate(select._raw_columns):
- if isinstance(c, sql._Function):
- select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])
+ if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq:
+ return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator))
+ else:
+ return super(MSSQLCompiler, self).visit_binary(binary)
- super(MSSQLCompiler, self).visit_select(select)
+ def label_select_column(self, select, column):
+ if isinstance(column, sql._Function):
+ return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
+ else:
+ return super(MSSQLCompiler, self).label_select_column(select, column)
function_rewrites = {'current_date': 'getdate',
'length': 'len',
@@ -881,10 +883,10 @@ class MSSQLCompiler(ansisql.ANSICompiler):
return ''
def order_by_clause(self, select):
- order_by = self.get_str(select.order_by_clause)
+ order_by = self.process(select._order_by_clause)
# MSSQL only allows ORDER BY in subqueries if there is a LIMIT
- if order_by and (not select.is_subquery or select.limit):
+ if order_by and (not self.is_subquery(select) or select._limit):
return " ORDER BY " + order_by
else:
return ""
@@ -916,10 +918,12 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX %s.%s" % (
- self.preparer.quote_identifier(index.table.name),
- self.preparer.quote_identifier(index.name)))
+ self.preparer.quote_identifier(index.table.name),
+ self.preparer.quote_identifier(index.name)
+ ))
self.execute()
+
class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
# TODO: does ms-sql have standalone sequences ?
pass
@@ -940,4 +944,3 @@ dialect = MSSQLDialect
-