summaryrefslogtreecommitdiff
path: root/django/db/models/sql/compiler.py
diff options
context:
space:
mode:
authorSimon Charette <charette.s@gmail.com>2022-08-10 08:22:01 -0400
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-08-15 08:26:26 +0200
commitf387d024fc75569d2a4a338bfda76cc2f328f627 (patch)
tree61994be69d4dfa545158f7e887f0673ec281e479 /django/db/models/sql/compiler.py
parentf3f9d03edf17ccfa17263c7efa0b1350d1ac9278 (diff)
downloaddjango-f387d024fc75569d2a4a338bfda76cc2f328f627.tar.gz
Refs #28333 -- Added partial support for filtering against window functions.
Adds support for joint predicates against window annotations through subquery wrapping while maintaining errors for disjointed filter attempts. The "qualify" wording was used to refer to predicates against window annotations as it's the name of a specialized Snowflake extension to SQL that is to window functions what HAVING is to aggregates. While not complete the implementation should cover most of the common use cases for filtering against window functions without requiring the complex subquery pushdown and predicate re-aliasing machinery to deal with disjointed predicates against columns, aggregates, and window functions. A complete disjointed filtering implementation should likely be deferred until proper QUALIFY support lands or the ORM gains a proper subquery pushdown interface.
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r--django/db/models/sql/compiler.py76
1 files changed, 75 insertions, 1 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 1c0ab2d212..858142913b 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -9,6 +9,7 @@ from django.db import DatabaseError, NotSupportedError
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
from django.db.models.functions import Cast, Random
+from django.db.models.lookups import Lookup
from django.db.models.query_utils import select_related_descend
from django.db.models.sql.constants import (
CURSOR,
@@ -73,7 +74,9 @@ class SQLCompiler:
"""
self.setup_query(with_col_aliases=with_col_aliases)
order_by = self.get_order_by()
- self.where, self.having = self.query.where.split_having()
+ self.where, self.having, self.qualify = self.query.where.split_having_qualify(
+ must_group_by=self.query.group_by is not None
+ )
extra_select = self.get_extra_select(order_by, self.select)
self.has_extra_select = bool(extra_select)
group_by = self.get_group_by(self.select + extra_select, order_by)
@@ -584,6 +587,74 @@ class SQLCompiler:
params.extend(part)
return result, params
+ def get_qualify_sql(self):
+ where_parts = []
+ if self.where:
+ where_parts.append(self.where)
+ if self.having:
+ where_parts.append(self.having)
+ inner_query = self.query.clone()
+ inner_query.subquery = True
+ inner_query.where = inner_query.where.__class__(where_parts)
+ # Augment the inner query with any window function references that
+ # might have been masked via values() and alias(). If any masked
+ # aliases are added they'll be masked again to avoid fetching
+ # the data in the `if qual_aliases` branch below.
+ select = {
+ expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0]
+ }
+ qual_aliases = set()
+ replacements = {}
+ expressions = list(self.qualify.leaves())
+ while expressions:
+ expr = expressions.pop()
+ if select_alias := (select.get(expr) or replacements.get(expr)):
+ replacements[expr] = select_alias
+ elif isinstance(expr, Lookup):
+ expressions.extend(expr.get_source_expressions())
+ else:
+ num_qual_alias = len(qual_aliases)
+ select_alias = f"qual{num_qual_alias}"
+ qual_aliases.add(select_alias)
+ inner_query.add_annotation(expr, select_alias)
+ replacements[expr] = select_alias
+ self.qualify = self.qualify.replace_expressions(
+ {expr: Ref(alias, expr) for expr, alias in replacements.items()}
+ )
+ inner_query_compiler = inner_query.get_compiler(
+ self.using, elide_empty=self.elide_empty
+ )
+ inner_sql, inner_params = inner_query_compiler.as_sql(
+ # The limits must be applied to the outer query to avoid pruning
+ # results too eagerly.
+ with_limits=False,
+ # Force unique aliasing of selected columns to avoid collisions
+ # and make rhs predicates referencing easier.
+ with_col_aliases=True,
+ )
+ qualify_sql, qualify_params = self.compile(self.qualify)
+ result = [
+ "SELECT * FROM (",
+ inner_sql,
+ ")",
+ self.connection.ops.quote_name("qualify"),
+ "WHERE",
+ qualify_sql,
+ ]
+ if qual_aliases:
+ # If some select aliases were unmasked for filtering purposes they
+ # must be masked back.
+ cols = [self.connection.ops.quote_name(alias) for alias in select.values()]
+ result = [
+ "SELECT",
+ ", ".join(cols),
+ "FROM (",
+ *result,
+ ")",
+ self.connection.ops.quote_name("qualify_mask"),
+ ]
+ return result, list(inner_params) + qualify_params
+
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Create the SQL for this query. Return the SQL string and list of
@@ -614,6 +685,9 @@ class SQLCompiler:
result, params = self.get_combinator_sql(
combinator, self.query.combinator_all
)
+ elif self.qualify:
+ result, params = self.get_qualify_sql()
+ order_by = None
else:
distinct_fields, distinct_params = self.get_distinct()
# This must come after 'select', 'ordering', and 'distinct'