summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/9635.rst10
-rw-r--r--lib/sqlalchemy/orm/bulk_persistence.py30
-rw-r--r--lib/sqlalchemy/orm/context.py74
-rw-r--r--test/orm/test_relationship_criteria.py236
4 files changed, 315 insertions, 35 deletions
diff --git a/doc/build/changelog/unreleased_20/9635.rst b/doc/build/changelog/unreleased_20/9635.rst
new file mode 100644
index 000000000..73281c7e1
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9635.rst
@@ -0,0 +1,10 @@
+.. change::
+ :tags: bug, orm
+ :tickets: 9635
+
+ Made an improvement to the :func:`_orm.with_loader_criteria` loader option
+ to allow it to be indicated in the :meth:`.Executable.options` method of a
+ top-level statement that is not itself an ORM statement. Examples include
+ :func:`_sql.select` that's embedded in compound statements such as
+ :func:`_sql.union`, within an :meth:`_dml.Insert.from_select` construct, as
+ well as within CTE expressions that are not ORM related at the top level.
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py
index 1b3cce47a..f9d9d6a43 100644
--- a/lib/sqlalchemy/orm/bulk_persistence.py
+++ b/lib/sqlalchemy/orm/bulk_persistence.py
@@ -1346,15 +1346,14 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
self.mapper = mapper = ext_info.mapper
- self.extra_criteria_entities = {}
-
self._resolved_values = self._get_resolved_values(mapper, statement)
- extra_criteria_attributes = {}
-
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(extra_criteria_attributes)
+ self._init_global_attributes(
+ statement,
+ compiler,
+ toplevel=True,
+ process_criteria_for_toplevel=True,
+ )
if statement._values:
self._resolved_values = dict(self._resolved_values)
@@ -1372,7 +1371,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
new_stmt._values = self._resolved_values
new_crit = self._adjust_for_extra_criteria(
- extra_criteria_attributes, mapper
+ self.global_attributes, mapper
)
if new_crit:
new_stmt = new_stmt.where(*new_crit)
@@ -1741,19 +1740,18 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
- self.extra_criteria_entities = {}
-
- extra_criteria_attributes = {}
-
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(extra_criteria_attributes)
+ self._init_global_attributes(
+ statement,
+ compiler,
+ toplevel=True,
+ process_criteria_for_toplevel=True,
+ )
new_stmt = statement._clone()
new_stmt.table = mapper.local_table
new_crit = cls._adjust_for_extra_criteria(
- extra_criteria_attributes, mapper
+ self.global_attributes, mapper
)
if new_crit:
new_stmt = new_stmt.where(*new_crit)
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 2b45b5adc..e778c4840 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -209,6 +209,45 @@ _orm_load_exec_options = util.immutabledict(
class AbstractORMCompileState(CompileState):
is_dml_returning = False
+ def _init_global_attributes(
+ self, statement, compiler, *, toplevel, process_criteria_for_toplevel
+ ):
+ self.attributes = {}
+
+ if compiler is None:
+ # this is the legacy / testing only ORM _compile_state() use case.
+ # there is no need to apply criteria options for this.
+ self.global_attributes = ga = {}
+ assert toplevel
+ return
+ else:
+ self.global_attributes = ga = compiler._global_attributes
+
+ if toplevel:
+ ga["toplevel_orm"] = True
+
+ if process_criteria_for_toplevel:
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.process_compile_state(self)
+
+ return
+ elif ga.get("toplevel_orm", False):
+ return
+
+ stack_0 = compiler.stack[0]
+
+ try:
+ toplevel_stmt = stack_0["selectable"]
+ except KeyError:
+ pass
+ else:
+ for opt in toplevel_stmt._with_options:
+ if opt._is_compile_state and opt._is_criteria_option:
+ opt.process_compile_state(self)
+
+ ga["toplevel_orm"] = True
+
@classmethod
def create_for_statement(
cls,
@@ -622,17 +661,13 @@ class ORMFromStatementCompileState(ORMCompileState):
assert isinstance(statement_container, FromStatement)
- if compiler is not None:
- toplevel = not compiler.stack
- else:
- toplevel = True
-
- if not toplevel:
+ if compiler is not None and compiler.stack:
raise sa_exc.CompileError(
"The ORM FromStatement construct only supports being "
"invoked as the topmost statement, as it is only intended to "
"define how result rows should be returned."
)
+
self = cls.__new__(cls)
self._primary_entity = None
@@ -680,18 +715,18 @@ class ORMFromStatementCompileState(ORMCompileState):
self.current_path = statement_container._compile_options._current_path
- if toplevel and statement_container._with_options:
- self.attributes = {}
- self.global_attributes = compiler._global_attributes
+ self._init_global_attributes(
+ statement_container,
+ compiler,
+ process_criteria_for_toplevel=False,
+ toplevel=True,
+ )
+ if statement_container._with_options:
for opt in statement_container._with_options:
if opt._is_compile_state:
opt.process_compile_state(self)
- else:
- self.attributes = {}
- self.global_attributes = compiler._global_attributes
-
if statement_container._with_context_options:
for fn, key in statement_container._with_context_options:
fn(self)
@@ -911,10 +946,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
if compiler is not None:
toplevel = not compiler.stack
- self.global_attributes = compiler._global_attributes
else:
toplevel = True
- self.global_attributes = {}
select_statement = statement
@@ -1002,11 +1035,17 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self.eager_order_by = ()
+ self._init_global_attributes(
+ select_statement,
+ compiler,
+ toplevel=toplevel,
+ process_criteria_for_toplevel=False,
+ )
+
if toplevel and (
select_statement._with_options
or select_statement._memoized_select_entities
):
- self.attributes = {}
for (
memoized_entities
@@ -1028,9 +1067,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
if opt._is_compile_state:
opt.process_compile_state(self)
- else:
- self.attributes = {}
-
# uncomment to print out the context.attributes structure
# after it's been set up above
# self._dump_option_struct()
diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py
index 58244c462..c02f7af4c 100644
--- a/test/orm/test_relationship_criteria.py
+++ b/test/orm/test_relationship_criteria.py
@@ -3,10 +3,12 @@ import random
from sqlalchemy import Column
from sqlalchemy import DateTime
+from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import exc as sa_exc
from sqlalchemy import ForeignKey
from sqlalchemy import func
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import literal_column
from sqlalchemy import orm
@@ -14,6 +16,8 @@ from sqlalchemy import select
from sqlalchemy import sql
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy import union
+from sqlalchemy import update
from sqlalchemy.orm import aliased
from sqlalchemy.orm import column_property
from sqlalchemy.orm import defer
@@ -588,6 +592,238 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
"FROM users WHERE users.name != :name_1",
)
+ @testing.variation("style", ["direct_union", "from_statement"])
+ @testing.variation("add_nested_union", [True, False])
+ def test_select_mapper_columns_w_union_mapper_criteria(
+ self, multi_mixin_fixture, style: testing.Variation, add_nested_union
+ ):
+ """test #9635"""
+ HasFoob, Order, Item = multi_mixin_fixture
+
+ stmt = (
+ select(Order.id, Order.description)
+ .where(Order.id > 8)
+ .union(select(Order.id, Order.description).where(Order.id <= 8))
+ )
+
+ if add_nested_union:
+ stmt = union(
+ stmt,
+ union(
+ select(Item.id, Item.description).where(Item.id <= 8),
+ select(Item.id, Item.description).where(Item.id > 8),
+ ),
+ )
+
+ if style.direct_union:
+ stmt = stmt.options(
+ with_loader_criteria(
+ HasFoob,
+ lambda cls: cls.description != "name",
+ include_aliases=True,
+ )
+ )
+ elif style.from_statement:
+
+ stmt = (
+ select(Order.id, Order.description)
+ .from_statement(stmt)
+ .options(
+ with_loader_criteria(
+ HasFoob,
+ lambda cls: cls.description != "name",
+ include_aliases=True,
+ )
+ )
+ )
+
+ else:
+ style.fail()
+
+ if add_nested_union:
+ # the criteria is embedded into all UNIONS regardless of nesting.
+ self.assert_compile(
+ stmt,
+ "(SELECT orders.id, orders.description FROM orders WHERE "
+ "orders.id > :id_1 AND orders.description != :description_1 "
+ "UNION SELECT orders.id, orders.description FROM orders WHERE "
+ "orders.id <= :id_2 AND orders.description != :description_2) "
+ "UNION (SELECT items.id, items.description FROM items WHERE "
+ "items.id <= :id_3 AND items.description != :description_3 "
+ "UNION SELECT items.id, items.description FROM items WHERE "
+ "items.id > :id_4 AND items.description != :description_4)",
+ checkparams={
+ "id_1": 8,
+ "description_1": "name",
+ "id_2": 8,
+ "description_2": "name",
+ "id_3": 8,
+ "description_3": "name",
+ "id_4": 8,
+ "description_4": "name",
+ },
+ )
+ else:
+ self.assert_compile(
+ stmt,
+ "SELECT orders.id, orders.description FROM orders WHERE "
+ "orders.id > :id_1 AND orders.description != :description_1 "
+ "UNION SELECT orders.id, orders.description FROM orders WHERE "
+ "orders.id <= :id_2 AND orders.description != :description_2",
+ checkparams={
+ "description_1": "name",
+ "description_2": "name",
+ "id_1": 8,
+ "id_2": 8,
+ },
+ )
+
+ def test_select_mapper_columns_w_core_dml_mapper_criteria(
+ self, multi_mixin_fixture
+ ):
+ """test #9635"""
+ HasFoob, Order, Item = multi_mixin_fixture
+
+ stmt = (
+ insert(Order)
+ .from_select(
+ ["id", "description"],
+ select(Order.id, Order.description).where(Order.id > 8),
+ )
+ .options(
+ with_loader_criteria(
+ HasFoob,
+ lambda cls: cls.description != "name",
+ include_aliases=True,
+ )
+ )
+ )
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO orders (id, description) SELECT orders.id, "
+ "orders.description FROM orders WHERE orders.id > :id_1 "
+ "AND orders.description != :description_1",
+ checkparams={"description_1": "name", "id_1": 8},
+ )
+
+ @testing.variation("update_is_orm", [True, False])
+ def test_select_mapper_columns_w_core_cte_update_mapper_criteria(
+ self, multi_mixin_fixture, update_is_orm
+ ):
+ """test #9635"""
+ HasFoob, Order, Item = multi_mixin_fixture
+
+ cte = select(Order).cte("pd")
+
+ if update_is_orm:
+ stmt = (
+ update(Order)
+ .where(Order.id == cte.c.id)
+ .values(description="newname")
+ )
+ else:
+ stmt = (
+ update(Order.__table__)
+ .where(Order.__table__.c.id == cte.c.id)
+ .values(description="newname")
+ )
+
+ stmt = stmt.options(
+ with_loader_criteria(
+ HasFoob,
+ lambda cls: cls.description != "name",
+ include_aliases=True,
+ )
+ )
+
+ if update_is_orm:
+ self.assert_compile(
+ stmt,
+ "WITH pd AS (SELECT orders.id AS id, "
+ "orders.user_id AS user_id, "
+ "orders.address_id AS address_id, "
+ "orders.description AS description, orders.isopen AS isopen "
+ "FROM orders WHERE orders.description != %(description_1)s) "
+ "UPDATE orders SET description=%(description)s "
+ "FROM pd WHERE orders.id = pd.id "
+ "AND orders.description != %(description_2)s",
+ dialect="postgresql",
+ checkparams={
+ "description": "newname",
+ "description_1": "name",
+ "description_2": "name",
+ },
+ )
+ else:
+ # non ORM update, no criteria, but criteria still gets rendered
+ # inside the SELECT
+ self.assert_compile(
+ stmt,
+ "WITH pd AS (SELECT orders.id AS id, "
+ "orders.user_id AS user_id, "
+ "orders.address_id AS address_id, "
+ "orders.description AS description, orders.isopen AS isopen "
+ "FROM orders WHERE orders.description != %(description_1)s) "
+ "UPDATE orders SET description=%(description)s "
+ "FROM pd WHERE orders.id = pd.id",
+ dialect="postgresql",
+ checkparams={
+ "description": "newname",
+ "description_1": "name",
+ },
+ )
+
+ @testing.variation("delete_is_orm", [True, False])
+ def test_select_mapper_columns_w_core_cte_delete_mapper_criteria(
+ self, multi_mixin_fixture, delete_is_orm
+ ):
+ """test #9635"""
+ HasFoob, Order, Item = multi_mixin_fixture
+
+ cte = select(Order).cte("pd")
+
+ if delete_is_orm:
+ stmt = delete(Order).where(Order.id == cte.c.id)
+ else:
+ stmt = delete(Order.__table__).where(
+ Order.__table__.c.id == cte.c.id
+ )
+
+ stmt = stmt.options(
+ with_loader_criteria(
+ HasFoob,
+ lambda cls: cls.description != "name",
+ include_aliases=True,
+ )
+ )
+
+ if delete_is_orm:
+ self.assert_compile(
+ stmt,
+ "WITH pd AS (SELECT orders.id AS id, orders.user_id AS "
+ "user_id, orders.address_id AS address_id, "
+ "orders.description AS description, orders.isopen AS isopen "
+ "FROM orders WHERE orders.description != %(description_1)s) "
+ "DELETE FROM orders USING pd WHERE orders.id = pd.id "
+ "AND orders.description != %(description_2)s",
+ dialect="postgresql",
+ checkparams={"description_1": "name", "description_2": "name"},
+ )
+ else:
+ # non ORM update, no criteria, but criteria still gets rendered
+ # inside the SELECT
+ self.assert_compile(
+ stmt,
+ "WITH pd AS (SELECT orders.id AS id, orders.user_id AS "
+ "user_id, orders.address_id AS address_id, "
+ "orders.description AS description, orders.isopen AS isopen "
+ "FROM orders WHERE orders.description != %(description_1)s) "
+ "DELETE FROM orders USING pd WHERE orders.id = pd.id",
+ dialect="postgresql",
+ checkparams={"description_1": "name"},
+ )
+
def test_select_join_mapper_mapper_criteria(self, user_address_fixture):
User, Address = user_address_fixture