From c86ec8f8c98b756ef06933174a3f4a0f3cfbed41 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sun, 25 Sep 2022 16:37:15 +0200 Subject: `aggregate_order_by` now supports cache generation. also adjusted CacheKeyFixture to be a general purpose fixture so that sub-components / dialects can run their own cache key tests. Fixes: #8574 Change-Id: I6c66107856aee11e548d357cea77bceee3e316a0 --- lib/sqlalchemy/testing/fixtures.py | 108 +++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) (limited to 'lib/sqlalchemy/testing/fixtures.py') diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 20dee5273..2a5f97dbb 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -9,6 +9,7 @@ from __future__ import annotations +import itertools import re import sys @@ -16,6 +17,8 @@ import sqlalchemy as sa from . import assertions from . import config from . import schema +from .assertions import eq_ +from .assertions import ne_ from .entities import BasicEntity from .entities import ComparableEntity from .entities import ComparableMixin # noqa @@ -28,6 +31,8 @@ from ..orm import DeclarativeBase from ..orm import MappedAsDataclass from ..orm import registry from ..schema import sort_tables_and_constraints +from ..sql import visitors +from ..sql.elements import ClauseElement @config.mark_base_test_class() @@ -881,3 +886,106 @@ class ComputedReflectionFixtureTest(TablesTest): Computed("normal * 42", persisted=True), ) ) + + +class CacheKeyFixture: + def _compare_equal(self, a, b, compare_values): + a_key = a._generate_cache_key() + b_key = b._generate_cache_key() + + if a_key is None: + assert a._annotations.get("nocache") + + assert b_key is None + else: + + eq_(a_key.key, b_key.key) + eq_(hash(a_key.key), hash(b_key.key)) + + for a_param, b_param in zip(a_key.bindparams, b_key.bindparams): + assert a_param.compare(b_param, compare_values=compare_values) + return a_key, b_key + + def _run_cache_key_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + a_key, b_key = self._compare_equal( + case_a[a], case_b[b], compare_values + ) + if a_key is None: + continue + else: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + + if a_key is None or b_key is None: + if a_key is None: + assert case_a[a]._annotations.get("nocache") + if b_key is None: + assert case_b[b]._annotations.get("nocache") + continue + + if a_key.key == b_key.key: + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + if not a_param.compare( + b_param, compare_values=compare_values + ): + break + else: + # this fails unconditionally since we could not + # find bound parameter values that differed. + # Usually we intended to get two distinct keys here + # so the failure will be more descriptive using the + # ne_() assertion. + ne_(a_key.key, b_key.key) + else: + ne_(a_key.key, b_key.key) + + # ClauseElement-specific test to ensure the cache key + # collected all the bound parameters that aren't marked + # as "literal execute" + if isinstance(case_a[a], ClauseElement) and isinstance( + case_b[b], ClauseElement + ): + assert_a_params = [] + assert_b_params = [] + + for elem in visitors.iterate(case_a[a]): + if elem.__visit_name__ == "bindparam": + assert_a_params.append(elem) + + for elem in visitors.iterate(case_b[b]): + if elem.__visit_name__ == "bindparam": + assert_b_params.append(elem) + + # note we're asserting the order of the params as well as + # if there are dupes or not. ordering has to be + # deterministic and matches what a traversal would provide. + eq_( + sorted(a_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_a_params), key=lambda b: b.key + ), + ) + eq_( + sorted(b_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_b_params), key=lambda b: b.key + ), + ) + + def _run_cache_key_equal_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + self._compare_equal(case_a[a], case_b[b], compare_values) -- cgit v1.2.1