summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py90
1 files changed, 73 insertions, 17 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index f1fe53b73..0afcdfaec 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -621,9 +621,15 @@ class SQLCompiler(engine.Compiled):
def visit_compound_select(self, cs, asfrom=False,
parens=True, compound_index=0, **kwargs):
- entry = self.stack and self.stack[-1] or {}
- self.stack.append({'from': entry.get('from', None),
- 'iswrapper': not entry})
+ toplevel = not self.stack
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
+ self.stack.append(
+ {
+ 'correlate_froms': entry['correlate_froms'],
+ 'iswrapper': toplevel,
+ 'asfrom_froms': entry['asfrom_froms']
+ })
keyword = self.compound_keywords.get(cs.keyword)
@@ -644,7 +650,7 @@ class SQLCompiler(engine.Compiled):
self.limit_clause(cs) or ""
if self.ctes and \
- compound_index == 0 and not entry:
+ compound_index == 0 and toplevel:
text = self._render_cte_clause() + text
self.stack.pop(-1)
@@ -1197,12 +1203,42 @@ class SQLCompiler(engine.Compiled):
objs = tuple([d.get(col, col) for col in objs])
self.result_map[key] = (name, objs, typ)
+
+ _default_stack_entry = util.immutabledict([
+ ('iswrapper', False),
+ ('correlate_froms', frozenset()),
+ ('asfrom_froms', frozenset())
+ ])
+
+ def _display_froms_for_select(self, select, asfrom):
+ # utility method to help external dialects
+ # get the correct from list for a select.
+ # specifically the oracle dialect needs this feature
+ # right now.
+ toplevel = not self.stack
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
+ correlate_froms = entry['correlate_froms']
+ asfrom_froms = entry['asfrom_froms']
+
+ if asfrom:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=\
+ correlate_froms.difference(asfrom_froms),
+ implicit_correlate_froms=())
+ else:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms)
+ return froms
+
def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None,
compound_index=0,
force_result_map=False,
positional_names=None,
- nested_join_translation=False, **kwargs):
+ nested_join_translation=False,
+ **kwargs):
needs_nested_translation = \
select.use_labels and \
@@ -1221,12 +1257,14 @@ class SQLCompiler(engine.Compiled):
nested_join_translation=True, **kwargs
)
- entry = self.stack and self.stack[-1] or {}
+ toplevel = not self.stack
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
populate_result_map = force_result_map or (
compound_index == 0 and (
- not entry or \
- entry.get('iswrapper', False)
+ toplevel or \
+ entry['iswrapper']
)
)
@@ -1236,15 +1274,28 @@ class SQLCompiler(engine.Compiled):
select, transformed_select)
return text
- existingfroms = entry.get('from', None)
+ correlate_froms = entry['correlate_froms']
+ asfrom_froms = entry['asfrom_froms']
- froms = select._get_display_froms(existingfroms, asfrom=asfrom)
-
- correlate_froms = set(sql._from_objects(*froms))
+ if asfrom:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=
+ correlate_froms.difference(asfrom_froms),
+ implicit_correlate_froms=())
+ else:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms)
+ new_correlate_froms = set(sql._from_objects(*froms))
+ all_correlate_froms = new_correlate_froms.union(correlate_froms)
- self.stack.append({'from': correlate_froms,
- 'iswrapper': iswrapper})
+ new_entry = {
+ 'asfrom_froms': new_correlate_froms,
+ 'iswrapper': iswrapper,
+ 'correlate_froms': all_correlate_froms
+ }
+ self.stack.append(new_entry)
column_clause_args = kwargs.copy()
column_clause_args.update({
@@ -1333,7 +1384,7 @@ class SQLCompiler(engine.Compiled):
text += self.for_update_clause(select)
if self.ctes and \
- compound_index == 0 and not entry:
+ compound_index == 0 and toplevel:
text = self._render_cte_clause() + text
self.stack.pop(-1)
@@ -1546,7 +1597,10 @@ class SQLCompiler(engine.Compiled):
for t in extra_froms)
def visit_update(self, update_stmt, **kw):
- self.stack.append({'from': set([update_stmt.table])})
+ self.stack.append(
+ {'correlate_froms': set([update_stmt.table]),
+ "iswrapper": False,
+ "asfrom_froms": set([update_stmt.table])})
self.isupdate = True
@@ -1880,7 +1934,9 @@ class SQLCompiler(engine.Compiled):
return values
def visit_delete(self, delete_stmt, **kw):
- self.stack.append({'from': set([delete_stmt.table])})
+ self.stack.append({'correlate_froms': set([delete_stmt.table]),
+ "iswrapper": False,
+ "asfrom_froms": set([delete_stmt.table])})
self.isdelete = True
text = "DELETE "