summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py100
1 files changed, 55 insertions, 45 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 7922054f8..13219ee68 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -488,7 +488,7 @@ class SQLCompiler(Compiled):
"""
if False, means we can't be sure the list of entries
in _result_columns is actually the rendered order. Usually
- True unless using an unordered TextAsFrom.
+ True unless using an unordered TextualSelect.
"""
_numeric_binds = False
@@ -916,8 +916,8 @@ class SQLCompiler(Compiled):
),
)
- def visit_text_as_from(
- self, taf, compound_index=None, asfrom=False, parens=True, **kw
+ def visit_textual_select(
+ self, taf, compound_index=None, asfrom=False, **kw
):
toplevel = not self.stack
@@ -943,10 +943,7 @@ class SQLCompiler(Compiled):
add_to_result_map=self._add_to_result_map,
)
- text = self.process(taf.element, **kw)
- if asfrom and parens:
- text = "(%s)" % text
- return text
+ return self.process(taf.element, **kw)
def visit_null(self, expr, **kw):
return "NULL"
@@ -1120,7 +1117,7 @@ class SQLCompiler(Compiled):
return func.clause_expr._compiler_dispatch(self, **kwargs)
def visit_compound_select(
- self, cs, asfrom=False, parens=True, compound_index=0, **kwargs
+ self, cs, asfrom=False, compound_index=0, **kwargs
):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
@@ -1143,16 +1140,13 @@ class SQLCompiler(Compiled):
text = (" " + keyword + " ").join(
(
c._compiler_dispatch(
- self,
- asfrom=asfrom,
- parens=False,
- compound_index=i,
- **kwargs
+ self, asfrom=asfrom, compound_index=i, **kwargs
)
for i, c in enumerate(cs.selects)
)
)
+ kwargs["include_table"] = False
text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
text += self.order_by_clause(cs, **kwargs)
text += (
@@ -1165,10 +1159,7 @@ class SQLCompiler(Compiled):
text = self._render_cte_clause() + text
self.stack.pop(-1)
- if asfrom and parens:
- return "(" + text + ")"
- else:
- return text
+ return text
def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
attrname = "visit_%s_%s%s" % (
@@ -1682,8 +1673,11 @@ class SQLCompiler(Compiled):
if self.positional:
kwargs["positional_names"] = self.cte_positional[cte] = []
- text += " AS \n" + cte.element._compiler_dispatch(
- self, asfrom=True, **kwargs
+ assert kwargs.get("subquery", False) is False
+ text += " AS \n(%s)" % (
+ cte.element._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ ),
)
if cte._suffixes:
@@ -1713,8 +1707,28 @@ class SQLCompiler(Compiled):
ashint=False,
iscrud=False,
fromhints=None,
+ subquery=False,
+ lateral=False,
+ enclosing_alias=None,
**kwargs
):
+ if enclosing_alias is not None and enclosing_alias.element is alias:
+ inner = alias.element._compiler_dispatch(
+ self,
+ asfrom=asfrom,
+ ashint=ashint,
+ iscrud=iscrud,
+ fromhints=fromhints,
+ lateral=lateral,
+ enclosing_alias=alias,
+ **kwargs
+ )
+ if subquery and (asfrom or lateral):
+ inner = "(%s)" % (inner,)
+ return inner
+ else:
+ enclosing_alias = kwargs["enclosing_alias"] = alias
+
if asfrom or ashint:
if isinstance(alias.name, elements._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
@@ -1724,12 +1738,15 @@ class SQLCompiler(Compiled):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
- ret = alias.element._compiler_dispatch(
- self, asfrom=True, **kwargs
- ) + self.get_render_as_alias_suffix(
- self.preparer.format_alias(alias, alias_name)
+ inner = alias.element._compiler_dispatch(
+ self, asfrom=True, lateral=lateral, **kwargs
)
+ if subquery:
+ inner = "(%s)" % (inner,)
+ ret = inner + self.get_render_as_alias_suffix(
+ self.preparer.format_alias(alias, alias_name)
+ )
if fromhints and alias in fromhints:
ret = self.format_from_hint_text(
ret, alias, fromhints[alias], iscrud
@@ -1737,7 +1754,14 @@ class SQLCompiler(Compiled):
return ret
else:
- return alias.element._compiler_dispatch(self, **kwargs)
+ # note we cancel the "subquery" flag here as well
+ return alias.element._compiler_dispatch(
+ self, lateral=lateral, **kwargs
+ )
+
+ def visit_subquery(self, subquery, **kw):
+ kw["subquery"] = True
+ return self.visit_alias(subquery, **kw)
def visit_lateral(self, lateral, **kw):
kw["lateral"] = True
@@ -2004,7 +2028,6 @@ class SQLCompiler(Compiled):
self,
select,
asfrom=False,
- parens=True,
fromhints=None,
compound_index=0,
nested_join_translation=False,
@@ -2027,7 +2050,6 @@ class SQLCompiler(Compiled):
text = self.visit_select(
transformed_select,
asfrom=asfrom,
- parens=parens,
fromhints=fromhints,
compound_index=compound_index,
nested_join_translation=True,
@@ -2138,10 +2160,7 @@ class SQLCompiler(Compiled):
self.stack.pop(-1)
- if (asfrom or lateral) and parens:
- return "(" + text + ")"
- else:
- return text
+ return text
def _setup_select_hints(self, select):
byfrom = dict(
@@ -2371,7 +2390,7 @@ class SQLCompiler(Compiled):
)
return dialect_hints, table_text
- def visit_insert(self, insert_stmt, asfrom=False, **kw):
+ def visit_insert(self, insert_stmt, **kw):
toplevel = not self.stack
self.stack.append(
@@ -2475,10 +2494,7 @@ class SQLCompiler(Compiled):
self.stack.pop(-1)
- if asfrom:
- return "(" + text + ")"
- else:
- return text
+ return text
def update_limit_clause(self, update_stmt):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
@@ -2508,7 +2524,7 @@ class SQLCompiler(Compiled):
"criteria within UPDATE"
)
- def visit_update(self, update_stmt, asfrom=False, **kw):
+ def visit_update(self, update_stmt, **kw):
toplevel = not self.stack
extra_froms = update_stmt._extra_froms
@@ -2605,10 +2621,7 @@ class SQLCompiler(Compiled):
self.stack.pop(-1)
- if asfrom:
- return "(" + text + ")"
- else:
- return text
+ return text
@util.memoized_property
def _key_getters_for_crud_column(self):
@@ -2633,7 +2646,7 @@ class SQLCompiler(Compiled):
def delete_table_clause(self, delete_stmt, from_table, extra_froms):
return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
- def visit_delete(self, delete_stmt, asfrom=False, **kw):
+ def visit_delete(self, delete_stmt, **kw):
toplevel = not self.stack
crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw)
@@ -2702,10 +2715,7 @@ class SQLCompiler(Compiled):
self.stack.pop(-1)
- if asfrom:
- return "(" + text + ")"
- else:
- return text
+ return text
def visit_savepoint(self, savepoint_stmt):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)