summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-07-07 18:58:18 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-07-07 18:58:18 +0000
commita6d8b674e92ef1cabdb2ab85490397f3ed12a42c (patch)
treeda1bc2912b3465939bee40a5f649053977ca85c7
parentbcb2421e9faaab8ce48e2731b9a2f7411204f393 (diff)
parent6eea9ca437084feae6a7b00276547e70ef6b40ad (diff)
downloadsqlalchemy-a6d8b674e92ef1cabdb2ab85490397f3ed12a42c.tar.gz
Merge "ensure we unwrap desc() /label() all the way w/ order by"
-rw-r--r--lib/sqlalchemy/sql/util.py20
-rw-r--r--test/orm/test_deprecations.py15
-rw-r--r--test/sql/test_utils.py32
3 files changed, 65 insertions, 2 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index e8726000b..b803ef912 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -27,6 +27,7 @@ from .elements import _textual_label_reference
from .elements import BindParameter
from .elements import ColumnClause
from .elements import ColumnElement
+from .elements import Label
from .elements import Null
from .elements import UnaryExpression
from .schema import Column
@@ -279,14 +280,31 @@ def unwrap_order_by(clause):
cols = util.column_set()
result = []
stack = deque([clause])
+
+ # examples
+ # column -> ASC/DESC == column
+ # column -> ASC/DESC -> label == column
+ # column -> label -> ASC/DESC -> label == column
+ # scalar_select -> label -> ASC/DESC == scalar_select -> label
+
while stack:
t = stack.popleft()
if isinstance(t, ColumnElement) and (
not isinstance(t, UnaryExpression)
or not operators.is_ordering_modifier(t.modifier)
):
- if isinstance(t, _label_reference):
+ if isinstance(t, Label) and not isinstance(
+ t.element, ScalarSelect
+ ):
+ t = t.element
+
+ stack.append(t)
+ continue
+ elif isinstance(t, _label_reference):
t = t.element
+
+ stack.append(t)
+ continue
if isinstance(t, (_textual_label_reference)):
continue
if t not in cols:
diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py
index 4e9f50661..4cfded25c 100644
--- a/test/orm/test_deprecations.py
+++ b/test/orm/test_deprecations.py
@@ -1597,6 +1597,21 @@ class DistinctOrderByImplicitTest(QueryTest, AssertsCompiledSQL):
):
eq_([User(id=7), User(id=9), User(id=8)], q.all())
+ def test_columns_augmented_roundtrip_two(self):
+ User, Address = self.classes.User, self.classes.Address
+
+ sess = create_session()
+ q = (
+ sess.query(User)
+ .join("addresses")
+ .distinct()
+ .order_by(desc(Address.email_address).label("foo"))
+ )
+ with testing.expect_deprecated(
+ "ORDER BY columns added implicitly due to "
+ ):
+ eq_([User(id=7), User(id=9), User(id=8)], q.all())
+
def test_columns_augmented_roundtrip_three(self):
User, Address = self.classes.User, self.classes.Address
diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py
index d68a74475..676ad4298 100644
--- a/test/sql/test_utils.py
+++ b/test/sql/test_utils.py
@@ -4,9 +4,14 @@ from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import util
from sqlalchemy.sql import base as sql_base
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import column
+from sqlalchemy.sql import ColumnElement
+from sqlalchemy.sql import roles
from sqlalchemy.sql import util as sql_util
-from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
@@ -89,3 +94,28 @@ class MiscTest(fixtures.TestBase):
eq_(o4.bat, "hi")
assert_raises(TypeError, opt2.safe_merge, o4)
+
+ @testing.combinations(
+ (column("q"), [column("q")]),
+ (column("q").desc(), [column("q")]),
+ (column("q").desc().label(None), [column("q")]),
+ (column("q").label(None).desc(), [column("q")]),
+ (column("q").label(None).desc().label(None), [column("q")]),
+ ("foo", []), # textual label reference
+ (
+ select([column("q")]).scalar_subquery().label(None),
+ [select([column("q")]).scalar_subquery().label(None)],
+ ),
+ (
+ select([column("q")]).scalar_subquery().label(None).desc(),
+ [select([column("q")]).scalar_subquery().label(None)],
+ ),
+ )
+ def test_unwrap_order_by(self, expr, expected):
+
+ expr = coercions.expect(roles.OrderByRole, expr)
+
+ unwrapped = sql_util.unwrap_order_by(expr)
+
+ for a, b in util.zip_longest(unwrapped, expected):
+ assert a is not None and a.compare(b)