diff options
author | Simon Charette <charette.s@gmail.com> | 2022-08-18 12:30:20 -0400 |
---|---|---|
committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-08-30 08:43:53 +0200 |
commit | b3db6c8dcb5145f7d45eff517bcd96460475c879 (patch) | |
tree | ca51349fab4db9de0f86dcb315c24caa02ae1e2a /django/db/models/sql/compiler.py | |
parent | 5d12650ed9269acb3cba97fd70e8df2e35a55a54 (diff) | |
download | django-b3db6c8dcb5145f7d45eff517bcd96460475c879.tar.gz |
Fixed #21204 -- Tracked field deferrals by field instead of models.
This ensures field deferral works properly when a model is involved
more than once in the same query with a distinct deferral mask.
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r-- | django/db/models/sql/compiler.py | 52 |
1 files changed, 28 insertions, 24 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 858142913b..96d10b9eda 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -256,8 +256,9 @@ class SQLCompiler: select.append((RawSQL(sql, params), alias)) select_idx += 1 assert not (self.query.select and self.query.default_cols) + select_mask = self.query.get_select_mask() if self.query.default_cols: - cols = self.get_default_columns() + cols = self.get_default_columns(select_mask) else: # self.query.select is a special case. These columns never go to # any model. @@ -278,7 +279,7 @@ class SQLCompiler: select_idx += 1 if self.query.select_related: - related_klass_infos = self.get_related_selections(select) + related_klass_infos = self.get_related_selections(select, select_mask) klass_info["related_klass_infos"] = related_klass_infos def get_select_from_parent(klass_info): @@ -870,7 +871,9 @@ class SQLCompiler: # Finally do cleanup - get rid of the joins we created above. self.query.reset_refcounts(refcounts_before) - def get_default_columns(self, start_alias=None, opts=None, from_parent=None): + def get_default_columns( + self, select_mask, start_alias=None, opts=None, from_parent=None + ): """ Compute the default columns for selecting every field in the base model. Will sometimes be called to pull in related models (e.g. via @@ -886,7 +889,6 @@ class SQLCompiler: if opts is None: if (opts := self.query.get_meta()) is None: return result - only_load = self.deferred_to_columns() start_alias = start_alias or self.query.get_initial_alias() # The 'seen_models' is used to optimize checking the needed parent # alias for a given field. This also includes None -> start_alias to @@ -912,7 +914,7 @@ class SQLCompiler: # parent model data is already present in the SELECT clause, # and we want to avoid reloading the same data again. continue - if field.model in only_load and field.attname not in only_load[field.model]: + if select_mask and field not in select_mask: continue alias = self.query.join_parent_model(opts, model, start_alias, seen_models) column = field.get_col(alias) @@ -1063,6 +1065,7 @@ class SQLCompiler: def get_related_selections( self, select, + select_mask, opts=None, root_alias=None, cur_depth=1, @@ -1095,7 +1098,6 @@ class SQLCompiler: if not opts: opts = self.query.get_meta() root_alias = self.query.get_initial_alias() - only_load = self.deferred_to_columns() # Setup for the case when only particular related fields should be # included in the related selection. @@ -1109,7 +1111,6 @@ class SQLCompiler: klass_info["related_klass_infos"] = related_klass_infos for f in opts.fields: - field_model = f.model._meta.concrete_model fields_found.add(f.name) if restricted: @@ -1129,10 +1130,9 @@ class SQLCompiler: else: next = False - if not select_related_descend( - f, restricted, requested, only_load.get(field_model) - ): + if not select_related_descend(f, restricted, requested, select_mask): continue + related_select_mask = select_mask.get(f) or {} klass_info = { "model": f.remote_field.model, "field": f, @@ -1148,7 +1148,7 @@ class SQLCompiler: _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias) alias = joins[-1] columns = self.get_default_columns( - start_alias=alias, opts=f.remote_field.model._meta + related_select_mask, start_alias=alias, opts=f.remote_field.model._meta ) for col in columns: select_fields.append(len(select)) @@ -1156,6 +1156,7 @@ class SQLCompiler: klass_info["select_fields"] = select_fields next_klass_infos = self.get_related_selections( select, + related_select_mask, f.remote_field.model._meta, alias, cur_depth + 1, @@ -1171,8 +1172,9 @@ class SQLCompiler: if o.field.unique and not o.many_to_many ] for f, model in related_fields: + related_select_mask = select_mask.get(f) or {} if not select_related_descend( - f, restricted, requested, only_load.get(model), reverse=True + f, restricted, requested, related_select_mask, reverse=True ): continue @@ -1195,7 +1197,10 @@ class SQLCompiler: related_klass_infos.append(klass_info) select_fields = [] columns = self.get_default_columns( - start_alias=alias, opts=model._meta, from_parent=opts.model + related_select_mask, + start_alias=alias, + opts=model._meta, + from_parent=opts.model, ) for col in columns: select_fields.append(len(select)) @@ -1203,7 +1208,13 @@ class SQLCompiler: klass_info["select_fields"] = select_fields next = requested.get(f.related_query_name(), {}) next_klass_infos = self.get_related_selections( - select, model._meta, alias, cur_depth + 1, next, restricted + select, + related_select_mask, + model._meta, + alias, + cur_depth + 1, + next, + restricted, ) get_related_klass_infos(klass_info, next_klass_infos) @@ -1239,7 +1250,9 @@ class SQLCompiler: } related_klass_infos.append(klass_info) select_fields = [] + field_select_mask = select_mask.get((name, f)) or {} columns = self.get_default_columns( + field_select_mask, start_alias=alias, opts=model._meta, from_parent=opts.model, @@ -1251,6 +1264,7 @@ class SQLCompiler: next_requested = requested.get(name, {}) next_klass_infos = self.get_related_selections( select, + field_select_mask, opts=model._meta, root_alias=alias, cur_depth=cur_depth + 1, @@ -1377,16 +1391,6 @@ class SQLCompiler: ) return result - def deferred_to_columns(self): - """ - Convert the self.deferred_loading data structure to mapping of table - names to sets of column names which are to be loaded. Return the - dictionary. - """ - columns = {} - self.query.deferred_to_data(columns) - return columns - def get_converters(self, expressions): converters = {} for i, expression in enumerate(expressions): |