summaryrefslogtreecommitdiff
path: root/django/db/models/sql/query.py
diff options
context:
space:
mode:
authorSimon Charette <charette.s@gmail.com>2022-08-18 12:30:20 -0400
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-08-30 08:43:53 +0200
commitb3db6c8dcb5145f7d45eff517bcd96460475c879 (patch)
treeca51349fab4db9de0f86dcb315c24caa02ae1e2a /django/db/models/sql/query.py
parent5d12650ed9269acb3cba97fd70e8df2e35a55a54 (diff)
downloaddjango-b3db6c8dcb5145f7d45eff517bcd96460475c879.tar.gz
Fixed #21204 -- Tracked field deferrals by field instead of models.
This ensures field deferral works properly when a model is involved more than once in the same query with a distinct deferral mask.
Diffstat (limited to 'django/db/models/sql/query.py')
-rw-r--r--django/db/models/sql/query.py153
1 files changed, 63 insertions, 90 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 14ed0c0a63..8419dc0d54 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -718,7 +718,61 @@ class Query(BaseExpression):
self.order_by = rhs.order_by or self.order_by
self.extra_order_by = rhs.extra_order_by or self.extra_order_by
- def deferred_to_data(self, target):
+ def _get_defer_select_mask(self, opts, mask, select_mask=None):
+ if select_mask is None:
+ select_mask = {}
+ select_mask[opts.pk] = {}
+ # All concrete fields that are not part of the defer mask must be
+ # loaded. If a relational field is encountered it gets added to the
+ # mask for it be considered if `select_related` and the cycle continues
+ # by recursively caling this function.
+ for field in opts.concrete_fields:
+ field_mask = mask.pop(field.name, None)
+ if field_mask is None:
+ select_mask.setdefault(field, {})
+ elif field_mask:
+ if not field.is_relation:
+ raise FieldError(next(iter(field_mask)))
+ field_select_mask = select_mask.setdefault(field, {})
+ related_model = field.remote_field.model._meta.concrete_model
+ self._get_defer_select_mask(
+ related_model._meta, field_mask, field_select_mask
+ )
+ # Remaining defer entries must be references to reverse relationships.
+ # The following code is expected to raise FieldError if it encounters
+ # a malformed defer entry.
+ for field_name, field_mask in mask.items():
+ if filtered_relation := self._filtered_relations.get(field_name):
+ relation = opts.get_field(filtered_relation.relation_name)
+ field_select_mask = select_mask.setdefault((field_name, relation), {})
+ field = relation.field
+ else:
+ field = opts.get_field(field_name).field
+ field_select_mask = select_mask.setdefault(field, {})
+ related_model = field.model._meta.concrete_model
+ self._get_defer_select_mask(
+ related_model._meta, field_mask, field_select_mask
+ )
+ return select_mask
+
+ def _get_only_select_mask(self, opts, mask, select_mask=None):
+ if select_mask is None:
+ select_mask = {}
+ select_mask[opts.pk] = {}
+ # Only include fields mentioned in the mask.
+ for field_name, field_mask in mask.items():
+ field = opts.get_field(field_name)
+ field_select_mask = select_mask.setdefault(field, {})
+ if field_mask:
+ if not field.is_relation:
+ raise FieldError(next(iter(field_mask)))
+ related_model = field.remote_field.model._meta.concrete_model
+ self._get_only_select_mask(
+ related_model._meta, field_mask, field_select_mask
+ )
+ return select_mask
+
+ def get_select_mask(self):
"""
Convert the self.deferred_loading data structure to an alternate data
structure, describing the field that *will* be loaded. This is used to
@@ -726,81 +780,19 @@ class Query(BaseExpression):
QuerySet class to work out which fields are being initialized on each
model. Models that have all their fields included aren't mentioned in
the result, only those that have field restrictions in place.
-
- The "target" parameter is the instance that is populated (in place).
"""
field_names, defer = self.deferred_loading
if not field_names:
- return
- orig_opts = self.get_meta()
- seen = {}
- must_include = {orig_opts.concrete_model: {orig_opts.pk}}
+ return {}
+ mask = {}
for field_name in field_names:
- parts = field_name.split(LOOKUP_SEP)
- cur_model = self.model._meta.concrete_model
- opts = orig_opts
- for name in parts[:-1]:
- old_model = cur_model
- if name in self._filtered_relations:
- name = self._filtered_relations[name].relation_name
- source = opts.get_field(name)
- if is_reverse_o2o(source):
- cur_model = source.related_model
- else:
- cur_model = source.remote_field.model
- cur_model = cur_model._meta.concrete_model
- opts = cur_model._meta
- # Even if we're "just passing through" this model, we must add
- # both the current model's pk and the related reference field
- # (if it's not a reverse relation) to the things we select.
- if not is_reverse_o2o(source):
- must_include[old_model].add(source)
- add_to_dict(must_include, cur_model, opts.pk)
- field = opts.get_field(parts[-1])
- is_reverse_object = field.auto_created and not field.concrete
- model = field.related_model if is_reverse_object else field.model
- model = model._meta.concrete_model
- if model == opts.model:
- model = cur_model
- if not is_reverse_o2o(field):
- add_to_dict(seen, model, field)
-
+ part_mask = mask
+ for part in field_name.split(LOOKUP_SEP):
+ part_mask = part_mask.setdefault(part, {})
+ opts = self.get_meta()
if defer:
- # We need to load all fields for each model, except those that
- # appear in "seen" (for all models that appear in "seen"). The only
- # slight complexity here is handling fields that exist on parent
- # models.
- workset = {}
- for model, values in seen.items():
- for field in model._meta.local_fields:
- if field not in values:
- m = field.model._meta.concrete_model
- add_to_dict(workset, m, field)
- for model, values in must_include.items():
- # If we haven't included a model in workset, we don't add the
- # corresponding must_include fields for that model, since an
- # empty set means "include all fields". That's why there's no
- # "else" branch here.
- if model in workset:
- workset[model].update(values)
- for model, fields in workset.items():
- target[model] = {f.attname for f in fields}
- else:
- for model, values in must_include.items():
- if model in seen:
- seen[model].update(values)
- else:
- # As we've passed through this model, but not explicitly
- # included any fields, we have to make sure it's mentioned
- # so that only the "must include" fields are pulled in.
- seen[model] = values
- # Now ensure that every model in the inheritance chain is mentioned
- # in the parent list. Again, it must be mentioned to ensure that
- # only "must include" fields are pulled in.
- for model in orig_opts.get_parent_list():
- seen.setdefault(model, set())
- for model, fields in seen.items():
- target[model] = {f.attname for f in fields}
+ return self._get_defer_select_mask(opts, mask)
+ return self._get_only_select_mask(opts, mask)
def table_alias(self, table_name, create=False, filtered_relation=None):
"""
@@ -2583,25 +2575,6 @@ def get_order_dir(field, default="ASC"):
return field, dirn[0]
-def add_to_dict(data, key, value):
- """
- Add "value" to the set of values for "key", whether or not "key" already
- exists.
- """
- if key in data:
- data[key].add(value)
- else:
- data[key] = {value}
-
-
-def is_reverse_o2o(field):
- """
- Check if the given field is reverse-o2o. The field is expected to be some
- sort of relation field or related object.
- """
- return field.is_relation and field.one_to_one and not field.concrete
-
-
class JoinPromoter:
"""
A class to abstract away join promotion problems for complex filter