From c8dd7affff3c253256be81b5e912bcdd1359cf1a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 28 Mar 2022 18:39:19 -0400 Subject: apply loader criteria more specifically when refresh is true Fixed bug in :func:`_orm.with_loader_criteria` function where loader criteria would not be applied to a joined eager load that were invoked within the scope of a refresh operation for the parent object. Fixes: #7862 Change-Id: If1ac86eaa95880b5ec5bdeee292d6e8000aac705 (cherry picked from commit 9c52d9a507a738ae68f0a6eae09d87959995b981) --- doc/build/changelog/unreleased_14/7862.rst | 7 ++++ lib/sqlalchemy/orm/context.py | 5 ++- lib/sqlalchemy/orm/util.py | 3 +- test/orm/test_relationship_criteria.py | 61 ++++++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/7862.rst diff --git a/doc/build/changelog/unreleased_14/7862.rst b/doc/build/changelog/unreleased_14/7862.rst new file mode 100644 index 000000000..00252ec8d --- /dev/null +++ b/doc/build/changelog/unreleased_14/7862.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm + :tickets: 7862 + + Fixed bug in :func:`_orm.with_loader_criteria` function where loader + criteria would not be applied to a joined eager load that were invoked + within the scope of a refresh operation for the parent object. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 7a63543a6..49d354cb3 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -2254,7 +2254,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState): single_crit = ext_info.mapper._single_table_criterion - additional_entity_criteria = self._get_extra_criteria(ext_info) + if self.compile_options._for_refresh_state: + additional_entity_criteria = [] + else: + additional_entity_criteria = self._get_extra_criteria(ext_info) if single_crit is not None: additional_entity_criteria += (single_crit,) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 0cd6b8f41..9ec2ad076 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1202,8 +1202,7 @@ class LoaderCriteriaOption(CriteriaOption): "Please migrate code to use the with_polymorphic() standalone " "function before using with_loader_criteria()." ) - if not compile_state.compile_options._for_refresh_state: - self.get_global_criteria(compile_state.global_attributes) + self.get_global_criteria(compile_state.global_attributes) def get_global_criteria(self, attributes): for mp in self._all_mappers(): diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py index 932f80d9f..5f47b49ac 100644 --- a/test/orm/test_relationship_criteria.py +++ b/test/orm/test_relationship_criteria.py @@ -55,6 +55,33 @@ class _Fixtures(_fixtures.FixtureTest): ) return User, Address + @testing.fixture + def user_address_custom_strat_fixture(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + def go(strat): + self.mapper_registry.map_imperatively( + User, + users, + properties={ + "addresses": relationship( + self.mapper_registry.map_imperatively( + Address, addresses + ), + lazy=strat, + order_by=Address.id, + ) + }, + ) + return User, Address + + return go + @testing.fixture def order_item_fixture(self): Order, Item = self.classes("Order", "Item") @@ -220,6 +247,40 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): "WHERE users.name != :name_1", ) + @testing.combinations( + "select", + "joined", + "subquery", + "selectin", + "immediate", + argnames="loader_strategy", + ) + def test_loader_strategy_on_refresh( + self, loader_strategy, user_address_custom_strat_fixture + ): + User, Address = user_address_custom_strat_fixture(loader_strategy) + + sess = fixture_session() + + @event.listens_for(sess, "do_orm_execute") + def add_criteria(orm_context): + orm_context.statement = orm_context.statement.options( + with_loader_criteria( + Address, + ~Address.id.in_([5, 3]), + ) + ) + + u1 = sess.get(User, 7) + u2 = sess.get(User, 8) + eq_(u1.addresses, [Address(id=1)]) + eq_(u2.addresses, [Address(id=2), Address(id=4)]) + + for i in range(3): + sess.expire_all() + eq_(u1.addresses, [Address(id=1)]) + eq_(u2.addresses, [Address(id=2), Address(id=4)]) + def test_criteria_post_replace_legacy(self, user_address_fixture): User, Address = user_address_fixture -- cgit v1.2.1