summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-09-26 02:33:19 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-09-26 02:33:19 +0000
commit1657cea73d5ec9aeedd541001e125e03e581a34b (patch)
treeb1d8527435fa51f7cec399972ea5af29d4f74a67 /lib/sqlalchemy
parente708cfea0bdaae82ac30dd7d33f9442115b9af6d (diff)
parentc86ec8f8c98b756ef06933174a3f4a0f3cfbed41 (diff)
downloadsqlalchemy-1657cea73d5ec9aeedd541001e125e03e581a34b.tar.gz
Merge "`aggregate_order_by` now supports cache generation." into main
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py12
-rw-r--r--lib/sqlalchemy/testing/fixtures.py108
2 files changed, 119 insertions, 1 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
index 0192cf581..ebaad2734 100644
--- a/lib/sqlalchemy/dialects/postgresql/ext.py
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -5,8 +5,10 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
+from __future__ import annotations
from itertools import zip_longest
+from typing import TYPE_CHECKING
from .array import ARRAY
from ...sql import coercions
@@ -16,6 +18,10 @@ from ...sql import functions
from ...sql import roles
from ...sql import schema
from ...sql.schema import ColumnCollectionConstraint
+from ...sql.visitors import InternalTraversal
+
+if TYPE_CHECKING:
+ from ...sql.visitors import _TraverseInternalsType
class aggregate_order_by(expression.ColumnElement):
@@ -56,7 +62,11 @@ class aggregate_order_by(expression.ColumnElement):
__visit_name__ = "aggregate_order_by"
stringify_dialect = "postgresql"
- inherit_cache = False
+ _traverse_internals: _TraverseInternalsType = [
+ ("target", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ]
def __init__(self, target, *order_by):
self.target = coercions.expect(roles.ExpressionElementRole, target)
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index ef284babc..5fb547cbc 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -9,6 +9,7 @@
from __future__ import annotations
+import itertools
import re
import sys
@@ -16,6 +17,8 @@ import sqlalchemy as sa
from . import assertions
from . import config
from . import schema
+from .assertions import eq_
+from .assertions import ne_
from .entities import BasicEntity
from .entities import ComparableEntity
from .entities import ComparableMixin # noqa
@@ -27,6 +30,8 @@ from ..orm import DeclarativeBase
from ..orm import MappedAsDataclass
from ..orm import registry
from ..schema import sort_tables_and_constraints
+from ..sql import visitors
+from ..sql.elements import ClauseElement
@config.mark_base_test_class()
@@ -881,3 +886,106 @@ class ComputedReflectionFixtureTest(TablesTest):
Computed("normal * 42", persisted=True),
)
)
+
+
+class CacheKeyFixture:
+ def _compare_equal(self, a, b, compare_values):
+ a_key = a._generate_cache_key()
+ b_key = b._generate_cache_key()
+
+ if a_key is None:
+ assert a._annotations.get("nocache")
+
+ assert b_key is None
+ else:
+
+ eq_(a_key.key, b_key.key)
+ eq_(hash(a_key.key), hash(b_key.key))
+
+ for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
+ assert a_param.compare(b_param, compare_values=compare_values)
+ return a_key, b_key
+
+ def _run_cache_key_fixture(self, fixture, compare_values):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ if a == b:
+ a_key, b_key = self._compare_equal(
+ case_a[a], case_b[b], compare_values
+ )
+ if a_key is None:
+ continue
+ else:
+ a_key = case_a[a]._generate_cache_key()
+ b_key = case_b[b]._generate_cache_key()
+
+ if a_key is None or b_key is None:
+ if a_key is None:
+ assert case_a[a]._annotations.get("nocache")
+ if b_key is None:
+ assert case_b[b]._annotations.get("nocache")
+ continue
+
+ if a_key.key == b_key.key:
+ for a_param, b_param in zip(
+ a_key.bindparams, b_key.bindparams
+ ):
+ if not a_param.compare(
+ b_param, compare_values=compare_values
+ ):
+ break
+ else:
+ # this fails unconditionally since we could not
+ # find bound parameter values that differed.
+ # Usually we intended to get two distinct keys here
+ # so the failure will be more descriptive using the
+ # ne_() assertion.
+ ne_(a_key.key, b_key.key)
+ else:
+ ne_(a_key.key, b_key.key)
+
+ # ClauseElement-specific test to ensure the cache key
+ # collected all the bound parameters that aren't marked
+ # as "literal execute"
+ if isinstance(case_a[a], ClauseElement) and isinstance(
+ case_b[b], ClauseElement
+ ):
+ assert_a_params = []
+ assert_b_params = []
+
+ for elem in visitors.iterate(case_a[a]):
+ if elem.__visit_name__ == "bindparam":
+ assert_a_params.append(elem)
+
+ for elem in visitors.iterate(case_b[b]):
+ if elem.__visit_name__ == "bindparam":
+ assert_b_params.append(elem)
+
+ # note we're asserting the order of the params as well as
+ # if there are dupes or not. ordering has to be
+ # deterministic and matches what a traversal would provide.
+ eq_(
+ sorted(a_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_a_params), key=lambda b: b.key
+ ),
+ )
+ eq_(
+ sorted(b_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_b_params), key=lambda b: b.key
+ ),
+ )
+
+ def _run_cache_key_equal_fixture(self, fixture, compare_values):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ self._compare_equal(case_a[a], case_b[b], compare_values)