summaryrefslogtreecommitdiff
path: root/django/db/models/sql/compiler.py
diff options
context:
space:
mode:
authordjango-bot <ops@djangoproject.com>2022-02-03 20:24:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-02-07 20:37:05 +0100
commit9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch)
treef0506b668a013d0063e5fba3dbf4863b466713ba /django/db/models/sql/compiler.py
parentf68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff)
downloaddjango-9c19aff7c7561e3a82978a272ecdaad40dda5c00.tar.gz
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r--django/db/models/sql/compiler.py770
1 files changed, 487 insertions, 283 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index d405a203ee..13a7ec7263 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -11,7 +11,12 @@ from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
from django.db.models.functions import Cast, Random
from django.db.models.query_utils import select_related_descend
from django.db.models.sql.constants import (
- CURSOR, GET_ITERATOR_CHUNK_SIZE, MULTI, NO_RESULTS, ORDER_DIR, SINGLE,
+ CURSOR,
+ GET_ITERATOR_CHUNK_SIZE,
+ MULTI,
+ NO_RESULTS,
+ ORDER_DIR,
+ SINGLE,
)
from django.db.models.sql.query import Query, get_order_dir
from django.db.transaction import TransactionManagementError
@@ -23,7 +28,7 @@ from django.utils.regex_helper import _lazy_re_compile
class SQLCompiler:
# Multiline ordering SQL clause may appear from RawSQL.
ordering_parts = _lazy_re_compile(
- r'^(.*)\s(?:ASC|DESC).*',
+ r"^(.*)\s(?:ASC|DESC).*",
re.MULTILINE | re.DOTALL,
)
@@ -34,7 +39,7 @@ class SQLCompiler:
# Some queries, e.g. coalesced aggregation, need to be executed even if
# they would return an empty result set.
self.elide_empty = elide_empty
- self.quote_cache = {'*': '*'}
+ self.quote_cache = {"*": "*"}
# The select, klass_info, and annotations are needed by QuerySet.iterator()
# these are set as a side-effect of executing the query. Note that we calculate
# separately a list of extra select columns needed for grammatical correctness
@@ -46,9 +51,9 @@ class SQLCompiler:
def __repr__(self):
return (
- f'<{self.__class__.__qualname__} '
- f'model={self.query.model.__qualname__} '
- f'connection={self.connection!r} using={self.using!r}>'
+ f"<{self.__class__.__qualname__} "
+ f"model={self.query.model.__qualname__} "
+ f"connection={self.connection!r} using={self.using!r}>"
)
def setup_query(self):
@@ -118,16 +123,14 @@ class SQLCompiler:
# when we have public API way of forcing the GROUP BY clause.
# Converts string references to expressions.
for expr in self.query.group_by:
- if not hasattr(expr, 'as_sql'):
+ if not hasattr(expr, "as_sql"):
expressions.append(self.query.resolve_ref(expr))
else:
expressions.append(expr)
# Note that even if the group_by is set, it is only the minimal
# set to group by. So, we need to add cols in select, order_by, and
# having into the select in any case.
- ref_sources = {
- expr.source for expr in expressions if isinstance(expr, Ref)
- }
+ ref_sources = {expr.source for expr in expressions if isinstance(expr, Ref)}
for expr, _, _ in select:
# Skip members of the select clause that are already included
# by reference.
@@ -169,8 +172,10 @@ class SQLCompiler:
for expr in expressions:
# Is this a reference to query's base table primary key? If the
# expression isn't a Col-like, then skip the expression.
- if (getattr(expr, 'target', None) == self.query.model._meta.pk and
- getattr(expr, 'alias', None) == self.query.base_table):
+ if (
+ getattr(expr, "target", None) == self.query.model._meta.pk
+ and getattr(expr, "alias", None) == self.query.base_table
+ ):
pk = expr
break
# If the main model's primary key is in the query, group by that
@@ -178,13 +183,17 @@ class SQLCompiler:
# that don't have a primary key included in the grouped columns.
if pk:
pk_aliases = {
- expr.alias for expr in expressions
- if hasattr(expr, 'target') and expr.target.primary_key
+ expr.alias
+ for expr in expressions
+ if hasattr(expr, "target") and expr.target.primary_key
}
expressions = [pk] + [
- expr for expr in expressions
- if expr in having or (
- getattr(expr, 'alias', None) is not None and expr.alias not in pk_aliases
+ expr
+ for expr in expressions
+ if expr in having
+ or (
+ getattr(expr, "alias", None) is not None
+ and expr.alias not in pk_aliases
)
]
elif self.connection.features.allows_group_by_selected_pks:
@@ -195,16 +204,21 @@ class SQLCompiler:
# Unmanaged models are excluded because they could be representing
# database views on which the optimization might not be allowed.
pks = {
- expr for expr in expressions
+ expr
+ for expr in expressions
if (
- hasattr(expr, 'target') and
- expr.target.primary_key and
- self.connection.features.allows_group_by_selected_pks_on_model(expr.target.model)
+ hasattr(expr, "target")
+ and expr.target.primary_key
+ and self.connection.features.allows_group_by_selected_pks_on_model(
+ expr.target.model
+ )
)
}
aliases = {expr.alias for expr in pks}
expressions = [
- expr for expr in expressions if expr in pks or getattr(expr, 'alias', None) not in aliases
+ expr
+ for expr in expressions
+ if expr in pks or getattr(expr, "alias", None) not in aliases
]
return expressions
@@ -248,8 +262,8 @@ class SQLCompiler:
select.append((col, None))
select_idx += 1
klass_info = {
- 'model': self.query.model,
- 'select_fields': select_list,
+ "model": self.query.model,
+ "select_fields": select_list,
}
for alias, annotation in self.query.annotation_select.items():
annotations[alias] = select_idx
@@ -258,14 +272,16 @@ class SQLCompiler:
if self.query.select_related:
related_klass_infos = self.get_related_selections(select)
- klass_info['related_klass_infos'] = related_klass_infos
+ klass_info["related_klass_infos"] = related_klass_infos
def get_select_from_parent(klass_info):
- for ki in klass_info['related_klass_infos']:
- if ki['from_parent']:
- ki['select_fields'] = (klass_info['select_fields'] +
- ki['select_fields'])
+ for ki in klass_info["related_klass_infos"]:
+ if ki["from_parent"]:
+ ki["select_fields"] = (
+ klass_info["select_fields"] + ki["select_fields"]
+ )
get_select_from_parent(ki)
+
get_select_from_parent(klass_info)
ret = []
@@ -273,10 +289,12 @@ class SQLCompiler:
try:
sql, params = self.compile(col)
except EmptyResultSet:
- empty_result_set_value = getattr(col, 'empty_result_set_value', NotImplemented)
+ empty_result_set_value = getattr(
+ col, "empty_result_set_value", NotImplemented
+ )
if empty_result_set_value is NotImplemented:
# Select a predicate that's always False.
- sql, params = '0', ()
+ sql, params = "0", ()
else:
sql, params = self.compile(Value(empty_result_set_value))
else:
@@ -297,12 +315,12 @@ class SQLCompiler:
else:
ordering = []
if self.query.standard_ordering:
- default_order, _ = ORDER_DIR['ASC']
+ default_order, _ = ORDER_DIR["ASC"]
else:
- default_order, _ = ORDER_DIR['DESC']
+ default_order, _ = ORDER_DIR["DESC"]
for field in ordering:
- if hasattr(field, 'resolve_expression'):
+ if hasattr(field, "resolve_expression"):
if isinstance(field, Value):
# output_field must be resolved for constants.
field = Cast(field, field.output_field)
@@ -313,12 +331,12 @@ class SQLCompiler:
field.reverse_ordering()
yield field, False
continue
- if field == '?': # random
+ if field == "?": # random
yield OrderBy(Random()), False
continue
col, order = get_order_dir(field, default_order)
- descending = order == 'DESC'
+ descending = order == "DESC"
if col in self.query.annotation_select:
# Reference to expression in SELECT clause
@@ -345,13 +363,15 @@ class SQLCompiler:
yield OrderBy(expr, descending=descending), False
continue
- if '.' in field:
+ if "." in field:
# This came in through an extra(order_by=...) addition. Pass it
# on verbatim.
- table, col = col.split('.', 1)
+ table, col = col.split(".", 1)
yield (
OrderBy(
- RawSQL('%s.%s' % (self.quote_name_unless_alias(table), col), []),
+ RawSQL(
+ "%s.%s" % (self.quote_name_unless_alias(table), col), []
+ ),
descending=descending,
),
False,
@@ -361,7 +381,10 @@ class SQLCompiler:
if self.query.extra and col in self.query.extra:
if col in self.query.extra_select:
yield (
- OrderBy(Ref(col, RawSQL(*self.query.extra[col])), descending=descending),
+ OrderBy(
+ Ref(col, RawSQL(*self.query.extra[col])),
+ descending=descending,
+ ),
True,
)
else:
@@ -378,7 +401,9 @@ class SQLCompiler:
# 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc.
yield from self.find_ordering_name(
- field, self.query.get_meta(), default_order=default_order,
+ field,
+ self.query.get_meta(),
+ default_order=default_order,
)
def get_order_by(self):
@@ -409,19 +434,21 @@ class SQLCompiler:
):
continue
if src == sel_expr:
- resolved.set_source_expressions([RawSQL('%d' % (idx + 1), ())])
+ resolved.set_source_expressions([RawSQL("%d" % (idx + 1), ())])
break
else:
if col_alias:
- raise DatabaseError('ORDER BY term does not match any column in the result set.')
+ raise DatabaseError(
+ "ORDER BY term does not match any column in the result set."
+ )
# Add column used in ORDER BY clause to the selected
# columns and to each combined query.
order_by_idx = len(self.query.select) + 1
- col_name = f'__orderbycol{order_by_idx}'
+ col_name = f"__orderbycol{order_by_idx}"
for q in self.query.combined_queries:
q.add_annotation(expr_src, col_name)
self.query.add_select_col(resolved, col_name)
- resolved.set_source_expressions([RawSQL(f'{order_by_idx}', ())])
+ resolved.set_source_expressions([RawSQL(f"{order_by_idx}", ())])
sql, params = self.compile(resolved)
# Don't add the same column twice, but the order direction is
# not taken into account so we strip it. When this entire method
@@ -453,9 +480,14 @@ class SQLCompiler:
"""
if name in self.quote_cache:
return self.quote_cache[name]
- if ((name in self.query.alias_map and name not in self.query.table_map) or
- name in self.query.extra_select or (
- self.query.external_aliases.get(name) and name not in self.query.table_map)):
+ if (
+ (name in self.query.alias_map and name not in self.query.table_map)
+ or name in self.query.extra_select
+ or (
+ self.query.external_aliases.get(name)
+ and name not in self.query.table_map
+ )
+ ):
self.quote_cache[name] = name
return name
r = self.connection.ops.quote_name(name)
@@ -463,7 +495,7 @@ class SQLCompiler:
return r
def compile(self, node):
- vendor_impl = getattr(node, 'as_' + self.connection.vendor, None)
+ vendor_impl = getattr(node, "as_" + self.connection.vendor, None)
if vendor_impl:
sql, params = vendor_impl(self, self.connection)
else:
@@ -474,14 +506,19 @@ class SQLCompiler:
features = self.connection.features
compilers = [
query.get_compiler(self.using, self.connection, self.elide_empty)
- for query in self.query.combined_queries if not query.is_empty()
+ for query in self.query.combined_queries
+ if not query.is_empty()
]
if not features.supports_slicing_ordering_in_compound:
for query, compiler in zip(self.query.combined_queries, compilers):
if query.low_mark or query.high_mark:
- raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.')
+ raise DatabaseError(
+ "LIMIT/OFFSET not allowed in subqueries of compound statements."
+ )
if compiler.get_order_by():
- raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.')
+ raise DatabaseError(
+ "ORDER BY not allowed in subqueries of compound statements."
+ )
parts = ()
for compiler in compilers:
try:
@@ -490,41 +527,45 @@ class SQLCompiler:
# the query on all combined queries, if not already set.
if not compiler.query.values_select and self.query.values_select:
compiler.query = compiler.query.clone()
- compiler.query.set_values((
- *self.query.extra_select,
- *self.query.values_select,
- *self.query.annotation_select,
- ))
+ compiler.query.set_values(
+ (
+ *self.query.extra_select,
+ *self.query.values_select,
+ *self.query.annotation_select,
+ )
+ )
part_sql, part_args = compiler.as_sql()
if compiler.query.combinator:
# Wrap in a subquery if wrapping in parentheses isn't
# supported.
if not features.supports_parentheses_in_compound:
- part_sql = 'SELECT * FROM ({})'.format(part_sql)
+ part_sql = "SELECT * FROM ({})".format(part_sql)
# Add parentheses when combining with compound query if not
# already added for all compound queries.
elif (
- self.query.subquery or
- not features.supports_slicing_ordering_in_compound
+ self.query.subquery
+ or not features.supports_slicing_ordering_in_compound
):
- part_sql = '({})'.format(part_sql)
+ part_sql = "({})".format(part_sql)
parts += ((part_sql, part_args),)
except EmptyResultSet:
# Omit the empty queryset with UNION and with DIFFERENCE if the
# first queryset is nonempty.
- if combinator == 'union' or (combinator == 'difference' and parts):
+ if combinator == "union" or (combinator == "difference" and parts):
continue
raise
if not parts:
raise EmptyResultSet
combinator_sql = self.connection.ops.set_operators[combinator]
- if all and combinator == 'union':
- combinator_sql += ' ALL'
- braces = '{}'
+ if all and combinator == "union":
+ combinator_sql += " ALL"
+ braces = "{}"
if not self.query.subquery and features.supports_slicing_ordering_in_compound:
- braces = '({})'
- sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts))
- result = [' {} '.format(combinator_sql).join(sql_parts)]
+ braces = "({})"
+ sql_parts, args_parts = zip(
+ *((braces.format(sql), args) for sql, args in parts)
+ )
+ result = [" {} ".format(combinator_sql).join(sql_parts)]
params = []
for part in args_parts:
params.extend(part)
@@ -543,27 +584,39 @@ class SQLCompiler:
extra_select, order_by, group_by = self.pre_sql_setup()
for_update_part = None
# Is a LIMIT/OFFSET clause needed?
- with_limit_offset = with_limits and (self.query.high_mark is not None or self.query.low_mark)
+ with_limit_offset = with_limits and (
+ self.query.high_mark is not None or self.query.low_mark
+ )
combinator = self.query.combinator
features = self.connection.features
if combinator:
- if not getattr(features, 'supports_select_{}'.format(combinator)):
- raise NotSupportedError('{} is not supported on this database backend.'.format(combinator))
- result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
+ if not getattr(features, "supports_select_{}".format(combinator)):
+ raise NotSupportedError(
+ "{} is not supported on this database backend.".format(
+ combinator
+ )
+ )
+ result, params = self.get_combinator_sql(
+ combinator, self.query.combinator_all
+ )
else:
distinct_fields, distinct_params = self.get_distinct()
# This must come after 'select', 'ordering', and 'distinct'
# (see docstring of get_from_clause() for details).
from_, f_params = self.get_from_clause()
try:
- where, w_params = self.compile(self.where) if self.where is not None else ('', [])
+ where, w_params = (
+ self.compile(self.where) if self.where is not None else ("", [])
+ )
except EmptyResultSet:
if self.elide_empty:
raise
# Use a predicate that's always False.
- where, w_params = '0 = 1', []
- having, h_params = self.compile(self.having) if self.having is not None else ("", [])
- result = ['SELECT']
+ where, w_params = "0 = 1", []
+ having, h_params = (
+ self.compile(self.having) if self.having is not None else ("", [])
+ )
+ result = ["SELECT"]
params = []
if self.query.distinct:
@@ -578,27 +631,38 @@ class SQLCompiler:
col_idx = 1
for _, (s_sql, s_params), alias in self.select + extra_select:
if alias:
- s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias))
+ s_sql = "%s AS %s" % (
+ s_sql,
+ self.connection.ops.quote_name(alias),
+ )
elif with_col_aliases:
- s_sql = '%s AS %s' % (
+ s_sql = "%s AS %s" % (
s_sql,
- self.connection.ops.quote_name('col%d' % col_idx),
+ self.connection.ops.quote_name("col%d" % col_idx),
)
col_idx += 1
params.extend(s_params)
out_cols.append(s_sql)
- result += [', '.join(out_cols), 'FROM', *from_]
+ result += [", ".join(out_cols), "FROM", *from_]
params.extend(f_params)
- if self.query.select_for_update and self.connection.features.has_select_for_update:
+ if (
+ self.query.select_for_update
+ and self.connection.features.has_select_for_update
+ ):
if self.connection.get_autocommit():
- raise TransactionManagementError('select_for_update cannot be used outside of a transaction.')
+ raise TransactionManagementError(
+ "select_for_update cannot be used outside of a transaction."
+ )
- if with_limit_offset and not self.connection.features.supports_select_for_update_with_limit:
+ if (
+ with_limit_offset
+ and not self.connection.features.supports_select_for_update_with_limit
+ ):
raise NotSupportedError(
- 'LIMIT/OFFSET is not supported with '
- 'select_for_update on this database backend.'
+ "LIMIT/OFFSET is not supported with "
+ "select_for_update on this database backend."
)
nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
@@ -607,16 +671,31 @@ class SQLCompiler:
# If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the
# backend doesn't support it, raise NotSupportedError to
# prevent a possible deadlock.
- if nowait and not self.connection.features.has_select_for_update_nowait:
- raise NotSupportedError('NOWAIT is not supported on this database backend.')
- elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
- raise NotSupportedError('SKIP LOCKED is not supported on this database backend.')
+ if (
+ nowait
+ and not self.connection.features.has_select_for_update_nowait
+ ):
+ raise NotSupportedError(
+ "NOWAIT is not supported on this database backend."
+ )
+ elif (
+ skip_locked
+ and not self.connection.features.has_select_for_update_skip_locked
+ ):
+ raise NotSupportedError(
+ "SKIP LOCKED is not supported on this database backend."
+ )
elif of and not self.connection.features.has_select_for_update_of:
- raise NotSupportedError('FOR UPDATE OF is not supported on this database backend.')
- elif no_key and not self.connection.features.has_select_for_no_key_update:
raise NotSupportedError(
- 'FOR NO KEY UPDATE is not supported on this '
- 'database backend.'
+ "FOR UPDATE OF is not supported on this database backend."
+ )
+ elif (
+ no_key
+ and not self.connection.features.has_select_for_no_key_update
+ ):
+ raise NotSupportedError(
+ "FOR NO KEY UPDATE is not supported on this "
+ "database backend."
)
for_update_part = self.connection.ops.for_update_sql(
nowait=nowait,
@@ -629,7 +708,7 @@ class SQLCompiler:
result.append(for_update_part)
if where:
- result.append('WHERE %s' % where)
+ result.append("WHERE %s" % where)
params.extend(w_params)
grouping = []
@@ -638,30 +717,39 @@ class SQLCompiler:
params.extend(g_params)
if grouping:
if distinct_fields:
- raise NotImplementedError('annotate() + distinct(fields) is not implemented.')
+ raise NotImplementedError(
+ "annotate() + distinct(fields) is not implemented."
+ )
order_by = order_by or self.connection.ops.force_no_ordering()
- result.append('GROUP BY %s' % ', '.join(grouping))
+ result.append("GROUP BY %s" % ", ".join(grouping))
if self._meta_ordering:
order_by = None
if having:
- result.append('HAVING %s' % having)
+ result.append("HAVING %s" % having)
params.extend(h_params)
if self.query.explain_info:
- result.insert(0, self.connection.ops.explain_query_prefix(
- self.query.explain_info.format,
- **self.query.explain_info.options
- ))
+ result.insert(
+ 0,
+ self.connection.ops.explain_query_prefix(
+ self.query.explain_info.format,
+ **self.query.explain_info.options,
+ ),
+ )
if order_by:
ordering = []
for _, (o_sql, o_params, _) in order_by:
ordering.append(o_sql)
params.extend(o_params)
- result.append('ORDER BY %s' % ', '.join(ordering))
+ result.append("ORDER BY %s" % ", ".join(ordering))
if with_limit_offset:
- result.append(self.connection.ops.limit_offset_sql(self.query.low_mark, self.query.high_mark))
+ result.append(
+ self.connection.ops.limit_offset_sql(
+ self.query.low_mark, self.query.high_mark
+ )
+ )
if for_update_part and not self.connection.features.for_update_after_from:
result.append(for_update_part)
@@ -677,23 +765,30 @@ class SQLCompiler:
sub_params = []
for index, (select, _, alias) in enumerate(self.select, start=1):
if not alias and with_col_aliases:
- alias = 'col%d' % index
+ alias = "col%d" % index
if alias:
- sub_selects.append("%s.%s" % (
- self.connection.ops.quote_name('subquery'),
- self.connection.ops.quote_name(alias),
- ))
+ sub_selects.append(
+ "%s.%s"
+ % (
+ self.connection.ops.quote_name("subquery"),
+ self.connection.ops.quote_name(alias),
+ )
+ )
else:
- select_clone = select.relabeled_clone({select.alias: 'subquery'})
- subselect, subparams = select_clone.as_sql(self, self.connection)
+ select_clone = select.relabeled_clone(
+ {select.alias: "subquery"}
+ )
+ subselect, subparams = select_clone.as_sql(
+ self, self.connection
+ )
sub_selects.append(subselect)
sub_params.extend(subparams)
- return 'SELECT %s FROM (%s) subquery' % (
- ', '.join(sub_selects),
- ' '.join(result),
+ return "SELECT %s FROM (%s) subquery" % (
+ ", ".join(sub_selects),
+ " ".join(result),
), tuple(sub_params + params)
- return ' '.join(result), tuple(params)
+ return " ".join(result), tuple(params)
finally:
# Finally do cleanup - get rid of the joins we created above.
self.query.reset_refcounts(refcounts_before)
@@ -726,8 +821,13 @@ class SQLCompiler:
# will assign None if the field belongs to this model.
if model == opts.model:
model = None
- if from_parent and model is not None and issubclass(
- from_parent._meta.concrete_model, model._meta.concrete_model):
+ if (
+ from_parent
+ and model is not None
+ and issubclass(
+ from_parent._meta.concrete_model, model._meta.concrete_model
+ )
+ ):
# Avoid loading data for already loaded parents.
# We end up here in the case select_related() resolution
# proceeds from parent model to child model. In that case the
@@ -736,8 +836,7 @@ class SQLCompiler:
continue
if field.model in only_load and field.attname not in only_load[field.model]:
continue
- alias = self.query.join_parent_model(opts, model, start_alias,
- seen_models)
+ alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
column = field.get_col(alias)
result.append(column)
return result
@@ -755,7 +854,9 @@ class SQLCompiler:
for name in self.query.distinct_fields:
parts = name.split(LOOKUP_SEP)
- _, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None)
+ _, targets, alias, joins, path, _, transform_function = self._setup_joins(
+ parts, opts, None
+ )
targets, alias, _ = self.query.trim_joins(targets, joins, path)
for target in targets:
if name in self.query.annotation_select:
@@ -766,46 +867,63 @@ class SQLCompiler:
params.append(p)
return result, params
- def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
- already_seen=None):
+ def find_ordering_name(
+ self, name, opts, alias=None, default_order="ASC", already_seen=None
+ ):
"""
Return the table alias (the name might be ambiguous, the alias will
not be) and column name for ordering by the given 'name' parameter.
The 'name' is of the form 'field1__field2__...__fieldN'.
"""
name, order = get_order_dir(name, default_order)
- descending = order == 'DESC'
+ descending = order == "DESC"
pieces = name.split(LOOKUP_SEP)
- field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias)
+ (
+ field,
+ targets,
+ alias,
+ joins,
+ path,
+ opts,
+ transform_function,
+ ) = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model unless it is the pk
# shortcut or the attribute name of the field that is specified.
if (
- field.is_relation and
- opts.ordering and
- getattr(field, 'attname', None) != pieces[-1] and
- name != 'pk'
+ field.is_relation
+ and opts.ordering
+ and getattr(field, "attname", None) != pieces[-1]
+ and name != "pk"
):
# Firstly, avoid infinite loops.
already_seen = already_seen or set()
- join_tuple = tuple(getattr(self.query.alias_map[j], 'join_cols', None) for j in joins)
+ join_tuple = tuple(
+ getattr(self.query.alias_map[j], "join_cols", None) for j in joins
+ )
if join_tuple in already_seen:
- raise FieldError('Infinite loop caused by ordering.')
+ raise FieldError("Infinite loop caused by ordering.")
already_seen.add(join_tuple)
results = []
for item in opts.ordering:
- if hasattr(item, 'resolve_expression') and not isinstance(item, OrderBy):
+ if hasattr(item, "resolve_expression") and not isinstance(
+ item, OrderBy
+ ):
item = item.desc() if descending else item.asc()
if isinstance(item, OrderBy):
results.append((item, False))
continue
- results.extend(self.find_ordering_name(item, opts, alias,
- order, already_seen))
+ results.extend(
+ self.find_ordering_name(item, opts, alias, order, already_seen)
+ )
return results
targets, alias, _ = self.query.trim_joins(targets, joins, path)
- return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets]
+ return [
+ (OrderBy(transform_function(t, alias), descending=descending), False)
+ for t in targets
+ ]
def _setup_joins(self, pieces, opts, alias):
"""
@@ -816,7 +934,9 @@ class SQLCompiler:
match. Executing SQL where this is not true is an error.
"""
alias = alias or self.query.get_initial_alias()
- field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias)
+ field, targets, opts, joins, path, transform_function = self.query.setup_joins(
+ pieces, opts, alias
+ )
alias = joins[-1]
return field, targets, alias, joins, path, opts, transform_function
@@ -850,25 +970,39 @@ class SQLCompiler:
# Only add the alias if it's not already present (the table_alias()
# call increments the refcount, so an alias refcount of one means
# this is the only reference).
- if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1:
- result.append(', %s' % self.quote_name_unless_alias(alias))
+ if (
+ alias not in self.query.alias_map
+ or self.query.alias_refcount[alias] == 1
+ ):
+ result.append(", %s" % self.quote_name_unless_alias(alias))
return result, params
- def get_related_selections(self, select, opts=None, root_alias=None, cur_depth=1,
- requested=None, restricted=None):
+ def get_related_selections(
+ self,
+ select,
+ opts=None,
+ root_alias=None,
+ cur_depth=1,
+ requested=None,
+ restricted=None,
+ ):
"""
Fill in the information needed for a select_related query. The current
depth is measured as the number of connections away from the root model
(for example, cur_depth=1 means we are looking at models with direct
connections to the root model).
"""
+
def _get_field_choices():
direct_choices = (f.name for f in opts.fields if f.is_relation)
reverse_choices = (
f.field.related_query_name()
- for f in opts.related_objects if f.field.unique
+ for f in opts.related_objects
+ if f.field.unique
+ )
+ return chain(
+ direct_choices, reverse_choices, self.query._filtered_relations
)
- return chain(direct_choices, reverse_choices, self.query._filtered_relations)
related_klass_infos = []
if not restricted and cur_depth > self.query.max_depth:
@@ -889,7 +1023,7 @@ class SQLCompiler:
requested = self.query.select_related
def get_related_klass_infos(klass_info, related_klass_infos):
- klass_info['related_klass_infos'] = related_klass_infos
+ klass_info["related_klass_infos"] = related_klass_infos
for f in opts.fields:
field_model = f.model._meta.concrete_model
@@ -903,37 +1037,48 @@ class SQLCompiler:
if next or f.name in requested:
raise FieldError(
"Non-relational field given in select_related: '%s'. "
- "Choices are: %s" % (
+ "Choices are: %s"
+ % (
f.name,
- ", ".join(_get_field_choices()) or '(none)',
+ ", ".join(_get_field_choices()) or "(none)",
)
)
else:
next = False
- if not select_related_descend(f, restricted, requested,
- only_load.get(field_model)):
+ if not select_related_descend(
+ f, restricted, requested, only_load.get(field_model)
+ ):
continue
klass_info = {
- 'model': f.remote_field.model,
- 'field': f,
- 'reverse': False,
- 'local_setter': f.set_cached_value,
- 'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None,
- 'from_parent': False,
+ "model": f.remote_field.model,
+ "field": f,
+ "reverse": False,
+ "local_setter": f.set_cached_value,
+ "remote_setter": f.remote_field.set_cached_value
+ if f.unique
+ else lambda x, y: None,
+ "from_parent": False,
}
related_klass_infos.append(klass_info)
select_fields = []
- _, _, _, joins, _, _ = self.query.setup_joins(
- [f.name], opts, root_alias)
+ _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
alias = joins[-1]
- columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta)
+ columns = self.get_default_columns(
+ start_alias=alias, opts=f.remote_field.model._meta
+ )
for col in columns:
select_fields.append(len(select))
select.append((col, None))
- klass_info['select_fields'] = select_fields
+ klass_info["select_fields"] = select_fields
next_klass_infos = self.get_related_selections(
- select, f.remote_field.model._meta, alias, cur_depth + 1, next, restricted)
+ select,
+ f.remote_field.model._meta,
+ alias,
+ cur_depth + 1,
+ next,
+ restricted,
+ )
get_related_klass_infos(klass_info, next_klass_infos)
if restricted:
@@ -943,36 +1088,40 @@ class SQLCompiler:
if o.field.unique and not o.many_to_many
]
for f, model in related_fields:
- if not select_related_descend(f, restricted, requested,
- only_load.get(model), reverse=True):
+ if not select_related_descend(
+ f, restricted, requested, only_load.get(model), reverse=True
+ ):
continue
related_field_name = f.related_query_name()
fields_found.add(related_field_name)
- join_info = self.query.setup_joins([related_field_name], opts, root_alias)
+ join_info = self.query.setup_joins(
+ [related_field_name], opts, root_alias
+ )
alias = join_info.joins[-1]
from_parent = issubclass(model, opts.model) and model is not opts.model
klass_info = {
- 'model': model,
- 'field': f,
- 'reverse': True,
- 'local_setter': f.remote_field.set_cached_value,
- 'remote_setter': f.set_cached_value,
- 'from_parent': from_parent,
+ "model": model,
+ "field": f,
+ "reverse": True,
+ "local_setter": f.remote_field.set_cached_value,
+ "remote_setter": f.set_cached_value,
+ "from_parent": from_parent,
}
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
- start_alias=alias, opts=model._meta, from_parent=opts.model)
+ start_alias=alias, opts=model._meta, from_parent=opts.model
+ )
for col in columns:
select_fields.append(len(select))
select.append((col, None))
- klass_info['select_fields'] = select_fields
+ klass_info["select_fields"] = select_fields
next = requested.get(f.related_query_name(), {})
next_klass_infos = self.get_related_selections(
- select, model._meta, alias, cur_depth + 1,
- next, restricted)
+ select, model._meta, alias, cur_depth + 1, next, restricted
+ )
get_related_klass_infos(klass_info, next_klass_infos)
def local_setter(obj, from_obj):
@@ -989,32 +1138,40 @@ class SQLCompiler:
break
if name in self.query._filtered_relations:
fields_found.add(name)
- f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias)
+ f, _, join_opts, joins, _, _ = self.query.setup_joins(
+ [name], opts, root_alias
+ )
model = join_opts.model
alias = joins[-1]
- from_parent = issubclass(model, opts.model) and model is not opts.model
+ from_parent = (
+ issubclass(model, opts.model) and model is not opts.model
+ )
klass_info = {
- 'model': model,
- 'field': f,
- 'reverse': True,
- 'local_setter': local_setter,
- 'remote_setter': partial(remote_setter, name),
- 'from_parent': from_parent,
+ "model": model,
+ "field": f,
+ "reverse": True,
+ "local_setter": local_setter,
+ "remote_setter": partial(remote_setter, name),
+ "from_parent": from_parent,
}
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
- start_alias=alias, opts=model._meta,
+ start_alias=alias,
+ opts=model._meta,
from_parent=opts.model,
)
for col in columns:
select_fields.append(len(select))
select.append((col, None))
- klass_info['select_fields'] = select_fields
+ klass_info["select_fields"] = select_fields
next_requested = requested.get(name, {})
next_klass_infos = self.get_related_selections(
- select, opts=model._meta, root_alias=alias,
- cur_depth=cur_depth + 1, requested=next_requested,
+ select,
+ opts=model._meta,
+ root_alias=alias,
+ cur_depth=cur_depth + 1,
+ requested=next_requested,
restricted=restricted,
)
get_related_klass_infos(klass_info, next_klass_infos)
@@ -1022,10 +1179,11 @@ class SQLCompiler:
if fields_not_found:
invalid_fields = ("'%s'" % s for s in fields_not_found)
raise FieldError(
- 'Invalid field name(s) given in select_related: %s. '
- 'Choices are: %s' % (
- ', '.join(invalid_fields),
- ', '.join(_get_field_choices()) or '(none)',
+ "Invalid field name(s) given in select_related: %s. "
+ "Choices are: %s"
+ % (
+ ", ".join(invalid_fields),
+ ", ".join(_get_field_choices()) or "(none)",
)
)
return related_klass_infos
@@ -1035,21 +1193,22 @@ class SQLCompiler:
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
the query.
"""
+
def _get_parent_klass_info(klass_info):
- concrete_model = klass_info['model']._meta.concrete_model
+ concrete_model = klass_info["model"]._meta.concrete_model
for parent_model, parent_link in concrete_model._meta.parents.items():
parent_list = parent_model._meta.get_parent_list()
yield {
- 'model': parent_model,
- 'field': parent_link,
- 'reverse': False,
- 'select_fields': [
+ "model": parent_model,
+ "field": parent_link,
+ "reverse": False,
+ "select_fields": [
select_index
- for select_index in klass_info['select_fields']
+ for select_index in klass_info["select_fields"]
# Selected columns from a model or its parents.
if (
- self.select[select_index][0].target.model == parent_model or
- self.select[select_index][0].target.model in parent_list
+ self.select[select_index][0].target.model == parent_model
+ or self.select[select_index][0].target.model in parent_list
)
],
}
@@ -1062,8 +1221,8 @@ class SQLCompiler:
select_fields is filled recursively, so it also contains fields
from the parent models.
"""
- concrete_model = klass_info['model']._meta.concrete_model
- for select_index in klass_info['select_fields']:
+ concrete_model = klass_info["model"]._meta.concrete_model
+ for select_index in klass_info["select_fields"]:
if self.select[select_index][0].target.model == concrete_model:
return self.select[select_index][0]
@@ -1074,10 +1233,10 @@ class SQLCompiler:
parent_path, klass_info = queue.popleft()
if parent_path is None:
path = []
- yield 'self'
+ yield "self"
else:
- field = klass_info['field']
- if klass_info['reverse']:
+ field = klass_info["field"]
+ if klass_info["reverse"]:
field = field.remote_field
path = parent_path + [field.name]
yield LOOKUP_SEP.join(path)
@@ -1087,25 +1246,26 @@ class SQLCompiler:
)
queue.extend(
(path, klass_info)
- for klass_info in klass_info.get('related_klass_infos', [])
+ for klass_info in klass_info.get("related_klass_infos", [])
)
+
if not self.klass_info:
return []
result = []
invalid_names = []
for name in self.query.select_for_update_of:
klass_info = self.klass_info
- if name == 'self':
+ if name == "self":
col = _get_first_selected_col_from_model(klass_info)
else:
for part in name.split(LOOKUP_SEP):
klass_infos = (
- *klass_info.get('related_klass_infos', []),
+ *klass_info.get("related_klass_infos", []),
*_get_parent_klass_info(klass_info),
)
for related_klass_info in klass_infos:
- field = related_klass_info['field']
- if related_klass_info['reverse']:
+ field = related_klass_info["field"]
+ if related_klass_info["reverse"]:
field = field.remote_field
if field.name == part:
klass_info = related_klass_info
@@ -1124,11 +1284,12 @@ class SQLCompiler:
result.append(self.quote_name_unless_alias(col.alias))
if invalid_names:
raise FieldError(
- 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
- 'Only relational fields followed in the query are allowed. '
- 'Choices are: %s.' % (
- ', '.join(invalid_names),
- ', '.join(_get_field_choices()),
+ "Invalid field name(s) given in select_for_update(of=(...)): %s. "
+ "Only relational fields followed in the query are allowed. "
+ "Choices are: %s."
+ % (
+ ", ".join(invalid_names),
+ ", ".join(_get_field_choices()),
)
)
return result
@@ -1164,12 +1325,19 @@ class SQLCompiler:
row[pos] = value
yield row
- def results_iter(self, results=None, tuple_expected=False, chunked_fetch=False,
- chunk_size=GET_ITERATOR_CHUNK_SIZE):
+ def results_iter(
+ self,
+ results=None,
+ tuple_expected=False,
+ chunked_fetch=False,
+ chunk_size=GET_ITERATOR_CHUNK_SIZE,
+ ):
"""Return an iterator over the results from executing this query."""
if results is None:
- results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size)
- fields = [s[0] for s in self.select[0:self.col_count]]
+ results = self.execute_sql(
+ MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size
+ )
+ fields = [s[0] for s in self.select[0 : self.col_count]]
converters = self.get_converters(fields)
rows = chain.from_iterable(results)
if converters:
@@ -1185,7 +1353,9 @@ class SQLCompiler:
"""
return bool(self.execute_sql(SINGLE))
- def execute_sql(self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE):
+ def execute_sql(
+ self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
+ ):
"""
Run the query against the database and return the result(s). The
return value is a single data item if result_type is SINGLE, or an
@@ -1226,7 +1396,7 @@ class SQLCompiler:
try:
val = cursor.fetchone()
if val:
- return val[0:self.col_count]
+ return val[0 : self.col_count]
return val
finally:
# done with the cursor
@@ -1236,7 +1406,8 @@ class SQLCompiler:
return
result = cursor_iter(
- cursor, self.connection.features.empty_fetchmany_value,
+ cursor,
+ self.connection.features.empty_fetchmany_value,
self.col_count if self.has_extra_select else None,
chunk_size,
)
@@ -1254,21 +1425,22 @@ class SQLCompiler:
for index, select_col in enumerate(self.query.select):
lhs_sql, lhs_params = self.compile(select_col)
- rhs = '%s.%s' % (qn(alias), qn2(columns[index]))
- self.query.where.add(
- RawSQL('%s = %s' % (lhs_sql, rhs), lhs_params), 'AND')
+ rhs = "%s.%s" % (qn(alias), qn2(columns[index]))
+ self.query.where.add(RawSQL("%s = %s" % (lhs_sql, rhs), lhs_params), "AND")
sql, params = self.as_sql()
- return 'EXISTS (%s)' % sql, params
+ return "EXISTS (%s)" % sql, params
def explain_query(self):
result = list(self.execute_sql())
# Some backends return 1 item tuples with strings, and others return
# tuples with integers and strings. Flatten them out into strings.
- output_formatter = json.dumps if self.query.explain_info.format == 'json' else str
+ output_formatter = (
+ json.dumps if self.query.explain_info.format == "json" else str
+ )
for row in result[0]:
if not isinstance(row, str):
- yield ' '.join(output_formatter(c) for c in row)
+ yield " ".join(output_formatter(c) for c in row)
else:
yield row
@@ -1289,16 +1461,16 @@ class SQLInsertCompiler(SQLCompiler):
if field is None:
# A field value of None means the value is raw.
sql, params = val, []
- elif hasattr(val, 'as_sql'):
+ elif hasattr(val, "as_sql"):
# This is an expression, let's compile it.
sql, params = self.compile(val)
- elif hasattr(field, 'get_placeholder'):
+ elif hasattr(field, "get_placeholder"):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
sql, params = field.get_placeholder(val, self, self.connection), [val]
else:
# Return the common case for the placeholder
- sql, params = '%s', [val]
+ sql, params = "%s", [val]
# The following hook is only used by Oracle Spatial, which sometimes
# needs to yield 'NULL' and [] as its placeholder and params instead
@@ -1314,24 +1486,26 @@ class SQLInsertCompiler(SQLCompiler):
Prepare a value to be used in a query by resolving it if it is an
expression and otherwise calling the field's get_db_prep_save().
"""
- if hasattr(value, 'resolve_expression'):
- value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
+ if hasattr(value, "resolve_expression"):
+ value = value.resolve_expression(
+ self.query, allow_joins=False, for_save=True
+ )
# Don't allow values containing Col expressions. They refer to
# existing columns on a row, but in the case of insert the row
# doesn't exist yet.
if value.contains_column_references:
raise ValueError(
'Failed to insert expression "%s" on %s. F() expressions '
- 'can only be used to update, not to insert.' % (value, field)
+ "can only be used to update, not to insert." % (value, field)
)
if value.contains_aggregate:
raise FieldError(
- 'Aggregate functions are not allowed in this query '
- '(%s=%r).' % (field.name, value)
+ "Aggregate functions are not allowed in this query "
+ "(%s=%r)." % (field.name, value)
)
if value.contains_over_clause:
raise FieldError(
- 'Window expressions are not allowed in this query (%s=%r).'
+ "Window expressions are not allowed in this query (%s=%r)."
% (field.name, value)
)
else:
@@ -1390,25 +1564,32 @@ class SQLInsertCompiler(SQLCompiler):
insert_statement = self.connection.ops.insert_statement(
on_conflict=self.query.on_conflict,
)
- result = ['%s %s' % (insert_statement, qn(opts.db_table))]
+ result = ["%s %s" % (insert_statement, qn(opts.db_table))]
fields = self.query.fields or [opts.pk]
- result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
+ result.append("(%s)" % ", ".join(qn(f.column) for f in fields))
if self.query.fields:
value_rows = [
- [self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields]
+ [
+ self.prepare_value(field, self.pre_save_val(field, obj))
+ for field in fields
+ ]
for obj in self.query.objs
]
else:
# An empty object.
- value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs]
+ value_rows = [
+ [self.connection.ops.pk_default_value()] for _ in self.query.objs
+ ]
fields = [None]
# Currently the backends just accept values when generating bulk
# queries and generate their own placeholders. Doing that isn't
# necessary and it should be possible to use placeholders and
# expressions in bulk inserts too.
- can_bulk = (not self.returning_fields and self.connection.features.has_bulk_insert)
+ can_bulk = (
+ not self.returning_fields and self.connection.features.has_bulk_insert
+ )
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
@@ -1418,9 +1599,14 @@ class SQLInsertCompiler(SQLCompiler):
self.query.update_fields,
self.query.unique_fields,
)
- if self.returning_fields and self.connection.features.can_return_columns_from_insert:
+ if (
+ self.returning_fields
+ and self.connection.features.can_return_columns_from_insert
+ ):
if self.connection.features.can_return_rows_from_bulk_insert:
- result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
+ result.append(
+ self.connection.ops.bulk_insert_sql(fields, placeholder_rows)
+ )
params = param_rows
else:
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
@@ -1429,7 +1615,9 @@ class SQLInsertCompiler(SQLCompiler):
result.append(on_conflict_suffix_sql)
# Skip empty r_sql to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
- r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields)
+ r_sql, self.returning_params = self.connection.ops.return_insert_columns(
+ self.returning_fields
+ )
if r_sql:
result.append(r_sql)
params += [self.returning_params]
@@ -1450,8 +1638,9 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, returning_fields=None):
assert not (
- returning_fields and len(self.query.objs) != 1 and
- not self.connection.features.can_return_rows_from_bulk_insert
+ returning_fields
+ and len(self.query.objs) != 1
+ and not self.connection.features.can_return_rows_from_bulk_insert
)
opts = self.query.get_meta()
self.returning_fields = returning_fields
@@ -1460,17 +1649,29 @@ class SQLInsertCompiler(SQLCompiler):
cursor.execute(sql, params)
if not self.returning_fields:
return []
- if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1:
+ if (
+ self.connection.features.can_return_rows_from_bulk_insert
+ and len(self.query.objs) > 1
+ ):
rows = self.connection.ops.fetch_returned_insert_rows(cursor)
elif self.connection.features.can_return_columns_from_insert:
assert len(self.query.objs) == 1
- rows = [self.connection.ops.fetch_returned_insert_columns(
- cursor, self.returning_params,
- )]
+ rows = [
+ self.connection.ops.fetch_returned_insert_columns(
+ cursor,
+ self.returning_params,
+ )
+ ]
else:
- rows = [(self.connection.ops.last_insert_id(
- cursor, opts.db_table, opts.pk.column,
- ),)]
+ rows = [
+ (
+ self.connection.ops.last_insert_id(
+ cursor,
+ opts.db_table,
+ opts.pk.column,
+ ),
+ )
+ ]
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
converters = self.get_converters(cols)
if converters:
@@ -1489,7 +1690,7 @@ class SQLDeleteCompiler(SQLCompiler):
def _expr_refs_base_model(cls, expr, base_model):
if isinstance(expr, Query):
return expr.model == base_model
- if not hasattr(expr, 'get_source_expressions'):
+ if not hasattr(expr, "get_source_expressions"):
return False
return any(
cls._expr_refs_base_model(source_expr, base_model)
@@ -1500,17 +1701,17 @@ class SQLDeleteCompiler(SQLCompiler):
def contains_self_reference_subquery(self):
return any(
self._expr_refs_base_model(expr, self.query.model)
- for expr in chain(self.query.annotations.values(), self.query.where.children)
+ for expr in chain(
+ self.query.annotations.values(), self.query.where.children
+ )
)
def _as_sql(self, query):
- result = [
- 'DELETE FROM %s' % self.quote_name_unless_alias(query.base_table)
- ]
+ result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)]
where, params = self.compile(query.where)
if where:
- result.append('WHERE %s' % where)
- return ' '.join(result), tuple(params)
+ result.append("WHERE %s" % where)
+ return " ".join(result), tuple(params)
def as_sql(self):
"""
@@ -1523,16 +1724,14 @@ class SQLDeleteCompiler(SQLCompiler):
innerq.__class__ = Query
innerq.clear_select_clause()
pk = self.query.model._meta.pk
- innerq.select = [
- pk.get_col(self.query.get_initial_alias())
- ]
+ innerq.select = [pk.get_col(self.query.get_initial_alias())]
outerq = Query(self.query.model)
if not self.connection.features.update_can_self_select:
# Force the materialization of the inner query to allow reference
# to the target table on MySQL.
sql, params = innerq.get_compiler(connection=self.connection).as_sql()
- innerq = RawSQL('SELECT * FROM (%s) subquery' % sql, params)
- outerq.add_filter('pk__in', innerq)
+ innerq = RawSQL("SELECT * FROM (%s) subquery" % sql, params)
+ outerq.add_filter("pk__in", innerq)
return self._as_sql(outerq)
@@ -1544,23 +1743,25 @@ class SQLUpdateCompiler(SQLCompiler):
"""
self.pre_sql_setup()
if not self.query.values:
- return '', ()
+ return "", ()
qn = self.quote_name_unless_alias
values, update_params = [], []
for field, model, val in self.query.values:
- if hasattr(val, 'resolve_expression'):
- val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
+ if hasattr(val, "resolve_expression"):
+ val = val.resolve_expression(
+ self.query, allow_joins=False, for_save=True
+ )
if val.contains_aggregate:
raise FieldError(
- 'Aggregate functions are not allowed in this query '
- '(%s=%r).' % (field.name, val)
+ "Aggregate functions are not allowed in this query "
+ "(%s=%r)." % (field.name, val)
)
if val.contains_over_clause:
raise FieldError(
- 'Window expressions are not allowed in this query '
- '(%s=%r).' % (field.name, val)
+ "Window expressions are not allowed in this query "
+ "(%s=%r)." % (field.name, val)
)
- elif hasattr(val, 'prepare_database_save'):
+ elif hasattr(val, "prepare_database_save"):
if field.remote_field:
val = field.get_db_prep_save(
val.prepare_database_save(field),
@@ -1576,29 +1777,29 @@ class SQLUpdateCompiler(SQLCompiler):
val = field.get_db_prep_save(val, connection=self.connection)
# Getting the placeholder for the field.
- if hasattr(field, 'get_placeholder'):
+ if hasattr(field, "get_placeholder"):
placeholder = field.get_placeholder(val, self, self.connection)
else:
- placeholder = '%s'
+ placeholder = "%s"
name = field.column
- if hasattr(val, 'as_sql'):
+ if hasattr(val, "as_sql"):
sql, params = self.compile(val)
- values.append('%s = %s' % (qn(name), placeholder % sql))
+ values.append("%s = %s" % (qn(name), placeholder % sql))
update_params.extend(params)
elif val is not None:
- values.append('%s = %s' % (qn(name), placeholder))
+ values.append("%s = %s" % (qn(name), placeholder))
update_params.append(val)
else:
- values.append('%s = NULL' % qn(name))
+ values.append("%s = NULL" % qn(name))
table = self.query.base_table
result = [
- 'UPDATE %s SET' % qn(table),
- ', '.join(values),
+ "UPDATE %s SET" % qn(table),
+ ", ".join(values),
]
where, params = self.compile(self.query.where)
if where:
- result.append('WHERE %s' % where)
- return ' '.join(result), tuple(update_params + params)
+ result.append("WHERE %s" % where)
+ return " ".join(result), tuple(update_params + params)
def execute_sql(self, result_type):
"""
@@ -1644,7 +1845,9 @@ class SQLUpdateCompiler(SQLCompiler):
query.add_fields([query.get_meta().pk.name])
super().pre_sql_setup()
- must_pre_select = count > 1 and not self.connection.features.update_can_self_select
+ must_pre_select = (
+ count > 1 and not self.connection.features.update_can_self_select
+ )
# Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select).
@@ -1656,11 +1859,11 @@ class SQLUpdateCompiler(SQLCompiler):
idents = []
for rows in query.get_compiler(self.using).execute_sql(MULTI):
idents.extend(r[0] for r in rows)
- self.query.add_filter('pk__in', idents)
+ self.query.add_filter("pk__in", idents)
self.query.related_ids = idents
else:
# The fast path. Filters and updates in one query.
- self.query.add_filter('pk__in', query)
+ self.query.add_filter("pk__in", query)
self.query.reset_refcounts(refcounts_before)
@@ -1677,13 +1880,14 @@ class SQLAggregateCompiler(SQLCompiler):
sql.append(ann_sql)
params.extend(ann_params)
self.col_count = len(self.query.annotation_select)
- sql = ', '.join(sql)
+ sql = ", ".join(sql)
params = tuple(params)
inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
- self.using, elide_empty=self.elide_empty,
+ self.using,
+ elide_empty=self.elide_empty,
).as_sql(with_col_aliases=True)
- sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql)
+ sql = "SELECT %s FROM (%s) subquery" % (sql, inner_query_sql)
params = params + inner_query_params
return sql, params