summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-03-29 14:05:28 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-03-29 14:05:28 +0000
commit6c9fa2dc8667ea5de66023dc7140881befa788da (patch)
tree1e3ecf2aa9d53c6506eb6be7a032a6c3a700c79d
parent25ac5edaba54a7880fe782dfabe6c4559557ae5d (diff)
parentc8dd7affff3c253256be81b5e912bcdd1359cf1a (diff)
downloadsqlalchemy-6c9fa2dc8667ea5de66023dc7140881befa788da.tar.gz
Merge "apply loader criteria more specifically when refresh is true" into rel_1_4
-rw-r--r--doc/build/changelog/unreleased_14/7862.rst7
-rw-r--r--lib/sqlalchemy/orm/context.py5
-rw-r--r--lib/sqlalchemy/orm/util.py3
-rw-r--r--test/orm/test_relationship_criteria.py61
4 files changed, 73 insertions, 3 deletions
diff --git a/doc/build/changelog/unreleased_14/7862.rst b/doc/build/changelog/unreleased_14/7862.rst
new file mode 100644
index 000000000..00252ec8d
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/7862.rst
@@ -0,0 +1,7 @@
+.. change::
+ :tags: bug, orm
+ :tickets: 7862
+
+ Fixed bug in :func:`_orm.with_loader_criteria` function where loader
+ criteria would not be applied to a joined eager load that were invoked
+ within the scope of a refresh operation for the parent object.
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 7a63543a6..49d354cb3 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -2254,7 +2254,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
single_crit = ext_info.mapper._single_table_criterion
- additional_entity_criteria = self._get_extra_criteria(ext_info)
+ if self.compile_options._for_refresh_state:
+ additional_entity_criteria = []
+ else:
+ additional_entity_criteria = self._get_extra_criteria(ext_info)
if single_crit is not None:
additional_entity_criteria += (single_crit,)
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 0cd6b8f41..9ec2ad076 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1202,8 +1202,7 @@ class LoaderCriteriaOption(CriteriaOption):
"Please migrate code to use the with_polymorphic() standalone "
"function before using with_loader_criteria()."
)
- if not compile_state.compile_options._for_refresh_state:
- self.get_global_criteria(compile_state.global_attributes)
+ self.get_global_criteria(compile_state.global_attributes)
def get_global_criteria(self, attributes):
for mp in self._all_mappers():
diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py
index 932f80d9f..5f47b49ac 100644
--- a/test/orm/test_relationship_criteria.py
+++ b/test/orm/test_relationship_criteria.py
@@ -56,6 +56,33 @@ class _Fixtures(_fixtures.FixtureTest):
return User, Address
@testing.fixture
+ def user_address_custom_strat_fixture(self):
+ users, Address, addresses, User = (
+ self.tables.users,
+ self.classes.Address,
+ self.tables.addresses,
+ self.classes.User,
+ )
+
+ def go(strat):
+ self.mapper_registry.map_imperatively(
+ User,
+ users,
+ properties={
+ "addresses": relationship(
+ self.mapper_registry.map_imperatively(
+ Address, addresses
+ ),
+ lazy=strat,
+ order_by=Address.id,
+ )
+ },
+ )
+ return User, Address
+
+ return go
+
+ @testing.fixture
def order_item_fixture(self):
Order, Item = self.classes("Order", "Item")
orders, items, order_items = self.tables(
@@ -220,6 +247,40 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
"WHERE users.name != :name_1",
)
+ @testing.combinations(
+ "select",
+ "joined",
+ "subquery",
+ "selectin",
+ "immediate",
+ argnames="loader_strategy",
+ )
+ def test_loader_strategy_on_refresh(
+ self, loader_strategy, user_address_custom_strat_fixture
+ ):
+ User, Address = user_address_custom_strat_fixture(loader_strategy)
+
+ sess = fixture_session()
+
+ @event.listens_for(sess, "do_orm_execute")
+ def add_criteria(orm_context):
+ orm_context.statement = orm_context.statement.options(
+ with_loader_criteria(
+ Address,
+ ~Address.id.in_([5, 3]),
+ )
+ )
+
+ u1 = sess.get(User, 7)
+ u2 = sess.get(User, 8)
+ eq_(u1.addresses, [Address(id=1)])
+ eq_(u2.addresses, [Address(id=2), Address(id=4)])
+
+ for i in range(3):
+ sess.expire_all()
+ eq_(u1.addresses, [Address(id=1)])
+ eq_(u2.addresses, [Address(id=2), Address(id=4)])
+
def test_criteria_post_replace_legacy(self, user_address_fixture):
User, Address = user_address_fixture