diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-03-06 16:04:46 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-03-10 16:55:03 -0400 |
commit | 693938dd6fb2f3ee3e031aed4c62355ac97f3ceb (patch) | |
tree | 94701d7df1b7274151800efd6ca996e1f4203916 /lib/sqlalchemy/sql/compiler.py | |
parent | 851fb8f5a661c66ee76308181118369c8c4df9e0 (diff) | |
download | sqlalchemy-693938dd6fb2f3ee3e031aed4c62355ac97f3ceb.tar.gz |
Rework select(), CompoundSelect() in terms of CompileState
Continuation of I408e0b8be91fddd77cf279da97f55020871f75a9
- add an options() method to the base Generative construct.
this will be where ORM options can go
- Change Null, False_, True_ to be singletons, so that
we aren't instantiating them and having to use isinstance.
The previous issue with this was that they would produce dupe
labels in SELECT statements. Apply the duplicate column
logic, newly added in 1.4, to these objects as well as to
non-apply-labels SELECT statements in general as a means of
improving this.
- create a revised system for generating ClauseList compilation
constructs that simplfies up front creation to not actually
use ClauseList; a simple tuple is rendered by the compiler
using the same constrcution rules as what are used for
ClauseList but without creating the actual object. Apply
to Select, CompoundSelect, revise Update, Delete
- Select, CompoundSelect get an initial CompileState
implementation. All methods used only within compilation
are moved here
- refine update/insert/delete compile state to not require
an outside boolean
- refine and simplify Select._copy_internals
- rework bind(), which is going away, to not use some
of the internal traversal stuff
- remove "autocommit", "for_update" parameters from Select,
references #4643
- remove "autocommit" parameter from TextClause ,
references #4643
- add deprecation warnings for statement.execute(),
engine.execute(), statement.scalar(), engine.scalar().
Fixes: #5193
Change-Id: I04ca0152b046fd42c5054ba10f37e43fc6e5a57b
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 183 |
1 files changed, 131 insertions, 52 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3ebcf24b0..c39f59e32 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ import contextlib import itertools import re +from . import base from . import coercions from . import crud from . import elements @@ -1045,9 +1046,13 @@ class SQLCompiler(Compiled): self, element, within_columns_clause=False, **kwargs ): if self.stack and self.dialect.supports_simple_order_by_label: - selectable = self.stack[-1]["selectable"] + compile_state = self.stack[-1]["compile_state"] - with_cols, only_froms, only_cols = selectable._label_resolve_dict + ( + with_cols, + only_froms, + only_cols, + ) = compile_state._label_resolve_dict if within_columns_clause: resolve_dict = only_froms else: @@ -1082,8 +1087,8 @@ class SQLCompiler(Compiled): # compiling the element outside of the context of a SELECT return self.process(element._text_clause) - selectable = self.stack[-1]["selectable"] - with_cols, only_froms, only_cols = selectable._label_resolve_dict + compile_state = self.stack[-1]["compile_state"] + with_cols, only_froms, only_cols = compile_state._label_resolve_dict try: if within_columns_clause: col = only_froms[element.element] @@ -1313,6 +1318,24 @@ class SQLCompiler(Compiled): if s ) + def _generate_delimited_and_list(self, clauses, **kw): + + lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean( + operators.and_, + elements.True_._singleton, + elements.False_._singleton, + clauses, + ) + if lcc == 1: + return clauses[0]._compiler_dispatch(self, **kw) + else: + separator = OPERATORS[operators.and_] + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in clauses) + if s + ) + def visit_clauselist(self, clauselist, **kw): sep = clauselist.operator if sep is None: @@ -1473,6 +1496,12 @@ class SQLCompiler(Compiled): self, cs, asfrom=False, compound_index=0, **kwargs ): toplevel = not self.stack + + compile_state = cs._compile_state_factory(cs, self, **kwargs) + + if toplevel: + self.compile_state = compile_state + entry = self._default_stack_entry if toplevel else self.stack[-1] need_result_map = toplevel or ( compound_index == 0 @@ -1484,6 +1513,7 @@ class SQLCompiler(Compiled): "correlate_froms": entry["correlate_froms"], "asfrom_froms": entry["asfrom_froms"], "selectable": cs, + "compile_state": compile_state, "need_result_map_for_compound": need_result_map, } ) @@ -1665,7 +1695,6 @@ class SQLCompiler(Compiled): from_linter=None, **kw ): - if from_linter and operators.is_comparison(binary.operator): from_linter.edges.update( itertools.product( @@ -2273,7 +2302,6 @@ class SQLCompiler(Compiled): need_column_expressions=False, ): """produce labeled columns present in a select().""" - impl = column.type.dialect_impl(self.dialect) if impl._has_column_expression and ( @@ -2349,7 +2377,12 @@ class SQLCompiler(Compiled): or isinstance(column, functions.FunctionElement) ) ): - result_expr = _CompileLabel(col_expr, column.anon_label) + result_expr = _CompileLabel( + col_expr, + column.anon_label + if not column_is_repeated + else column._dedupe_label_anon_label, + ) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? # assert isinstance(column, elements.ColumnClause) @@ -2389,7 +2422,9 @@ class SQLCompiler(Compiled): [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] ) - def _display_froms_for_select(self, select, asfrom, lateral=False): + def _display_froms_for_select( + self, select_stmt, asfrom, lateral=False, **kw + ): # utility method to help external dialects # get the correct from list for a select. # specifically the oracle dialect needs this feature @@ -2397,18 +2432,20 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] + compile_state = select_stmt._compile_state_factory(select_stmt, self) + correlate_froms = entry["correlate_froms"] asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms.difference( asfrom_froms ), implicit_correlate_froms=(), ) else: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms, implicit_correlate_froms=asfrom_froms, ) @@ -2416,7 +2453,7 @@ class SQLCompiler(Compiled): def visit_select( self, - select, + select_stmt, asfrom=False, fromhints=None, compound_index=0, @@ -2426,7 +2463,16 @@ class SQLCompiler(Compiled): **kwargs ): + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + select_stmt = compile_state.statement + toplevel = not self.stack + + if toplevel: + self.compile_state = compile_state + entry = self._default_stack_entry if toplevel else self.stack[-1] populate_result_map = need_column_expressions = ( @@ -2445,7 +2491,7 @@ class SQLCompiler(Compiled): del kwargs["add_to_result_map"] froms = self._setup_select_stack( - select, entry, asfrom, lateral, compound_index + select_stmt, compile_state, entry, asfrom, lateral, compound_index ) column_clause_args = kwargs.copy() @@ -2455,23 +2501,25 @@ class SQLCompiler(Compiled): text = "SELECT " # we're off to a good start ! - if select._hints: - hint_text, byfrom = self._setup_select_hints(select) + if select_stmt._hints: + hint_text, byfrom = self._setup_select_hints(select_stmt) if hint_text: text += hint_text + " " else: byfrom = None - if select._prefixes: - text += self._generate_prefixes(select, select._prefixes, **kwargs) + if select_stmt._prefixes: + text += self._generate_prefixes( + select_stmt, select_stmt._prefixes, **kwargs + ) - text += self.get_select_precolumns(select, **kwargs) + text += self.get_select_precolumns(select_stmt, **kwargs) # the actual list of columns to print in the SELECT column list. inner_columns = [ c for c in [ self._label_select_column( - select, + select_stmt, column, populate_result_map, asfrom, @@ -2480,7 +2528,7 @@ class SQLCompiler(Compiled): column_is_repeated=repeated, need_column_expressions=need_column_expressions, ) - for name, column, repeated in select._columns_plus_names + for name, column, repeated in compile_state.columns_plus_names ] if c is not None ] @@ -2489,11 +2537,19 @@ class SQLCompiler(Compiled): # if this select is a compiler-generated wrapper, # rewrite the targeted columns in the result map + compile_state_wraps_for = select_wraps_for._compile_state_factory( + select_wraps_for, self, **kwargs + ) + translate = dict( zip( [ name - for (key, name, repeated) in select._columns_plus_names + for ( + key, + name, + repeated, + ) in compile_state.columns_plus_names ], [ name @@ -2501,7 +2557,7 @@ class SQLCompiler(Compiled): key, name, repeated, - ) in select_wraps_for._columns_plus_names + ) in compile_state_wraps_for.columns_plus_names ], ) ) @@ -2512,13 +2568,20 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, toplevel, kwargs + text, + select_stmt, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, ) - if select._statement_hints: + if select_stmt._statement_hints: per_dialect = [ ht - for (dialect_name, ht) in select._statement_hints + for (dialect_name, ht) in select_stmt._statement_hints if dialect_name in ("*", self.dialect.name) ] if per_dialect: @@ -2527,9 +2590,9 @@ class SQLCompiler(Compiled): if self.ctes and toplevel: text = self._render_cte_clause() + text - if select._suffixes: + if select_stmt._suffixes: text += " " + self._generate_prefixes( - select, select._suffixes, **kwargs + select_stmt, select_stmt._suffixes, **kwargs ) self.stack.pop(-1) @@ -2552,7 +2615,7 @@ class SQLCompiler(Compiled): return hint_text, byfrom def _setup_select_stack( - self, select, entry, asfrom, lateral, compound_index + self, select, compile_state, entry, asfrom, lateral, compound_index ): correlate_froms = entry["correlate_froms"] asfrom_froms = entry["asfrom_froms"] @@ -2563,8 +2626,8 @@ class SQLCompiler(Compiled): if select_0._is_select_container: select_0 = select_0.element numcols = len(select_0.selected_columns) - # numcols = len(select_0._columns_plus_names) - if len(select._columns_plus_names) != numcols: + + if len(compile_state.columns_plus_names) != numcols: raise exc.CompileError( "All selectables passed to " "CompoundSelect must have identical numbers of " @@ -2579,14 +2642,14 @@ class SQLCompiler(Compiled): ) if asfrom and not lateral: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms.difference( asfrom_froms ), implicit_correlate_froms=(), ) else: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms, implicit_correlate_froms=asfrom_froms, ) @@ -2598,13 +2661,22 @@ class SQLCompiler(Compiled): "asfrom_froms": new_correlate_froms, "correlate_froms": all_correlate_froms, "selectable": select, + "compile_state": compile_state, } self.stack.append(new_entry) return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, toplevel, kwargs + self, + text, + select, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, ): text += ", ".join(inner_columns) @@ -2646,9 +2718,9 @@ class SQLCompiler(Compiled): else: text += self.default_from() - if select._whereclause is not None: - t = select._whereclause._compiler_dispatch( - self, from_linter=from_linter, **kwargs + if select._where_criteria: + t = self._generate_delimited_and_list( + select._where_criteria, from_linter=from_linter, **kwargs ) if t: text += " \nWHERE " + t @@ -2659,15 +2731,17 @@ class SQLCompiler(Compiled): ): from_linter.warn() - if select._group_by_clause.clauses: + if select._group_by_clauses: text += self.group_by_clause(select, **kwargs) - if select._having is not None: - t = select._having._compiler_dispatch(self, **kwargs) + if select._having_criteria: + t = self._generate_delimited_and_list( + select._having_criteria, **kwargs + ) if t: text += " \nHAVING " + t - if select._order_by_clause.clauses: + if select._order_by_clauses: text += self.order_by_clause(select, **kwargs) if ( @@ -2718,7 +2792,9 @@ class SQLCompiler(Compiled): def group_by_clause(self, select, **kw): """allow dialects to customize how GROUP BY is rendered.""" - group_by = select._group_by_clause._compiler_dispatch(self, **kw) + group_by = self._generate_delimited_list( + select._group_by_clauses, OPERATORS[operators.comma_op], **kw + ) if group_by: return " GROUP BY " + group_by else: @@ -2727,7 +2803,10 @@ class SQLCompiler(Compiled): def order_by_clause(self, select, **kw): """allow dialects to customize how ORDER BY is rendered.""" - order_by = select._order_by_clause._compiler_dispatch(self, **kw) + order_by = self._generate_delimited_list( + select._order_by_clauses, OPERATORS[operators.comma_op], **kw + ) + if order_by: return " ORDER BY " + order_by else: @@ -2826,8 +2905,8 @@ class SQLCompiler(Compiled): def visit_insert(self, insert_stmt, **kw): - compile_state = insert_stmt._compile_state_cls( - insert_stmt, self, isinsert=True, **kw + compile_state = insert_stmt._compile_state_factory( + insert_stmt, self, **kw ) insert_stmt = compile_state.statement @@ -2972,8 +3051,8 @@ class SQLCompiler(Compiled): ) def visit_update(self, update_stmt, **kw): - compile_state = update_stmt._compile_state_cls( - update_stmt, self, isupdate=True, **kw + compile_state = update_stmt._compile_state_factory( + update_stmt, self, **kw ) update_stmt = compile_state.statement @@ -3055,8 +3134,8 @@ class SQLCompiler(Compiled): text += " " + extra_from_text if update_stmt._where_criteria: - t = self._generate_delimited_list( - update_stmt._where_criteria, OPERATORS[operators.and_], **kw + t = self._generate_delimited_and_list( + update_stmt._where_criteria, **kw ) if t: text += " WHERE " + t @@ -3099,8 +3178,8 @@ class SQLCompiler(Compiled): return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) def visit_delete(self, delete_stmt, **kw): - compile_state = delete_stmt._compile_state_cls( - delete_stmt, self, isdelete=True, **kw + compile_state = delete_stmt._compile_state_factory( + delete_stmt, self, **kw ) delete_stmt = compile_state.statement @@ -3158,8 +3237,8 @@ class SQLCompiler(Compiled): text += " " + extra_from_text if delete_stmt._where_criteria: - t = self._generate_delimited_list( - delete_stmt._where_criteria, OPERATORS[operators.and_], **kw + t = self._generate_delimited_and_list( + delete_stmt._where_criteria, **kw ) if t: text += " WHERE " + t @@ -3229,7 +3308,7 @@ class StrSQLCompiler(SQLCompiler): def returning_clause(self, stmt, returning_cols): columns = [ self._label_select_column(None, c, True, False, {}) - for c in elements._select_iterables(returning_cols) + for c in base._select_iterables(returning_cols) ] return "RETURNING " + ", ".join(columns) |