diff options
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r-- | django/db/models/sql/compiler.py | 76 |
1 files changed, 75 insertions, 1 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 1c0ab2d212..858142913b 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -9,6 +9,7 @@ from django.db import DatabaseError, NotSupportedError from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value from django.db.models.functions import Cast, Random +from django.db.models.lookups import Lookup from django.db.models.query_utils import select_related_descend from django.db.models.sql.constants import ( CURSOR, @@ -73,7 +74,9 @@ class SQLCompiler: """ self.setup_query(with_col_aliases=with_col_aliases) order_by = self.get_order_by() - self.where, self.having = self.query.where.split_having() + self.where, self.having, self.qualify = self.query.where.split_having_qualify( + must_group_by=self.query.group_by is not None + ) extra_select = self.get_extra_select(order_by, self.select) self.has_extra_select = bool(extra_select) group_by = self.get_group_by(self.select + extra_select, order_by) @@ -584,6 +587,74 @@ class SQLCompiler: params.extend(part) return result, params + def get_qualify_sql(self): + where_parts = [] + if self.where: + where_parts.append(self.where) + if self.having: + where_parts.append(self.having) + inner_query = self.query.clone() + inner_query.subquery = True + inner_query.where = inner_query.where.__class__(where_parts) + # Augment the inner query with any window function references that + # might have been masked via values() and alias(). If any masked + # aliases are added they'll be masked again to avoid fetching + # the data in the `if qual_aliases` branch below. + select = { + expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0] + } + qual_aliases = set() + replacements = {} + expressions = list(self.qualify.leaves()) + while expressions: + expr = expressions.pop() + if select_alias := (select.get(expr) or replacements.get(expr)): + replacements[expr] = select_alias + elif isinstance(expr, Lookup): + expressions.extend(expr.get_source_expressions()) + else: + num_qual_alias = len(qual_aliases) + select_alias = f"qual{num_qual_alias}" + qual_aliases.add(select_alias) + inner_query.add_annotation(expr, select_alias) + replacements[expr] = select_alias + self.qualify = self.qualify.replace_expressions( + {expr: Ref(alias, expr) for expr, alias in replacements.items()} + ) + inner_query_compiler = inner_query.get_compiler( + self.using, elide_empty=self.elide_empty + ) + inner_sql, inner_params = inner_query_compiler.as_sql( + # The limits must be applied to the outer query to avoid pruning + # results too eagerly. + with_limits=False, + # Force unique aliasing of selected columns to avoid collisions + # and make rhs predicates referencing easier. + with_col_aliases=True, + ) + qualify_sql, qualify_params = self.compile(self.qualify) + result = [ + "SELECT * FROM (", + inner_sql, + ")", + self.connection.ops.quote_name("qualify"), + "WHERE", + qualify_sql, + ] + if qual_aliases: + # If some select aliases were unmasked for filtering purposes they + # must be masked back. + cols = [self.connection.ops.quote_name(alias) for alias in select.values()] + result = [ + "SELECT", + ", ".join(cols), + "FROM (", + *result, + ")", + self.connection.ops.quote_name("qualify_mask"), + ] + return result, list(inner_params) + qualify_params + def as_sql(self, with_limits=True, with_col_aliases=False): """ Create the SQL for this query. Return the SQL string and list of @@ -614,6 +685,9 @@ class SQLCompiler: result, params = self.get_combinator_sql( combinator, self.query.combinator_all ) + elif self.qualify: + result, params = self.get_qualify_sql() + order_by = None else: distinct_fields, distinct_params = self.get_distinct() # This must come after 'select', 'ordering', and 'distinct' |