summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-03-06 16:04:46 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2020-03-10 16:55:03 -0400
commit693938dd6fb2f3ee3e031aed4c62355ac97f3ceb (patch)
tree94701d7df1b7274151800efd6ca996e1f4203916 /lib/sqlalchemy/sql/compiler.py
parent851fb8f5a661c66ee76308181118369c8c4df9e0 (diff)
downloadsqlalchemy-693938dd6fb2f3ee3e031aed4c62355ac97f3ceb.tar.gz
Rework select(), CompoundSelect() in terms of CompileState
Continuation of I408e0b8be91fddd77cf279da97f55020871f75a9 - add an options() method to the base Generative construct. this will be where ORM options can go - Change Null, False_, True_ to be singletons, so that we aren't instantiating them and having to use isinstance. The previous issue with this was that they would produce dupe labels in SELECT statements. Apply the duplicate column logic, newly added in 1.4, to these objects as well as to non-apply-labels SELECT statements in general as a means of improving this. - create a revised system for generating ClauseList compilation constructs that simplfies up front creation to not actually use ClauseList; a simple tuple is rendered by the compiler using the same constrcution rules as what are used for ClauseList but without creating the actual object. Apply to Select, CompoundSelect, revise Update, Delete - Select, CompoundSelect get an initial CompileState implementation. All methods used only within compilation are moved here - refine update/insert/delete compile state to not require an outside boolean - refine and simplify Select._copy_internals - rework bind(), which is going away, to not use some of the internal traversal stuff - remove "autocommit", "for_update" parameters from Select, references #4643 - remove "autocommit" parameter from TextClause , references #4643 - add deprecation warnings for statement.execute(), engine.execute(), statement.scalar(), engine.scalar(). Fixes: #5193 Change-Id: I04ca0152b046fd42c5054ba10f37e43fc6e5a57b
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py183
1 files changed, 131 insertions, 52 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 3ebcf24b0..c39f59e32 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -28,6 +28,7 @@ import contextlib
import itertools
import re
+from . import base
from . import coercions
from . import crud
from . import elements
@@ -1045,9 +1046,13 @@ class SQLCompiler(Compiled):
self, element, within_columns_clause=False, **kwargs
):
if self.stack and self.dialect.supports_simple_order_by_label:
- selectable = self.stack[-1]["selectable"]
+ compile_state = self.stack[-1]["compile_state"]
- with_cols, only_froms, only_cols = selectable._label_resolve_dict
+ (
+ with_cols,
+ only_froms,
+ only_cols,
+ ) = compile_state._label_resolve_dict
if within_columns_clause:
resolve_dict = only_froms
else:
@@ -1082,8 +1087,8 @@ class SQLCompiler(Compiled):
# compiling the element outside of the context of a SELECT
return self.process(element._text_clause)
- selectable = self.stack[-1]["selectable"]
- with_cols, only_froms, only_cols = selectable._label_resolve_dict
+ compile_state = self.stack[-1]["compile_state"]
+ with_cols, only_froms, only_cols = compile_state._label_resolve_dict
try:
if within_columns_clause:
col = only_froms[element.element]
@@ -1313,6 +1318,24 @@ class SQLCompiler(Compiled):
if s
)
+ def _generate_delimited_and_list(self, clauses, **kw):
+
+ lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
+ operators.and_,
+ elements.True_._singleton,
+ elements.False_._singleton,
+ clauses,
+ )
+ if lcc == 1:
+ return clauses[0]._compiler_dispatch(self, **kw)
+ else:
+ separator = OPERATORS[operators.and_]
+ return separator.join(
+ s
+ for s in (c._compiler_dispatch(self, **kw) for c in clauses)
+ if s
+ )
+
def visit_clauselist(self, clauselist, **kw):
sep = clauselist.operator
if sep is None:
@@ -1473,6 +1496,12 @@ class SQLCompiler(Compiled):
self, cs, asfrom=False, compound_index=0, **kwargs
):
toplevel = not self.stack
+
+ compile_state = cs._compile_state_factory(cs, self, **kwargs)
+
+ if toplevel:
+ self.compile_state = compile_state
+
entry = self._default_stack_entry if toplevel else self.stack[-1]
need_result_map = toplevel or (
compound_index == 0
@@ -1484,6 +1513,7 @@ class SQLCompiler(Compiled):
"correlate_froms": entry["correlate_froms"],
"asfrom_froms": entry["asfrom_froms"],
"selectable": cs,
+ "compile_state": compile_state,
"need_result_map_for_compound": need_result_map,
}
)
@@ -1665,7 +1695,6 @@ class SQLCompiler(Compiled):
from_linter=None,
**kw
):
-
if from_linter and operators.is_comparison(binary.operator):
from_linter.edges.update(
itertools.product(
@@ -2273,7 +2302,6 @@ class SQLCompiler(Compiled):
need_column_expressions=False,
):
"""produce labeled columns present in a select()."""
-
impl = column.type.dialect_impl(self.dialect)
if impl._has_column_expression and (
@@ -2349,7 +2377,12 @@ class SQLCompiler(Compiled):
or isinstance(column, functions.FunctionElement)
)
):
- result_expr = _CompileLabel(col_expr, column.anon_label)
+ result_expr = _CompileLabel(
+ col_expr,
+ column.anon_label
+ if not column_is_repeated
+ else column._dedupe_label_anon_label,
+ )
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
# assert isinstance(column, elements.ColumnClause)
@@ -2389,7 +2422,9 @@ class SQLCompiler(Compiled):
[("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
)
- def _display_froms_for_select(self, select, asfrom, lateral=False):
+ def _display_froms_for_select(
+ self, select_stmt, asfrom, lateral=False, **kw
+ ):
# utility method to help external dialects
# get the correct from list for a select.
# specifically the oracle dialect needs this feature
@@ -2397,18 +2432,20 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
+ compile_state = select_stmt._compile_state_factory(select_stmt, self)
+
correlate_froms = entry["correlate_froms"]
asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
- froms = select._get_display_froms(
+ froms = compile_state._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
asfrom_froms
),
implicit_correlate_froms=(),
)
else:
- froms = select._get_display_froms(
+ froms = compile_state._get_display_froms(
explicit_correlate_froms=correlate_froms,
implicit_correlate_froms=asfrom_froms,
)
@@ -2416,7 +2453,7 @@ class SQLCompiler(Compiled):
def visit_select(
self,
- select,
+ select_stmt,
asfrom=False,
fromhints=None,
compound_index=0,
@@ -2426,7 +2463,16 @@ class SQLCompiler(Compiled):
**kwargs
):
+ compile_state = select_stmt._compile_state_factory(
+ select_stmt, self, **kwargs
+ )
+ select_stmt = compile_state.statement
+
toplevel = not self.stack
+
+ if toplevel:
+ self.compile_state = compile_state
+
entry = self._default_stack_entry if toplevel else self.stack[-1]
populate_result_map = need_column_expressions = (
@@ -2445,7 +2491,7 @@ class SQLCompiler(Compiled):
del kwargs["add_to_result_map"]
froms = self._setup_select_stack(
- select, entry, asfrom, lateral, compound_index
+ select_stmt, compile_state, entry, asfrom, lateral, compound_index
)
column_clause_args = kwargs.copy()
@@ -2455,23 +2501,25 @@ class SQLCompiler(Compiled):
text = "SELECT " # we're off to a good start !
- if select._hints:
- hint_text, byfrom = self._setup_select_hints(select)
+ if select_stmt._hints:
+ hint_text, byfrom = self._setup_select_hints(select_stmt)
if hint_text:
text += hint_text + " "
else:
byfrom = None
- if select._prefixes:
- text += self._generate_prefixes(select, select._prefixes, **kwargs)
+ if select_stmt._prefixes:
+ text += self._generate_prefixes(
+ select_stmt, select_stmt._prefixes, **kwargs
+ )
- text += self.get_select_precolumns(select, **kwargs)
+ text += self.get_select_precolumns(select_stmt, **kwargs)
# the actual list of columns to print in the SELECT column list.
inner_columns = [
c
for c in [
self._label_select_column(
- select,
+ select_stmt,
column,
populate_result_map,
asfrom,
@@ -2480,7 +2528,7 @@ class SQLCompiler(Compiled):
column_is_repeated=repeated,
need_column_expressions=need_column_expressions,
)
- for name, column, repeated in select._columns_plus_names
+ for name, column, repeated in compile_state.columns_plus_names
]
if c is not None
]
@@ -2489,11 +2537,19 @@ class SQLCompiler(Compiled):
# if this select is a compiler-generated wrapper,
# rewrite the targeted columns in the result map
+ compile_state_wraps_for = select_wraps_for._compile_state_factory(
+ select_wraps_for, self, **kwargs
+ )
+
translate = dict(
zip(
[
name
- for (key, name, repeated) in select._columns_plus_names
+ for (
+ key,
+ name,
+ repeated,
+ ) in compile_state.columns_plus_names
],
[
name
@@ -2501,7 +2557,7 @@ class SQLCompiler(Compiled):
key,
name,
repeated,
- ) in select_wraps_for._columns_plus_names
+ ) in compile_state_wraps_for.columns_plus_names
],
)
)
@@ -2512,13 +2568,20 @@ class SQLCompiler(Compiled):
]
text = self._compose_select_body(
- text, select, inner_columns, froms, byfrom, toplevel, kwargs
+ text,
+ select_stmt,
+ compile_state,
+ inner_columns,
+ froms,
+ byfrom,
+ toplevel,
+ kwargs,
)
- if select._statement_hints:
+ if select_stmt._statement_hints:
per_dialect = [
ht
- for (dialect_name, ht) in select._statement_hints
+ for (dialect_name, ht) in select_stmt._statement_hints
if dialect_name in ("*", self.dialect.name)
]
if per_dialect:
@@ -2527,9 +2590,9 @@ class SQLCompiler(Compiled):
if self.ctes and toplevel:
text = self._render_cte_clause() + text
- if select._suffixes:
+ if select_stmt._suffixes:
text += " " + self._generate_prefixes(
- select, select._suffixes, **kwargs
+ select_stmt, select_stmt._suffixes, **kwargs
)
self.stack.pop(-1)
@@ -2552,7 +2615,7 @@ class SQLCompiler(Compiled):
return hint_text, byfrom
def _setup_select_stack(
- self, select, entry, asfrom, lateral, compound_index
+ self, select, compile_state, entry, asfrom, lateral, compound_index
):
correlate_froms = entry["correlate_froms"]
asfrom_froms = entry["asfrom_froms"]
@@ -2563,8 +2626,8 @@ class SQLCompiler(Compiled):
if select_0._is_select_container:
select_0 = select_0.element
numcols = len(select_0.selected_columns)
- # numcols = len(select_0._columns_plus_names)
- if len(select._columns_plus_names) != numcols:
+
+ if len(compile_state.columns_plus_names) != numcols:
raise exc.CompileError(
"All selectables passed to "
"CompoundSelect must have identical numbers of "
@@ -2579,14 +2642,14 @@ class SQLCompiler(Compiled):
)
if asfrom and not lateral:
- froms = select._get_display_froms(
+ froms = compile_state._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
asfrom_froms
),
implicit_correlate_froms=(),
)
else:
- froms = select._get_display_froms(
+ froms = compile_state._get_display_froms(
explicit_correlate_froms=correlate_froms,
implicit_correlate_froms=asfrom_froms,
)
@@ -2598,13 +2661,22 @@ class SQLCompiler(Compiled):
"asfrom_froms": new_correlate_froms,
"correlate_froms": all_correlate_froms,
"selectable": select,
+ "compile_state": compile_state,
}
self.stack.append(new_entry)
return froms
def _compose_select_body(
- self, text, select, inner_columns, froms, byfrom, toplevel, kwargs
+ self,
+ text,
+ select,
+ compile_state,
+ inner_columns,
+ froms,
+ byfrom,
+ toplevel,
+ kwargs,
):
text += ", ".join(inner_columns)
@@ -2646,9 +2718,9 @@ class SQLCompiler(Compiled):
else:
text += self.default_from()
- if select._whereclause is not None:
- t = select._whereclause._compiler_dispatch(
- self, from_linter=from_linter, **kwargs
+ if select._where_criteria:
+ t = self._generate_delimited_and_list(
+ select._where_criteria, from_linter=from_linter, **kwargs
)
if t:
text += " \nWHERE " + t
@@ -2659,15 +2731,17 @@ class SQLCompiler(Compiled):
):
from_linter.warn()
- if select._group_by_clause.clauses:
+ if select._group_by_clauses:
text += self.group_by_clause(select, **kwargs)
- if select._having is not None:
- t = select._having._compiler_dispatch(self, **kwargs)
+ if select._having_criteria:
+ t = self._generate_delimited_and_list(
+ select._having_criteria, **kwargs
+ )
if t:
text += " \nHAVING " + t
- if select._order_by_clause.clauses:
+ if select._order_by_clauses:
text += self.order_by_clause(select, **kwargs)
if (
@@ -2718,7 +2792,9 @@ class SQLCompiler(Compiled):
def group_by_clause(self, select, **kw):
"""allow dialects to customize how GROUP BY is rendered."""
- group_by = select._group_by_clause._compiler_dispatch(self, **kw)
+ group_by = self._generate_delimited_list(
+ select._group_by_clauses, OPERATORS[operators.comma_op], **kw
+ )
if group_by:
return " GROUP BY " + group_by
else:
@@ -2727,7 +2803,10 @@ class SQLCompiler(Compiled):
def order_by_clause(self, select, **kw):
"""allow dialects to customize how ORDER BY is rendered."""
- order_by = select._order_by_clause._compiler_dispatch(self, **kw)
+ order_by = self._generate_delimited_list(
+ select._order_by_clauses, OPERATORS[operators.comma_op], **kw
+ )
+
if order_by:
return " ORDER BY " + order_by
else:
@@ -2826,8 +2905,8 @@ class SQLCompiler(Compiled):
def visit_insert(self, insert_stmt, **kw):
- compile_state = insert_stmt._compile_state_cls(
- insert_stmt, self, isinsert=True, **kw
+ compile_state = insert_stmt._compile_state_factory(
+ insert_stmt, self, **kw
)
insert_stmt = compile_state.statement
@@ -2972,8 +3051,8 @@ class SQLCompiler(Compiled):
)
def visit_update(self, update_stmt, **kw):
- compile_state = update_stmt._compile_state_cls(
- update_stmt, self, isupdate=True, **kw
+ compile_state = update_stmt._compile_state_factory(
+ update_stmt, self, **kw
)
update_stmt = compile_state.statement
@@ -3055,8 +3134,8 @@ class SQLCompiler(Compiled):
text += " " + extra_from_text
if update_stmt._where_criteria:
- t = self._generate_delimited_list(
- update_stmt._where_criteria, OPERATORS[operators.and_], **kw
+ t = self._generate_delimited_and_list(
+ update_stmt._where_criteria, **kw
)
if t:
text += " WHERE " + t
@@ -3099,8 +3178,8 @@ class SQLCompiler(Compiled):
return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
def visit_delete(self, delete_stmt, **kw):
- compile_state = delete_stmt._compile_state_cls(
- delete_stmt, self, isdelete=True, **kw
+ compile_state = delete_stmt._compile_state_factory(
+ delete_stmt, self, **kw
)
delete_stmt = compile_state.statement
@@ -3158,8 +3237,8 @@ class SQLCompiler(Compiled):
text += " " + extra_from_text
if delete_stmt._where_criteria:
- t = self._generate_delimited_list(
- delete_stmt._where_criteria, OPERATORS[operators.and_], **kw
+ t = self._generate_delimited_and_list(
+ delete_stmt._where_criteria, **kw
)
if t:
text += " WHERE " + t
@@ -3229,7 +3308,7 @@ class StrSQLCompiler(SQLCompiler):
def returning_clause(self, stmt, returning_cols):
columns = [
self._label_select_column(None, c, True, False, {})
- for c in elements._select_iterables(returning_cols)
+ for c in base._select_iterables(returning_cols)
]
return "RETURNING " + ", ".join(columns)