diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-09-26 02:33:19 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-09-26 02:33:19 +0000 |
commit | 1657cea73d5ec9aeedd541001e125e03e581a34b (patch) | |
tree | b1d8527435fa51f7cec399972ea5af29d4f74a67 /lib/sqlalchemy | |
parent | e708cfea0bdaae82ac30dd7d33f9442115b9af6d (diff) | |
parent | c86ec8f8c98b756ef06933174a3f4a0f3cfbed41 (diff) | |
download | sqlalchemy-1657cea73d5ec9aeedd541001e125e03e581a34b.tar.gz |
Merge "`aggregate_order_by` now supports cache generation." into main
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/ext.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 108 |
2 files changed, 119 insertions, 1 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 0192cf581..ebaad2734 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -5,8 +5,10 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors +from __future__ import annotations from itertools import zip_longest +from typing import TYPE_CHECKING from .array import ARRAY from ...sql import coercions @@ -16,6 +18,10 @@ from ...sql import functions from ...sql import roles from ...sql import schema from ...sql.schema import ColumnCollectionConstraint +from ...sql.visitors import InternalTraversal + +if TYPE_CHECKING: + from ...sql.visitors import _TraverseInternalsType class aggregate_order_by(expression.ColumnElement): @@ -56,7 +62,11 @@ class aggregate_order_by(expression.ColumnElement): __visit_name__ = "aggregate_order_by" stringify_dialect = "postgresql" - inherit_cache = False + _traverse_internals: _TraverseInternalsType = [ + ("target", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ("order_by", InternalTraversal.dp_clauseelement), + ] def __init__(self, target, *order_by): self.target = coercions.expect(roles.ExpressionElementRole, target) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index ef284babc..5fb547cbc 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -9,6 +9,7 @@ from __future__ import annotations +import itertools import re import sys @@ -16,6 +17,8 @@ import sqlalchemy as sa from . import assertions from . import config from . import schema +from .assertions import eq_ +from .assertions import ne_ from .entities import BasicEntity from .entities import ComparableEntity from .entities import ComparableMixin # noqa @@ -27,6 +30,8 @@ from ..orm import DeclarativeBase from ..orm import MappedAsDataclass from ..orm import registry from ..schema import sort_tables_and_constraints +from ..sql import visitors +from ..sql.elements import ClauseElement @config.mark_base_test_class() @@ -881,3 +886,106 @@ class ComputedReflectionFixtureTest(TablesTest): Computed("normal * 42", persisted=True), ) ) + + +class CacheKeyFixture: + def _compare_equal(self, a, b, compare_values): + a_key = a._generate_cache_key() + b_key = b._generate_cache_key() + + if a_key is None: + assert a._annotations.get("nocache") + + assert b_key is None + else: + + eq_(a_key.key, b_key.key) + eq_(hash(a_key.key), hash(b_key.key)) + + for a_param, b_param in zip(a_key.bindparams, b_key.bindparams): + assert a_param.compare(b_param, compare_values=compare_values) + return a_key, b_key + + def _run_cache_key_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + a_key, b_key = self._compare_equal( + case_a[a], case_b[b], compare_values + ) + if a_key is None: + continue + else: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + + if a_key is None or b_key is None: + if a_key is None: + assert case_a[a]._annotations.get("nocache") + if b_key is None: + assert case_b[b]._annotations.get("nocache") + continue + + if a_key.key == b_key.key: + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + if not a_param.compare( + b_param, compare_values=compare_values + ): + break + else: + # this fails unconditionally since we could not + # find bound parameter values that differed. + # Usually we intended to get two distinct keys here + # so the failure will be more descriptive using the + # ne_() assertion. + ne_(a_key.key, b_key.key) + else: + ne_(a_key.key, b_key.key) + + # ClauseElement-specific test to ensure the cache key + # collected all the bound parameters that aren't marked + # as "literal execute" + if isinstance(case_a[a], ClauseElement) and isinstance( + case_b[b], ClauseElement + ): + assert_a_params = [] + assert_b_params = [] + + for elem in visitors.iterate(case_a[a]): + if elem.__visit_name__ == "bindparam": + assert_a_params.append(elem) + + for elem in visitors.iterate(case_b[b]): + if elem.__visit_name__ == "bindparam": + assert_b_params.append(elem) + + # note we're asserting the order of the params as well as + # if there are dupes or not. ordering has to be + # deterministic and matches what a traversal would provide. + eq_( + sorted(a_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_a_params), key=lambda b: b.key + ), + ) + eq_( + sorted(b_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_b_params), key=lambda b: b.key + ), + ) + + def _run_cache_key_equal_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + self._compare_equal(case_a[a], case_b[b], compare_values) |