diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 100 |
1 files changed, 55 insertions, 45 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 7922054f8..13219ee68 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -488,7 +488,7 @@ class SQLCompiler(Compiled): """ if False, means we can't be sure the list of entries in _result_columns is actually the rendered order. Usually - True unless using an unordered TextAsFrom. + True unless using an unordered TextualSelect. """ _numeric_binds = False @@ -916,8 +916,8 @@ class SQLCompiler(Compiled): ), ) - def visit_text_as_from( - self, taf, compound_index=None, asfrom=False, parens=True, **kw + def visit_textual_select( + self, taf, compound_index=None, asfrom=False, **kw ): toplevel = not self.stack @@ -943,10 +943,7 @@ class SQLCompiler(Compiled): add_to_result_map=self._add_to_result_map, ) - text = self.process(taf.element, **kw) - if asfrom and parens: - text = "(%s)" % text - return text + return self.process(taf.element, **kw) def visit_null(self, expr, **kw): return "NULL" @@ -1120,7 +1117,7 @@ class SQLCompiler(Compiled): return func.clause_expr._compiler_dispatch(self, **kwargs) def visit_compound_select( - self, cs, asfrom=False, parens=True, compound_index=0, **kwargs + self, cs, asfrom=False, compound_index=0, **kwargs ): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] @@ -1143,16 +1140,13 @@ class SQLCompiler(Compiled): text = (" " + keyword + " ").join( ( c._compiler_dispatch( - self, - asfrom=asfrom, - parens=False, - compound_index=i, - **kwargs + self, asfrom=asfrom, compound_index=i, **kwargs ) for i, c in enumerate(cs.selects) ) ) + kwargs["include_table"] = False text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs)) text += self.order_by_clause(cs, **kwargs) text += ( @@ -1165,10 +1159,7 @@ class SQLCompiler(Compiled): text = self._render_cte_clause() + text self.stack.pop(-1) - if asfrom and parens: - return "(" + text + ")" - else: - return text + return text def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): attrname = "visit_%s_%s%s" % ( @@ -1682,8 +1673,11 @@ class SQLCompiler(Compiled): if self.positional: kwargs["positional_names"] = self.cte_positional[cte] = [] - text += " AS \n" + cte.element._compiler_dispatch( - self, asfrom=True, **kwargs + assert kwargs.get("subquery", False) is False + text += " AS \n(%s)" % ( + cte.element._compiler_dispatch( + self, asfrom=True, **kwargs + ), ) if cte._suffixes: @@ -1713,8 +1707,28 @@ class SQLCompiler(Compiled): ashint=False, iscrud=False, fromhints=None, + subquery=False, + lateral=False, + enclosing_alias=None, **kwargs ): + if enclosing_alias is not None and enclosing_alias.element is alias: + inner = alias.element._compiler_dispatch( + self, + asfrom=asfrom, + ashint=ashint, + iscrud=iscrud, + fromhints=fromhints, + lateral=lateral, + enclosing_alias=alias, + **kwargs + ) + if subquery and (asfrom or lateral): + inner = "(%s)" % (inner,) + return inner + else: + enclosing_alias = kwargs["enclosing_alias"] = alias + if asfrom or ashint: if isinstance(alias.name, elements._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) @@ -1724,12 +1738,15 @@ class SQLCompiler(Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: - ret = alias.element._compiler_dispatch( - self, asfrom=True, **kwargs - ) + self.get_render_as_alias_suffix( - self.preparer.format_alias(alias, alias_name) + inner = alias.element._compiler_dispatch( + self, asfrom=True, lateral=lateral, **kwargs ) + if subquery: + inner = "(%s)" % (inner,) + ret = inner + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name) + ) if fromhints and alias in fromhints: ret = self.format_from_hint_text( ret, alias, fromhints[alias], iscrud @@ -1737,7 +1754,14 @@ class SQLCompiler(Compiled): return ret else: - return alias.element._compiler_dispatch(self, **kwargs) + # note we cancel the "subquery" flag here as well + return alias.element._compiler_dispatch( + self, lateral=lateral, **kwargs + ) + + def visit_subquery(self, subquery, **kw): + kw["subquery"] = True + return self.visit_alias(subquery, **kw) def visit_lateral(self, lateral, **kw): kw["lateral"] = True @@ -2004,7 +2028,6 @@ class SQLCompiler(Compiled): self, select, asfrom=False, - parens=True, fromhints=None, compound_index=0, nested_join_translation=False, @@ -2027,7 +2050,6 @@ class SQLCompiler(Compiled): text = self.visit_select( transformed_select, asfrom=asfrom, - parens=parens, fromhints=fromhints, compound_index=compound_index, nested_join_translation=True, @@ -2138,10 +2160,7 @@ class SQLCompiler(Compiled): self.stack.pop(-1) - if (asfrom or lateral) and parens: - return "(" + text + ")" - else: - return text + return text def _setup_select_hints(self, select): byfrom = dict( @@ -2371,7 +2390,7 @@ class SQLCompiler(Compiled): ) return dialect_hints, table_text - def visit_insert(self, insert_stmt, asfrom=False, **kw): + def visit_insert(self, insert_stmt, **kw): toplevel = not self.stack self.stack.append( @@ -2475,10 +2494,7 @@ class SQLCompiler(Compiled): self.stack.pop(-1) - if asfrom: - return "(" + text + ")" - else: - return text + return text def update_limit_clause(self, update_stmt): """Provide a hook for MySQL to add LIMIT to the UPDATE""" @@ -2508,7 +2524,7 @@ class SQLCompiler(Compiled): "criteria within UPDATE" ) - def visit_update(self, update_stmt, asfrom=False, **kw): + def visit_update(self, update_stmt, **kw): toplevel = not self.stack extra_froms = update_stmt._extra_froms @@ -2605,10 +2621,7 @@ class SQLCompiler(Compiled): self.stack.pop(-1) - if asfrom: - return "(" + text + ")" - else: - return text + return text @util.memoized_property def _key_getters_for_crud_column(self): @@ -2633,7 +2646,7 @@ class SQLCompiler(Compiled): def delete_table_clause(self, delete_stmt, from_table, extra_froms): return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) - def visit_delete(self, delete_stmt, asfrom=False, **kw): + def visit_delete(self, delete_stmt, **kw): toplevel = not self.stack crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw) @@ -2702,10 +2715,7 @@ class SQLCompiler(Compiled): self.stack.pop(-1) - if asfrom: - return "(" + text + ")" - else: - return text + return text def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) |