summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/mssql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/databases/mssql.py')
-rw-r--r--lib/sqlalchemy/databases/mssql.py33
1 files changed, 20 insertions, 13 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index f86a95548..329109828 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -922,12 +922,15 @@ class MSSQLCompiler(compiler.DefaultCompiler):
def get_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
- if not self.dialect.has_window_funcs:
+ if select._distinct or select._limit:
s = select._distinct and "DISTINCT " or ""
+
if select._limit:
- s += "TOP %s " % (select._limit,)
- if select._offset:
- raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
+ if not select._offset:
+ s += "TOP %s " % (select._limit,)
+ else:
+ if not self.dialect.has_window_funcs:
+ raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
return compiler.DefaultCompiler.get_select_precolumns(self, select)
@@ -938,13 +941,13 @@ class MSSQLCompiler(compiler.DefaultCompiler):
def visit_select(self, select, **kwargs):
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
so tries to wrap it in a subquery with ``row_number()`` criterion.
+
"""
- if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and (select._limit is not None or select._offset is not None):
+ if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and (select._offset is not None):
# to use ROW_NUMBER(), an ORDER BY is required.
orderby = self.process(select._order_by_clause)
if not orderby:
- orderby = list(select.oid_column.proxies)[0]
- orderby = self.process(orderby)
+ raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
_offset = select._offset
_limit = select._limit
@@ -952,12 +955,9 @@ class MSSQLCompiler(compiler.DefaultCompiler):
select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
- if _offset is not None:
- limitselect.append_whereclause("mssql_rn>=%d" % _offset)
- if _limit is not None:
- limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
- else:
- limitselect.append_whereclause("mssql_rn<=%d" % _limit)
+ limitselect.append_whereclause("mssql_rn>%d" % _offset)
+ if _limit is not None:
+ limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
return self.process(limitselect, iswrapper=True, **kwargs)
else:
return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
@@ -1003,10 +1003,17 @@ class MSSQLCompiler(compiler.DefaultCompiler):
def visit_binary(self, binary, **kwargs):
"""Move bind parameters to the right-hand side of an operator, where possible."""
+
if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
and not isinstance(binary.right, expression._BindParamClause):
return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
else:
+ if (binary.operator in (operator.eq, operator.ne)) and (
+ (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._SelectBaseMixin)) or \
+ (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._SelectBaseMixin)) or \
+ isinstance(binary.left, expression._SelectBaseMixin) or isinstance(binary.right, expression._SelectBaseMixin)):
+ op = binary.operator == operator.eq and "IN" or "NOT IN"
+ return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
def label_select_column(self, select, column, asfrom):