diff options
author | Eric Masseran <eric.masseran@gmail.com> | 2021-10-08 10:02:58 -0400 |
---|---|---|
committer | mike bayer <mike_mp@zzzcomputing.com> | 2021-10-12 22:46:57 +0000 |
commit | ee9b8836a160484733baa556c5d3ade4810aa999 (patch) | |
tree | 623c4fa6e17d2366934b931a9695f12dc1a34e9f /lib/sqlalchemy/sql/compiler.py | |
parent | de9db9940fbcf32ccd93169d2ed6aa874869b84d (diff) | |
download | sqlalchemy-ee9b8836a160484733baa556c5d3ade4810aa999.tar.gz |
Fix recursive CTE to support nesting
Repaired issue in new :paramref:`_sql.HasCTE.cte.nesting` parameter
introduced with :ticket:`4123` where a recursive :class:`_sql.CTE` using
:paramref:`_sql.HasCTE.cte.recursive` in typical conjunction with UNION
would not compile correctly. Additionally makes some adjustments so that
the :class:`_sql.CTE` construct creates a correct cache key.
Pull request courtesy Eric Masseran.
Fixes: #4123
> This has not been caught by the tests because the nesting recursive
queries there did not union against itself, eg there was only the i
root clause...
- Now tests are real recursive queries
- Add tests on aliased nested CTEs (recursive or not)
- Adapt the `_restates` attribute to use it as a reference
- Add some docs around to explain some variables usage
Closes: #7133
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7133
Pull-request-sha: 2633f34f7f5336a4a85bd3f71d07bca33ce27a2c
Change-Id: I15512c94e1bc1f52afc619d82057ca647d274e92
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 92 |
1 files changed, 58 insertions, 34 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 333ed36f4..efcfe0e51 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -840,10 +840,17 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT + # To store the query to print - Dict[cte, text_query] self.ctes = util.OrderedDict() - # Detect same CTE references - self.ctes_by_name = {} - self.level_by_ctes = {} + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + self.ctes_by_level_name = {} + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name)] + self.level_name_by_cte = {} + self.ctes_recursive = False if self.positional: self.cte_positional = {} @@ -2515,8 +2522,6 @@ class SQLCompiler(Compiled): ): self._init_cte_state() - cte_level = len(self.stack) if cte.nesting else 1 - kwargs["visiting_cte"] = cte cte_name = cte.name @@ -2527,44 +2532,60 @@ class SQLCompiler(Compiled): is_new_cte = True embedded_in_current_named_cte = False - if cte in self.level_by_ctes: - cte_level = self.level_by_ctes[cte] + _reference_cte = cte._get_reference_cte() + + if _reference_cte in self.level_name_by_cte: + cte_level, _ = self.level_name_by_cte[_reference_cte] + assert _ == cte_name + else: + cte_level = len(self.stack) if cte.nesting else 1 cte_level_name = (cte_level, cte_name) - if cte_level_name in self.ctes_by_name: - existing_cte = self.ctes_by_name[cte_level_name] + if cte_level_name in self.ctes_by_level_name: + existing_cte = self.ctes_by_level_name[cte_level_name] embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, # or this is the same CTE. just return the name. - if cte in existing_cte._restates or cte is existing_cte: + if cte is existing_cte._restates or cte is existing_cte: is_new_cte = False - elif existing_cte in cte._restates: + elif existing_cte is cte._restates: # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". del self.ctes[existing_cte] - del self.level_by_ctes[existing_cte] + + existing_cte_reference_cte = existing_cte._get_reference_cte() + + # TODO: determine if these assertions are correct. they + # pass for current test cases + # assert existing_cte_reference_cte is _reference_cte + # assert existing_cte_reference_cte is existing_cte + + del self.level_name_by_cte[existing_cte_reference_cte] else: raise exc.CompileError( "Multiple, unrelated CTEs found with " "the same name: %r" % cte_name ) - if asfrom or is_new_cte: - if cte._cte_alias is not None: - pre_alias_cte = cte._cte_alias - cte_pre_alias_name = cte._cte_alias.name - if isinstance(cte_pre_alias_name, elements._truncated_label): - cte_pre_alias_name = self._truncated_identifier( - "alias", cte_pre_alias_name - ) - else: - pre_alias_cte = cte - cte_pre_alias_name = None + if not asfrom and not is_new_cte: + return None + + if cte._cte_alias is not None: + pre_alias_cte = cte._cte_alias + cte_pre_alias_name = cte._cte_alias.name + if isinstance(cte_pre_alias_name, elements._truncated_label): + cte_pre_alias_name = self._truncated_identifier( + "alias", cte_pre_alias_name + ) + else: + pre_alias_cte = cte + cte_pre_alias_name = None if is_new_cte: - self.ctes_by_name[cte_level_name] = cte + self.ctes_by_level_name[cte_level_name] = cte + self.level_name_by_cte[_reference_cte] = cte_level_name if ( "autocommit" in cte.element._execution_options @@ -2649,7 +2670,6 @@ class SQLCompiler(Compiled): ) self.ctes[cte] = text - self.level_by_ctes[cte] = cte_level if asfrom: if from_linter: @@ -3475,7 +3495,9 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): - cte_level = self.level_by_ctes[cte] + cte_level, cte_name = self.level_name_by_cte[ + cte._get_reference_cte() + ] is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) @@ -3484,14 +3506,6 @@ class SQLCompiler(Compiled): ctes[cte] = self.ctes[cte] - del self.ctes[cte] - del self.level_by_ctes[cte] - - cte_name = cte.name - if isinstance(cte_name, elements._truncated_label): - cte_name = self._truncated_identifier("alias", cte_name) - - del self.ctes_by_name[(cte_level, cte_name)] else: ctes = self.ctes @@ -3508,6 +3522,16 @@ class SQLCompiler(Compiled): cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " + + if nesting_level and nesting_level > 1: + for cte in list(ctes.keys()): + cte_level, cte_name = self.level_name_by_cte[ + cte._get_reference_cte() + ] + del self.ctes[cte] + del self.ctes_by_level_name[(cte_level, cte_name)] + del self.level_name_by_cte[cte._get_reference_cte()] + return cte_text def get_cte_preamble(self, recursive): |