diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-02-24 14:42:20 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-02-24 14:42:20 +0000 |
commit | 11333602c0f844e32f12af114f1dfcb160408fcf (patch) | |
tree | 8085956a2f83cbc57d1717dfd6a9eb17e3aa87ac /lib/sqlalchemy | |
parent | 878c37614efd311794aa50467dbb9e3fe972fdff (diff) | |
parent | bef67e58121704a9836e1e5ec2d361cd2086036c (diff) | |
download | sqlalchemy-11333602c0f844e32f12af114f1dfcb160408fcf.tar.gz |
Merge "support add_cte() for TextualSelect" into main
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/orm/context.py | 64 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 18 |
2 files changed, 79 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index f51abde0c..63ed10d50 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -27,6 +27,7 @@ from .. import future from .. import inspect from .. import sql from .. import util +from ..sql import ClauseElement from ..sql import coercions from ..sql import expression from ..sql import roles @@ -486,8 +487,8 @@ class ORMFromStatementCompileState(ORMCompileState): entity.setup_compile_state(self) # we did the setup just to get primary columns. - self.statement = expression.TextualSelect( - self.statement, self.primary_columns, positional=False + self.statement = _AdHocColumnsStatement( + self.statement, self.primary_columns ) else: # allow TextualSelect with implicit columns as well @@ -514,6 +515,65 @@ class ORMFromStatementCompileState(ORMCompileState): return None +class _AdHocColumnsStatement(ClauseElement): + """internal object created to somewhat act like a SELECT when we + are selecting columns from a DML RETURNING. + + + """ + + __visit_name__ = None + + def __init__(self, text, columns): + self.element = text + self.column_args = [ + coercions.expect(roles.ColumnsClauseRole, c) for c in columns + ] + + def _generate_cache_key(self): + raise NotImplementedError() + + def _gen_cache_key(self, anon_map, bindparams): + raise NotImplementedError() + + def _compiler_dispatch( + self, compiler, compound_index=None, asfrom=False, **kw + ): + """provide a fixed _compiler_dispatch method.""" + + toplevel = not compiler.stack + entry = ( + compiler._default_stack_entry if toplevel else compiler.stack[-1] + ) + + populate_result_map = ( + toplevel + # these two might not be needed + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) + + if populate_result_map: + compiler._ordered_columns = ( + compiler._textual_ordered_columns + ) = False + + # enable looser result column matching. this is shown to be + # needed by test_query.py::TextTest + compiler._loose_column_name_matching = True + + for c in self.column_args: + compiler.process( + c, + within_columns_clause=True, + add_to_result_map=compiler._add_to_result_map, + ) + return compiler.process(self.element, **kw) + + @sql.base.CompileState.plugin_for("orm", "select") class ORMSelectCompileState(ORMCompileState, SelectState): _already_joined_edges = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4a169f719..b140f9297 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1596,6 +1596,17 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] + new_entry = { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": taf, + } + self.stack.append(new_entry) + + if taf._independent_ctes: + for cte in taf._independent_ctes: + cte._compiler_dispatch(self, **kw) + populate_result_map = ( toplevel or ( @@ -1623,7 +1634,12 @@ class SQLCompiler(Compiled): add_to_result_map=self._add_to_result_map, ) - return self.process(taf.element, **kw) + text = self.process(taf.element, **kw) + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + return text def visit_null(self, expr, **kw): return "NULL" |