diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
commit | ed4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch) | |
tree | c1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy/databases/mssql.py | |
parent | 3a8e235af64e36b3b711df1f069d32359fe6c967 (diff) | |
download | sqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz |
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy/databases/mssql.py')
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 113 |
1 files changed, 58 insertions, 55 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index ba1c0fd9d..206291404 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -25,7 +25,7 @@ * Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT`` -* ``select.limit`` implemented as ``SELECT TOP n`` +* ``select._limit`` implemented as ``SELECT TOP n`` Known issues / TODO: @@ -39,16 +39,11 @@ Known issues / TODO: """ -import sys, StringIO, string, types, re, datetime, random +import datetime, random, warnings -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +from sqlalchemy import sql, schema, ansisql, exceptions import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions - +from sqlalchemy.engine import default class MSNumeric(sqltypes.Numeric): def convert_result_value(self, value, dialect): @@ -500,7 +495,7 @@ class MSSQLDialect(ansisql.ANSIDialect): row = c.fetchone() return row is not None - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): import sqlalchemy.databases.information_schema as ischema # Get base columns @@ -532,16 +527,22 @@ class MSSQLDialect(ansisql.ANSIDialect): row[columns.c.numeric_scale], row[columns.c.column_default] ) + if include_columns and name not in include_columns: + continue args = [] for a in (charlen, numericprec, numericscale): if a is not None: args.append(a) - coltype = self.ischema_names[type] + coltype = self.ischema_names.get(type, None) if coltype == MSString and charlen == -1: coltype = MSText() else: - if coltype == MSNVarchar and charlen == -1: + if coltype is None: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name))) + coltype = sqltypes.NULLTYPE + + elif coltype == MSNVarchar and charlen == -1: charlen = None coltype = coltype(*args) colargs= [] @@ -812,12 +813,12 @@ class MSSQLCompiler(ansisql.ANSICompiler): super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) self.tablealiases = {} - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ - s = select.distinct and "DISTINCT " or "" - if select.limit: - s += "TOP %s " % (select.limit,) - if select.offset: + s = select._distinct and "DISTINCT " or "" + if select._limit: + s += "TOP %s " % (select._limit,) + if select._offset: raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s @@ -825,49 +826,50 @@ class MSSQLCompiler(ansisql.ANSICompiler): # Limit in mssql is after the select keyword return "" - def visit_table(self, table): + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if not self.tablealiases.has_key(table): + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + # alias schema-qualified tables - if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table): - alias = table.alias() - self.tablealiases[table] = alias - self.traverse(alias) - self.froms[('alias', table)] = self.froms[table] - for c in alias.c: - self.traverse(c) - self.traverse(alias.oid_column) - self.tablealiases[alias] = self.froms[table] - self.froms[table] = self.froms[alias] + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) else: - super(MSSQLCompiler, self).visit_table(table) + return super(MSSQLCompiler, self).visit_table(table, **kwargs) - def visit_alias(self, alias): + def visit_alias(self, alias, **kwargs): # translate for schema-qualified table aliases - if self.froms.has_key(('alias', alias.original)): - self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name - self.strings[alias] = "" - else: - super(MSSQLCompiler, self).visit_alias(alias) + self.tablealiases[alias.original] = alias + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) def visit_column(self, column): - # translate for schema-qualified table aliases - super(MSSQLCompiler, self).visit_column(column) - if column.table is not None and self.tablealiases.has_key(column.table): - self.strings[column] = \ - self.strings[self.tablealiases[column.table].corresponding_column(column)] + if column.table is not None: + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + return self.process(t.corresponding_column(column)) + return super(MSSQLCompiler, self).visit_column(column) def visit_binary(self, binary): """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=': - binary.left, binary.right = binary.right, binary.left - super(MSSQLCompiler, self).visit_binary(binary) - - def visit_select(self, select): - # label function calls, so they return a name in cursor.description - for i,c in enumerate(select._raw_columns): - if isinstance(c, sql._Function): - select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:]) + if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq: + return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator)) + else: + return super(MSSQLCompiler, self).visit_binary(binary) - super(MSSQLCompiler, self).visit_select(select) + def label_select_column(self, select, column): + if isinstance(column, sql._Function): + return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) + else: + return super(MSSQLCompiler, self).label_select_column(select, column) function_rewrites = {'current_date': 'getdate', 'length': 'len', @@ -881,10 +883,10 @@ class MSSQLCompiler(ansisql.ANSICompiler): return '' def order_by_clause(self, select): - order_by = self.get_str(select.order_by_clause) + order_by = self.process(select._order_by_clause) # MSSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not select.is_subquery or select.limit): + if order_by and (not self.is_subquery(select) or select._limit): return " ORDER BY " + order_by else: return "" @@ -916,10 +918,12 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX %s.%s" % ( - self.preparer.quote_identifier(index.table.name), - self.preparer.quote_identifier(index.name))) + self.preparer.quote_identifier(index.table.name), + self.preparer.quote_identifier(index.name) + )) self.execute() + class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): # TODO: does ms-sql have standalone sequences ? pass @@ -940,4 +944,3 @@ dialect = MSSQLDialect - |