diff options
Diffstat (limited to 'lib/sqlalchemy/databases/mssql.py')
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index f86a95548..329109828 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -922,12 +922,15 @@ class MSSQLCompiler(compiler.DefaultCompiler): def get_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ - if not self.dialect.has_window_funcs: + if select._distinct or select._limit: s = select._distinct and "DISTINCT " or "" + if select._limit: - s += "TOP %s " % (select._limit,) - if select._offset: - raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') + if not select._offset: + s += "TOP %s " % (select._limit,) + else: + if not self.dialect.has_window_funcs: + raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s return compiler.DefaultCompiler.get_select_precolumns(self, select) @@ -938,13 +941,13 @@ class MSSQLCompiler(compiler.DefaultCompiler): def visit_select(self, select, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. + """ - if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and (select._limit is not None or select._offset is not None): + if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and (select._offset is not None): # to use ROW_NUMBER(), an ORDER BY is required. orderby = self.process(select._order_by_clause) if not orderby: - orderby = list(select.oid_column.proxies)[0] - orderby = self.process(orderby) + raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') _offset = select._offset _limit = select._limit @@ -952,12 +955,9 @@ class MSSQLCompiler(compiler.DefaultCompiler): select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) - if _offset is not None: - limitselect.append_whereclause("mssql_rn>=%d" % _offset) - if _limit is not None: - limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) - else: - limitselect.append_whereclause("mssql_rn<=%d" % _limit) + limitselect.append_whereclause("mssql_rn>%d" % _offset) + if _limit is not None: + limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) return self.process(limitselect, iswrapper=True, **kwargs) else: return compiler.DefaultCompiler.visit_select(self, select, **kwargs) @@ -1003,10 +1003,17 @@ class MSSQLCompiler(compiler.DefaultCompiler): def visit_binary(self, binary, **kwargs): """Move bind parameters to the right-hand side of an operator, where possible.""" + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ and not isinstance(binary.right, expression._BindParamClause): return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) else: + if (binary.operator in (operator.eq, operator.ne)) and ( + (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._SelectBaseMixin)) or \ + (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._SelectBaseMixin)) or \ + isinstance(binary.left, expression._SelectBaseMixin) or isinstance(binary.right, expression._SelectBaseMixin)): + op = binary.operator == operator.eq and "IN" or "NOT IN" + return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) def label_select_column(self, select, column, asfrom): |