summaryrefslogtreecommitdiff
path: root/django/db/models/sql/compiler.py
diff options
context:
space:
mode:
authorSimon Charette <charette.s@gmail.com>2022-08-18 12:30:20 -0400
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-08-30 08:43:53 +0200
commitb3db6c8dcb5145f7d45eff517bcd96460475c879 (patch)
treeca51349fab4db9de0f86dcb315c24caa02ae1e2a /django/db/models/sql/compiler.py
parent5d12650ed9269acb3cba97fd70e8df2e35a55a54 (diff)
downloaddjango-b3db6c8dcb5145f7d45eff517bcd96460475c879.tar.gz
Fixed #21204 -- Tracked field deferrals by field instead of models.
This ensures field deferral works properly when a model is involved more than once in the same query with a distinct deferral mask.
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r--django/db/models/sql/compiler.py52
1 files changed, 28 insertions, 24 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 858142913b..96d10b9eda 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -256,8 +256,9 @@ class SQLCompiler:
select.append((RawSQL(sql, params), alias))
select_idx += 1
assert not (self.query.select and self.query.default_cols)
+ select_mask = self.query.get_select_mask()
if self.query.default_cols:
- cols = self.get_default_columns()
+ cols = self.get_default_columns(select_mask)
else:
# self.query.select is a special case. These columns never go to
# any model.
@@ -278,7 +279,7 @@ class SQLCompiler:
select_idx += 1
if self.query.select_related:
- related_klass_infos = self.get_related_selections(select)
+ related_klass_infos = self.get_related_selections(select, select_mask)
klass_info["related_klass_infos"] = related_klass_infos
def get_select_from_parent(klass_info):
@@ -870,7 +871,9 @@ class SQLCompiler:
# Finally do cleanup - get rid of the joins we created above.
self.query.reset_refcounts(refcounts_before)
- def get_default_columns(self, start_alias=None, opts=None, from_parent=None):
+ def get_default_columns(
+ self, select_mask, start_alias=None, opts=None, from_parent=None
+ ):
"""
Compute the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via
@@ -886,7 +889,6 @@ class SQLCompiler:
if opts is None:
if (opts := self.query.get_meta()) is None:
return result
- only_load = self.deferred_to_columns()
start_alias = start_alias or self.query.get_initial_alias()
# The 'seen_models' is used to optimize checking the needed parent
# alias for a given field. This also includes None -> start_alias to
@@ -912,7 +914,7 @@ class SQLCompiler:
# parent model data is already present in the SELECT clause,
# and we want to avoid reloading the same data again.
continue
- if field.model in only_load and field.attname not in only_load[field.model]:
+ if select_mask and field not in select_mask:
continue
alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
column = field.get_col(alias)
@@ -1063,6 +1065,7 @@ class SQLCompiler:
def get_related_selections(
self,
select,
+ select_mask,
opts=None,
root_alias=None,
cur_depth=1,
@@ -1095,7 +1098,6 @@ class SQLCompiler:
if not opts:
opts = self.query.get_meta()
root_alias = self.query.get_initial_alias()
- only_load = self.deferred_to_columns()
# Setup for the case when only particular related fields should be
# included in the related selection.
@@ -1109,7 +1111,6 @@ class SQLCompiler:
klass_info["related_klass_infos"] = related_klass_infos
for f in opts.fields:
- field_model = f.model._meta.concrete_model
fields_found.add(f.name)
if restricted:
@@ -1129,10 +1130,9 @@ class SQLCompiler:
else:
next = False
- if not select_related_descend(
- f, restricted, requested, only_load.get(field_model)
- ):
+ if not select_related_descend(f, restricted, requested, select_mask):
continue
+ related_select_mask = select_mask.get(f) or {}
klass_info = {
"model": f.remote_field.model,
"field": f,
@@ -1148,7 +1148,7 @@ class SQLCompiler:
_, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
alias = joins[-1]
columns = self.get_default_columns(
- start_alias=alias, opts=f.remote_field.model._meta
+ related_select_mask, start_alias=alias, opts=f.remote_field.model._meta
)
for col in columns:
select_fields.append(len(select))
@@ -1156,6 +1156,7 @@ class SQLCompiler:
klass_info["select_fields"] = select_fields
next_klass_infos = self.get_related_selections(
select,
+ related_select_mask,
f.remote_field.model._meta,
alias,
cur_depth + 1,
@@ -1171,8 +1172,9 @@ class SQLCompiler:
if o.field.unique and not o.many_to_many
]
for f, model in related_fields:
+ related_select_mask = select_mask.get(f) or {}
if not select_related_descend(
- f, restricted, requested, only_load.get(model), reverse=True
+ f, restricted, requested, related_select_mask, reverse=True
):
continue
@@ -1195,7 +1197,10 @@ class SQLCompiler:
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
- start_alias=alias, opts=model._meta, from_parent=opts.model
+ related_select_mask,
+ start_alias=alias,
+ opts=model._meta,
+ from_parent=opts.model,
)
for col in columns:
select_fields.append(len(select))
@@ -1203,7 +1208,13 @@ class SQLCompiler:
klass_info["select_fields"] = select_fields
next = requested.get(f.related_query_name(), {})
next_klass_infos = self.get_related_selections(
- select, model._meta, alias, cur_depth + 1, next, restricted
+ select,
+ related_select_mask,
+ model._meta,
+ alias,
+ cur_depth + 1,
+ next,
+ restricted,
)
get_related_klass_infos(klass_info, next_klass_infos)
@@ -1239,7 +1250,9 @@ class SQLCompiler:
}
related_klass_infos.append(klass_info)
select_fields = []
+ field_select_mask = select_mask.get((name, f)) or {}
columns = self.get_default_columns(
+ field_select_mask,
start_alias=alias,
opts=model._meta,
from_parent=opts.model,
@@ -1251,6 +1264,7 @@ class SQLCompiler:
next_requested = requested.get(name, {})
next_klass_infos = self.get_related_selections(
select,
+ field_select_mask,
opts=model._meta,
root_alias=alias,
cur_depth=cur_depth + 1,
@@ -1377,16 +1391,6 @@ class SQLCompiler:
)
return result
- def deferred_to_columns(self):
- """
- Convert the self.deferred_loading data structure to mapping of table
- names to sets of column names which are to be loaded. Return the
- dictionary.
- """
- columns = {}
- self.query.deferred_to_data(columns)
- return columns
-
def get_converters(self, expressions):
converters = {}
for i, expression in enumerate(expressions):