diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 90 |
1 files changed, 73 insertions, 17 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index f1fe53b73..0afcdfaec 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -621,9 +621,15 @@ class SQLCompiler(engine.Compiled): def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=0, **kwargs): - entry = self.stack and self.stack[-1] or {} - self.stack.append({'from': entry.get('from', None), - 'iswrapper': not entry}) + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + self.stack.append( + { + 'correlate_froms': entry['correlate_froms'], + 'iswrapper': toplevel, + 'asfrom_froms': entry['asfrom_froms'] + }) keyword = self.compound_keywords.get(cs.keyword) @@ -644,7 +650,7 @@ class SQLCompiler(engine.Compiled): self.limit_clause(cs) or "" if self.ctes and \ - compound_index == 0 and not entry: + compound_index == 0 and toplevel: text = self._render_cte_clause() + text self.stack.pop(-1) @@ -1197,12 +1203,42 @@ class SQLCompiler(engine.Compiled): objs = tuple([d.get(col, col) for col in objs]) self.result_map[key] = (name, objs, typ) + + _default_stack_entry = util.immutabledict([ + ('iswrapper', False), + ('correlate_froms', frozenset()), + ('asfrom_froms', frozenset()) + ]) + + def _display_froms_for_select(self, select, asfrom): + # utility method to help external dialects + # get the correct from list for a select. + # specifically the oracle dialect needs this feature + # right now. + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + correlate_froms = entry['correlate_froms'] + asfrom_froms = entry['asfrom_froms'] + + if asfrom: + froms = select._get_display_froms( + explicit_correlate_froms=\ + correlate_froms.difference(asfrom_froms), + implicit_correlate_froms=()) + else: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) + return froms + def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, fromhints=None, compound_index=0, force_result_map=False, positional_names=None, - nested_join_translation=False, **kwargs): + nested_join_translation=False, + **kwargs): needs_nested_translation = \ select.use_labels and \ @@ -1221,12 +1257,14 @@ class SQLCompiler(engine.Compiled): nested_join_translation=True, **kwargs ) - entry = self.stack and self.stack[-1] or {} + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + populate_result_map = force_result_map or ( compound_index == 0 and ( - not entry or \ - entry.get('iswrapper', False) + toplevel or \ + entry['iswrapper'] ) ) @@ -1236,15 +1274,28 @@ class SQLCompiler(engine.Compiled): select, transformed_select) return text - existingfroms = entry.get('from', None) + correlate_froms = entry['correlate_froms'] + asfrom_froms = entry['asfrom_froms'] - froms = select._get_display_froms(existingfroms, asfrom=asfrom) - - correlate_froms = set(sql._from_objects(*froms)) + if asfrom: + froms = select._get_display_froms( + explicit_correlate_froms= + correlate_froms.difference(asfrom_froms), + implicit_correlate_froms=()) + else: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) + new_correlate_froms = set(sql._from_objects(*froms)) + all_correlate_froms = new_correlate_froms.union(correlate_froms) - self.stack.append({'from': correlate_froms, - 'iswrapper': iswrapper}) + new_entry = { + 'asfrom_froms': new_correlate_froms, + 'iswrapper': iswrapper, + 'correlate_froms': all_correlate_froms + } + self.stack.append(new_entry) column_clause_args = kwargs.copy() column_clause_args.update({ @@ -1333,7 +1384,7 @@ class SQLCompiler(engine.Compiled): text += self.for_update_clause(select) if self.ctes and \ - compound_index == 0 and not entry: + compound_index == 0 and toplevel: text = self._render_cte_clause() + text self.stack.pop(-1) @@ -1546,7 +1597,10 @@ class SQLCompiler(engine.Compiled): for t in extra_froms) def visit_update(self, update_stmt, **kw): - self.stack.append({'from': set([update_stmt.table])}) + self.stack.append( + {'correlate_froms': set([update_stmt.table]), + "iswrapper": False, + "asfrom_froms": set([update_stmt.table])}) self.isupdate = True @@ -1880,7 +1934,9 @@ class SQLCompiler(engine.Compiled): return values def visit_delete(self, delete_stmt, **kw): - self.stack.append({'from': set([delete_stmt.table])}) + self.stack.append({'correlate_froms': set([delete_stmt.table]), + "iswrapper": False, + "asfrom_froms": set([delete_stmt.table])}) self.isdelete = True text = "DELETE " |