diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-03-11 18:08:03 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-03-11 18:08:03 +0000 |
commit | 2aa9a8043b4982d4d7b53e8b11371ea27fccd09c (patch) | |
tree | 0fc1f0ddd3a6defdda5888ee48bd4e69bd162c4c /lib/sqlalchemy/sql | |
parent | 59ca6e5fcc6974ea1fac82d05157aa58e550b332 (diff) | |
parent | 693938dd6fb2f3ee3e031aed4c62355ac97f3ceb (diff) | |
download | sqlalchemy-2aa9a8043b4982d4d7b53e8b11371ea27fccd09c.tar.gz |
Merge "Rework select(), CompoundSelect() in terms of CompileState"
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 88 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 183 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 138 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 706 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 5 |
7 files changed, 596 insertions, 546 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index b61c7dc5e..262dc4a0e 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -47,6 +47,17 @@ class Immutable(object): pass +class SingletonConstant(Immutable): + def __new__(cls, *arg, **kw): + return cls._singleton + + @classmethod + def _create_singleton(cls): + obj = object.__new__(cls) + obj.__init__() + cls._singleton = obj + + class HasMemoized(object): def _reset_memoizations(self): self._memoized_property.expire_instance(self) @@ -60,7 +71,19 @@ class HasMemoized(object): def _from_objects(*elements): - return itertools.chain(*[element._from_objects for element in elements]) + return itertools.chain.from_iterable( + [element._from_objects for element in elements] + ) + + +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain.from_iterable( + [c._select_iterable for c in elements] + ) def _generative(fn): @@ -421,6 +444,19 @@ class CompileState(object): __slots__ = ("statement",) + @classmethod + def _create(cls, statement, compiler, **kw): + # factory construction. + + # specific CompileState classes here will look for + # "plugins" in the given statement. From there they will invoke + # the appropriate plugin constructor if one is found and return + # the alternate CompileState object. + + c = cls.__new__(cls) + c.__init__(statement, compiler, **kw) + return c + def __init__(self, statement, compiler, **kw): self.statement = statement @@ -434,11 +470,44 @@ class Generative(object): s.__dict__ = self.__dict__.copy() return s + def options(self, *options): + """Apply options to this statement. + + In the general sense, options are any kind of Python object + that can be interpreted by the SQL compiler for the statement. + These options can be consumed by specific dialects or specific kinds + of compilers. + + The most commonly known kind of option are the ORM level options + that apply "eager load" and other loading behaviors to an ORM + query. However, options can theoretically be used for many other + purposes. + + For background on specific kinds of options for specific kinds of + statements, refer to the documentation for those option objects. + + .. versionchanged:: 1.4 - added :meth:`.Generative.options` to + Core statement objects towards the goal of allowing unified + Core / ORM querying capabilities. + + .. seealso:: + + :ref:`deferred_options` - refers to options specific to the usage + of ORM queries + + :ref:`relationship_loader_options` - refers to options specific + to the usage of ORM queries + + """ + self._options += options + class HasCompileState(Generative): """A class that has a :class:`.CompileState` associated with it.""" - _compile_state_cls = CompileState + _compile_state_factory = CompileState._create + + _compile_state_plugin = None class Executable(Generative): @@ -511,6 +580,13 @@ class Executable(Generative): """ return self._execution_options + @util.deprecated_20( + ":meth:`.Executable.execute`", + alternative="All statement execution in SQLAlchemy 2.0 is performed " + "by the :meth:`.Connection.execute` method of :class:`.Connection`, " + "or in the ORM by the :meth:`.Session.execute` method of " + ":class:`.Session`.", + ) def execute(self, *multiparams, **params): """Compile and execute this :class:`.Executable`.""" e = self.bind @@ -524,6 +600,14 @@ class Executable(Generative): raise exc.UnboundExecutionError(msg) return e._execute_clauseelement(self, multiparams, params) + @util.deprecated_20( + ":meth:`.Executable.scalar`", + alternative="All statement execution in SQLAlchemy 2.0 is performed " + "by the :meth:`.Connection.execute` method of :class:`.Connection`, " + "or in the ORM by the :meth:`.Session.execute` method of " + ":class:`.Session`; the :meth:`.Result.scalar` method can then be " + "used to return a scalar result.", + ) def scalar(self, *multiparams, **params): """Compile and execute this :class:`.Executable`, returning the result's scalar representation. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b37c46216..1f183b5c1 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ import contextlib import itertools import re +from . import base from . import coercions from . import crud from . import elements @@ -1046,9 +1047,13 @@ class SQLCompiler(Compiled): self, element, within_columns_clause=False, **kwargs ): if self.stack and self.dialect.supports_simple_order_by_label: - selectable = self.stack[-1]["selectable"] + compile_state = self.stack[-1]["compile_state"] - with_cols, only_froms, only_cols = selectable._label_resolve_dict + ( + with_cols, + only_froms, + only_cols, + ) = compile_state._label_resolve_dict if within_columns_clause: resolve_dict = only_froms else: @@ -1083,8 +1088,8 @@ class SQLCompiler(Compiled): # compiling the element outside of the context of a SELECT return self.process(element._text_clause) - selectable = self.stack[-1]["selectable"] - with_cols, only_froms, only_cols = selectable._label_resolve_dict + compile_state = self.stack[-1]["compile_state"] + with_cols, only_froms, only_cols = compile_state._label_resolve_dict try: if within_columns_clause: col = only_froms[element.element] @@ -1314,6 +1319,24 @@ class SQLCompiler(Compiled): if s ) + def _generate_delimited_and_list(self, clauses, **kw): + + lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean( + operators.and_, + elements.True_._singleton, + elements.False_._singleton, + clauses, + ) + if lcc == 1: + return clauses[0]._compiler_dispatch(self, **kw) + else: + separator = OPERATORS[operators.and_] + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in clauses) + if s + ) + def visit_clauselist(self, clauselist, **kw): sep = clauselist.operator if sep is None: @@ -1474,6 +1497,12 @@ class SQLCompiler(Compiled): self, cs, asfrom=False, compound_index=0, **kwargs ): toplevel = not self.stack + + compile_state = cs._compile_state_factory(cs, self, **kwargs) + + if toplevel: + self.compile_state = compile_state + entry = self._default_stack_entry if toplevel else self.stack[-1] need_result_map = toplevel or ( compound_index == 0 @@ -1485,6 +1514,7 @@ class SQLCompiler(Compiled): "correlate_froms": entry["correlate_froms"], "asfrom_froms": entry["asfrom_froms"], "selectable": cs, + "compile_state": compile_state, "need_result_map_for_compound": need_result_map, } ) @@ -1666,7 +1696,6 @@ class SQLCompiler(Compiled): from_linter=None, **kw ): - if from_linter and operators.is_comparison(binary.operator): from_linter.edges.update( itertools.product( @@ -2274,7 +2303,6 @@ class SQLCompiler(Compiled): need_column_expressions=False, ): """produce labeled columns present in a select().""" - impl = column.type.dialect_impl(self.dialect) if impl._has_column_expression and ( @@ -2350,7 +2378,12 @@ class SQLCompiler(Compiled): or isinstance(column, functions.FunctionElement) ) ): - result_expr = _CompileLabel(col_expr, column.anon_label) + result_expr = _CompileLabel( + col_expr, + column.anon_label + if not column_is_repeated + else column._dedupe_label_anon_label, + ) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? # assert isinstance(column, elements.ColumnClause) @@ -2390,7 +2423,9 @@ class SQLCompiler(Compiled): [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] ) - def _display_froms_for_select(self, select, asfrom, lateral=False): + def _display_froms_for_select( + self, select_stmt, asfrom, lateral=False, **kw + ): # utility method to help external dialects # get the correct from list for a select. # specifically the oracle dialect needs this feature @@ -2398,18 +2433,20 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] + compile_state = select_stmt._compile_state_factory(select_stmt, self) + correlate_froms = entry["correlate_froms"] asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms.difference( asfrom_froms ), implicit_correlate_froms=(), ) else: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms, implicit_correlate_froms=asfrom_froms, ) @@ -2417,7 +2454,7 @@ class SQLCompiler(Compiled): def visit_select( self, - select, + select_stmt, asfrom=False, fromhints=None, compound_index=0, @@ -2427,7 +2464,16 @@ class SQLCompiler(Compiled): **kwargs ): + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + select_stmt = compile_state.statement + toplevel = not self.stack + + if toplevel: + self.compile_state = compile_state + entry = self._default_stack_entry if toplevel else self.stack[-1] populate_result_map = need_column_expressions = ( @@ -2446,7 +2492,7 @@ class SQLCompiler(Compiled): del kwargs["add_to_result_map"] froms = self._setup_select_stack( - select, entry, asfrom, lateral, compound_index + select_stmt, compile_state, entry, asfrom, lateral, compound_index ) column_clause_args = kwargs.copy() @@ -2456,23 +2502,25 @@ class SQLCompiler(Compiled): text = "SELECT " # we're off to a good start ! - if select._hints: - hint_text, byfrom = self._setup_select_hints(select) + if select_stmt._hints: + hint_text, byfrom = self._setup_select_hints(select_stmt) if hint_text: text += hint_text + " " else: byfrom = None - if select._prefixes: - text += self._generate_prefixes(select, select._prefixes, **kwargs) + if select_stmt._prefixes: + text += self._generate_prefixes( + select_stmt, select_stmt._prefixes, **kwargs + ) - text += self.get_select_precolumns(select, **kwargs) + text += self.get_select_precolumns(select_stmt, **kwargs) # the actual list of columns to print in the SELECT column list. inner_columns = [ c for c in [ self._label_select_column( - select, + select_stmt, column, populate_result_map, asfrom, @@ -2481,7 +2529,7 @@ class SQLCompiler(Compiled): column_is_repeated=repeated, need_column_expressions=need_column_expressions, ) - for name, column, repeated in select._columns_plus_names + for name, column, repeated in compile_state.columns_plus_names ] if c is not None ] @@ -2490,11 +2538,19 @@ class SQLCompiler(Compiled): # if this select is a compiler-generated wrapper, # rewrite the targeted columns in the result map + compile_state_wraps_for = select_wraps_for._compile_state_factory( + select_wraps_for, self, **kwargs + ) + translate = dict( zip( [ name - for (key, name, repeated) in select._columns_plus_names + for ( + key, + name, + repeated, + ) in compile_state.columns_plus_names ], [ name @@ -2502,7 +2558,7 @@ class SQLCompiler(Compiled): key, name, repeated, - ) in select_wraps_for._columns_plus_names + ) in compile_state_wraps_for.columns_plus_names ], ) ) @@ -2513,13 +2569,20 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, toplevel, kwargs + text, + select_stmt, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, ) - if select._statement_hints: + if select_stmt._statement_hints: per_dialect = [ ht - for (dialect_name, ht) in select._statement_hints + for (dialect_name, ht) in select_stmt._statement_hints if dialect_name in ("*", self.dialect.name) ] if per_dialect: @@ -2528,9 +2591,9 @@ class SQLCompiler(Compiled): if self.ctes and toplevel: text = self._render_cte_clause() + text - if select._suffixes: + if select_stmt._suffixes: text += " " + self._generate_prefixes( - select, select._suffixes, **kwargs + select_stmt, select_stmt._suffixes, **kwargs ) self.stack.pop(-1) @@ -2553,7 +2616,7 @@ class SQLCompiler(Compiled): return hint_text, byfrom def _setup_select_stack( - self, select, entry, asfrom, lateral, compound_index + self, select, compile_state, entry, asfrom, lateral, compound_index ): correlate_froms = entry["correlate_froms"] asfrom_froms = entry["asfrom_froms"] @@ -2564,8 +2627,8 @@ class SQLCompiler(Compiled): if select_0._is_select_container: select_0 = select_0.element numcols = len(select_0.selected_columns) - # numcols = len(select_0._columns_plus_names) - if len(select._columns_plus_names) != numcols: + + if len(compile_state.columns_plus_names) != numcols: raise exc.CompileError( "All selectables passed to " "CompoundSelect must have identical numbers of " @@ -2580,14 +2643,14 @@ class SQLCompiler(Compiled): ) if asfrom and not lateral: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms.difference( asfrom_froms ), implicit_correlate_froms=(), ) else: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms, implicit_correlate_froms=asfrom_froms, ) @@ -2599,13 +2662,22 @@ class SQLCompiler(Compiled): "asfrom_froms": new_correlate_froms, "correlate_froms": all_correlate_froms, "selectable": select, + "compile_state": compile_state, } self.stack.append(new_entry) return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, toplevel, kwargs + self, + text, + select, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, ): text += ", ".join(inner_columns) @@ -2647,9 +2719,9 @@ class SQLCompiler(Compiled): else: text += self.default_from() - if select._whereclause is not None: - t = select._whereclause._compiler_dispatch( - self, from_linter=from_linter, **kwargs + if select._where_criteria: + t = self._generate_delimited_and_list( + select._where_criteria, from_linter=from_linter, **kwargs ) if t: text += " \nWHERE " + t @@ -2660,15 +2732,17 @@ class SQLCompiler(Compiled): ): from_linter.warn() - if select._group_by_clause.clauses: + if select._group_by_clauses: text += self.group_by_clause(select, **kwargs) - if select._having is not None: - t = select._having._compiler_dispatch(self, **kwargs) + if select._having_criteria: + t = self._generate_delimited_and_list( + select._having_criteria, **kwargs + ) if t: text += " \nHAVING " + t - if select._order_by_clause.clauses: + if select._order_by_clauses: text += self.order_by_clause(select, **kwargs) if ( @@ -2719,7 +2793,9 @@ class SQLCompiler(Compiled): def group_by_clause(self, select, **kw): """allow dialects to customize how GROUP BY is rendered.""" - group_by = select._group_by_clause._compiler_dispatch(self, **kw) + group_by = self._generate_delimited_list( + select._group_by_clauses, OPERATORS[operators.comma_op], **kw + ) if group_by: return " GROUP BY " + group_by else: @@ -2728,7 +2804,10 @@ class SQLCompiler(Compiled): def order_by_clause(self, select, **kw): """allow dialects to customize how ORDER BY is rendered.""" - order_by = select._order_by_clause._compiler_dispatch(self, **kw) + order_by = self._generate_delimited_list( + select._order_by_clauses, OPERATORS[operators.comma_op], **kw + ) + if order_by: return " ORDER BY " + order_by else: @@ -2827,8 +2906,8 @@ class SQLCompiler(Compiled): def visit_insert(self, insert_stmt, **kw): - compile_state = insert_stmt._compile_state_cls( - insert_stmt, self, isinsert=True, **kw + compile_state = insert_stmt._compile_state_factory( + insert_stmt, self, **kw ) insert_stmt = compile_state.statement @@ -2973,8 +3052,8 @@ class SQLCompiler(Compiled): ) def visit_update(self, update_stmt, **kw): - compile_state = update_stmt._compile_state_cls( - update_stmt, self, isupdate=True, **kw + compile_state = update_stmt._compile_state_factory( + update_stmt, self, **kw ) update_stmt = compile_state.statement @@ -3056,8 +3135,8 @@ class SQLCompiler(Compiled): text += " " + extra_from_text if update_stmt._where_criteria: - t = self._generate_delimited_list( - update_stmt._where_criteria, OPERATORS[operators.and_], **kw + t = self._generate_delimited_and_list( + update_stmt._where_criteria, **kw ) if t: text += " WHERE " + t @@ -3100,8 +3179,8 @@ class SQLCompiler(Compiled): return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) def visit_delete(self, delete_stmt, **kw): - compile_state = delete_stmt._compile_state_cls( - delete_stmt, self, isdelete=True, **kw + compile_state = delete_stmt._compile_state_factory( + delete_stmt, self, **kw ) delete_stmt = compile_state.statement @@ -3159,8 +3238,8 @@ class SQLCompiler(Compiled): text += " " + extra_from_text if delete_stmt._where_criteria: - t = self._generate_delimited_list( - delete_stmt._where_criteria, OPERATORS[operators.and_], **kw + t = self._generate_delimited_and_list( + delete_stmt._where_criteria, **kw ) if t: text += " WHERE " + t @@ -3230,7 +3309,7 @@ class StrSQLCompiler(SQLCompiler): def returning_clause(self, stmt, returning_cols): columns = [ self._label_select_column(None, c, True, False, {}) - for c in elements._select_iterables(returning_cols) + for c in base._select_iterables(returning_cols) ] return "RETURNING " + ", ".join(columns) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 171a2cc2c..2349bfd03 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -37,6 +37,18 @@ class DMLState(CompileState): isdelete = False isinsert = False + @classmethod + def _create_insert(cls, statement, compiler, **kw): + return DMLState(statement, compiler, isinsert=True, **kw) + + @classmethod + def _create_update(cls, statement, compiler, **kw): + return DMLState(statement, compiler, isupdate=True, **kw) + + @classmethod + def _create_delete(cls, statement, compiler, **kw): + return DMLState(statement, compiler, isdelete=True, **kw) + def __init__( self, statement, @@ -181,8 +193,6 @@ class UpdateBase( _hints = util.immutabledict() named_with_column = False - _compile_state_cls = DMLState - @classmethod def _constructor_20_deprecations(cls, fn_name, clsname, names): @@ -717,6 +727,8 @@ class Insert(ValuesBase): _supports_multi_parameters = True + _compile_state_factory = DMLState._create_insert + select = None include_insert_from_select_defaults = False @@ -915,6 +927,8 @@ class Update(DMLWhereBase, ValuesBase): __visit_name__ = "update" + _compile_state_factory = DMLState._create_update + _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), @@ -1153,6 +1167,8 @@ class Delete(DMLWhereBase, UpdateBase): __visit_name__ = "delete" + _compile_state_factory = DMLState._create_delete + _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index bb68e8a7e..843732f67 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -31,6 +31,7 @@ from .base import HasMemoized from .base import Immutable from .base import NO_ARG from .base import PARSE_AUTOCOMMIT +from .base import SingletonConstant from .coercions import _document_text_coercion from .traversals import _copy_internals from .traversals import _get_children @@ -338,7 +339,7 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def _copy_internals(self, **kw): + def _copy_internals(self, omit_attrs=(), **kw): """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly @@ -358,12 +359,15 @@ class ClauseElement( for attrname, obj, meth in _copy_internals.run_generated_dispatch( self, traverse_internals, "_generated_copy_internals_traversal" ): + if attrname in omit_attrs: + continue + if obj is not None: result = meth(self, obj, **kw) if result is not None: setattr(self, attrname, result) - def get_children(self, omit_attrs=None, **kw): + def get_children(self, omit_attrs=(), **kw): r"""Return immediate child :class:`.Traversible` elements of this :class:`.Traversible`. @@ -385,7 +389,7 @@ class ClauseElement( for attrname, obj, meth in _get_children.run_generated_dispatch( self, traverse_internals, "_generated_get_children_traversal" ): - if obj is None or omit_attrs and attrname in omit_attrs: + if obj is None or attrname in omit_attrs: continue result.extend(meth(obj, **kw)) return result @@ -942,7 +946,8 @@ class ColumnElement( @util.memoized_property def _dedupe_label_anon_label(self): - return self._anon_label(getattr(self, "_label", "anon") + "_") + label = getattr(self, "_label", None) or "anon" + return self._anon_label(label + "_") class WrapsColumnExpression(object): @@ -1482,6 +1487,8 @@ class TextClause( ) _is_implicitly_boolean = False + _render_label_in_columns_clause = False + def __and__(self, other): # support use in select.where(), query.filter() return and_(self, other) @@ -1514,14 +1521,6 @@ class TextClause( @classmethod @util.deprecated_params( - autocommit=( - "0.6", - "The :paramref:`.text.autocommit` parameter is deprecated and " - "will be removed in a future release. Please use the " - ":paramref:`.Connection.execution_options.autocommit` parameter " - "in conjunction with the :meth:`.Executable.execution_options` " - "method.", - ), bindparams=( "0.9", "The :paramref:`.text.bindparams` parameter " @@ -1537,7 +1536,7 @@ class TextClause( ) @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`") def _create_text( - self, text, bind=None, bindparams=None, typemap=None, autocommit=None + self, text, bind=None, bindparams=None, typemap=None, ): r"""Construct a new :class:`.TextClause` clause, representing a textual SQL string directly. @@ -1614,9 +1613,6 @@ class TextClause( to specify bind parameters; they will be compiled to their engine-specific format. - :param autocommit: whether or not to set the "autocommit" execution - option for this :class:`.TextClause` object. - :param bind: an optional connection or engine to be used for this text query. @@ -1651,8 +1647,6 @@ class TextClause( stmt = stmt.bindparams(*bindparams) if typemap: stmt = stmt.columns(**typemap) - if autocommit is not None: - stmt = stmt.execution_options(autocommit=autocommit) return stmt @@ -1922,7 +1916,7 @@ class TextClause( return self -class Null(roles.ConstExprRole, ColumnElement): +class Null(SingletonConstant, roles.ConstExprRole, ColumnElement): """Represent the NULL keyword in a SQL statement. :class:`.Null` is accessed as a constant via the @@ -1945,7 +1939,10 @@ class Null(roles.ConstExprRole, ColumnElement): return Null() -class False_(roles.ConstExprRole, ColumnElement): +Null._create_singleton() + + +class False_(SingletonConstant, roles.ConstExprRole, ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. :class:`.False_` is accessed as a constant via the @@ -2002,7 +1999,10 @@ class False_(roles.ConstExprRole, ColumnElement): return False_() -class True_(roles.ConstExprRole, ColumnElement): +False_._create_singleton() + + +class True_(SingletonConstant, roles.ConstExprRole, ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. :class:`.True_` is accessed as a constant via the @@ -2067,6 +2067,9 @@ class True_(roles.ConstExprRole, ColumnElement): return True_() +True_._create_singleton() + + class ClauseList( roles.InElementRole, roles.OrderByRole, @@ -2108,6 +2111,17 @@ class ClauseList( ] self._is_implicitly_boolean = operators.is_boolean(self.operator) + @classmethod + def _construct_raw(cls, operator, clauses=None): + self = cls.__new__(cls) + self.clauses = clauses if clauses else [] + self.group = True + self.operator = operator + self.group_contents = True + self._tuple_values = False + self._is_implicitly_boolean = False + return self + def __iter__(self): return iter(self.clauses) @@ -2153,46 +2167,58 @@ class BooleanClauseList(ClauseList, ColumnElement): ) @classmethod - def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): - + def _process_clauses_for_boolean( + cls, operator, continue_on, skip_on, clauses + ): has_continue_on = None - special_elements = (continue_on, skip_on) - convert_clauses = [] - - for clause in util.coerce_generator_arg(clauses): - clause = coercions.expect(roles.WhereHavingRole, clause) - # elements that are not the continue/skip are the most - # common, try to have only one isinstance() call for that case. - if not isinstance(clause, special_elements): - convert_clauses.append(clause) - elif isinstance(clause, skip_on): - # instance of skip_on, e.g. and_(x, y, False, z), cancels - # the rest out - return clause.self_group(against=operators._asbool) - elif has_continue_on is None: + convert_clauses = [] + for clause in clauses: + if clause is continue_on: # instance of continue_on, like and_(x, y, True, z), store it # if we didn't find one already, we will use it if there # are no other expressions here. has_continue_on = clause + elif clause is skip_on: + # instance of skip_on, e.g. and_(x, y, False, z), cancels + # the rest out + convert_clauses = [clause] + break + else: + convert_clauses.append(clause) + + if not convert_clauses and has_continue_on is not None: + convert_clauses = [has_continue_on] lcc = len(convert_clauses) if lcc > 1: + against = operator + else: + against = operators._asbool + return lcc, [c.self_group(against=against) for c in convert_clauses] + + @classmethod + def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): + + lcc, convert_clauses = cls._process_clauses_for_boolean( + operator, + continue_on, + skip_on, + [ + coercions.expect(roles.WhereHavingRole, clause) + for clause in util.coerce_generator_arg(clauses) + ], + ) + + if lcc > 1: # multiple elements. Return regular BooleanClauseList # which will link elements against the operator. - return cls._construct_raw( - operator, - [c.self_group(against=operator) for c in convert_clauses], - ) + return cls._construct_raw(operator, convert_clauses) elif lcc == 1: # just one element. return it as a single boolean element, # not a list and discard the operator. - return convert_clauses[0].self_group(against=operators._asbool) - elif not lcc and has_continue_on is not None: - # no elements but we had a "continue", just return the continue - # as a boolean element, discard the operator. - return has_continue_on.self_group(against=operators._asbool) + return convert_clauses[0] else: # no elements period. deprecated use case. return an empty # ClauseList construct that generates nothing unless it has @@ -2203,7 +2229,9 @@ class BooleanClauseList(ClauseList, ColumnElement): "%(name)s() construct, use %(name)s(%(continue_on)s, *args)." % { "name": operator.__name__, - "continue_on": "True" if continue_on is True_ else "False", + "continue_on": "True" + if continue_on is True_._singleton + else "False", } ) return cls._construct_raw(operator) @@ -2277,7 +2305,9 @@ class BooleanClauseList(ClauseList, ColumnElement): :func:`.or_` """ - return cls._construct(operators.and_, True_, False_, *clauses) + return cls._construct( + operators.and_, True_._singleton, False_._singleton, *clauses + ) @classmethod def or_(cls, *clauses): @@ -2328,7 +2358,9 @@ class BooleanClauseList(ClauseList, ColumnElement): :func:`.and_` """ - return cls._construct(operators.or_, False_, True_, *clauses) + return cls._construct( + operators.or_, False_._singleton, True_._singleton, *clauses + ) @property def _select_iterable(self): @@ -4532,14 +4564,6 @@ class quoted_name(util.MemoizedSlots, util.text_type): return str.__repr__(self) -def _select_iterables(elements): - """expand tables into individual columns in the - given list of column expressions. - - """ - return itertools.chain(*[c._select_iterable for c in elements]) - - def _find_columns(clause): """locate Column objects within the given expression.""" diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 7ea0e2ad3..780648df0 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -84,6 +84,7 @@ __all__ = [ from .base import _from_objects # noqa +from .base import _select_iterables # noqa from .base import ColumnCollection # noqa from .base import Executable # noqa from .base import PARSE_AUTOCOMMIT # noqa @@ -92,7 +93,6 @@ from .dml import Insert # noqa from .dml import Update # noqa from .dml import UpdateBase # noqa from .dml import ValuesBase # noqa -from .elements import _select_iterables # noqa from .elements import _truncated_label # noqa from .elements import between # noqa from .elements import BinaryExpression # noqa diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 5536b27bc..45b9e7f9d 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -26,16 +26,18 @@ from .base import _cloned_intersection from .base import _expand_cloned from .base import _from_objects from .base import _generative +from .base import _select_iterables from .base import ColumnCollection from .base import ColumnSet +from .base import CompileState from .base import DedupeColumnCollection from .base import Executable from .base import Generative +from .base import HasCompileState from .base import HasMemoized from .base import Immutable from .coercions import _document_text_coercion from .elements import _anonymous_label -from .elements import _select_iterables from .elements import and_ from .elements import BindParameter from .elements import ClauseElement @@ -44,7 +46,6 @@ from .elements import ColumnClause from .elements import GroupedElement from .elements import Grouping from .elements import literal_column -from .elements import True_ from .elements import UnaryExpression from .visitors import InternalTraversal from .. import exc @@ -1994,54 +1995,6 @@ class ForUpdateArg(ClauseElement): ("skip_locked", InternalTraversal.dp_boolean), ] - @classmethod - def parse_legacy_select(self, arg): - """Parse the for_update argument of :func:`.select`. - - :param mode: Defines the lockmode to use. - - ``None`` - translates to no lockmode - - ``'update'`` - translates to ``FOR UPDATE`` - (standard SQL, supported by most dialects) - - ``'nowait'`` - translates to ``FOR UPDATE NOWAIT`` - (supported by Oracle, PostgreSQL 8.1 upwards) - - ``'read'`` - translates to ``LOCK IN SHARE MODE`` (for MySQL), - and ``FOR SHARE`` (for PostgreSQL) - - ``'read_nowait'`` - translates to ``FOR SHARE NOWAIT`` - (supported by PostgreSQL). ``FOR SHARE`` and - ``FOR SHARE NOWAIT`` (PostgreSQL). - - """ - if arg in (None, False): - return None - - nowait = read = False - if arg == "nowait": - nowait = True - elif arg == "read": - read = True - elif arg == "read_nowait": - read = nowait = True - elif arg is not True: - raise exc.ArgumentError("Unknown for_update argument: %r" % arg) - - return ForUpdateArg(read=read, nowait=nowait) - - @property - def legacy_for_update_value(self): - if self.read and not self.nowait: - return "read" - elif self.read and self.nowait: - return "read_nowait" - elif self.nowait: - return "nowait" - else: - return True - def __eq__(self, other): return ( isinstance(other, ForUpdateArg) @@ -2261,24 +2214,6 @@ class SelectBase( """ return Lateral._factory(self, name) - @_generative - @util.deprecated( - "0.6", - message="The :meth:`.SelectBase.autocommit` method is deprecated, " - "and will be removed in a future release. Please use the " - "the :paramref:`.Connection.execution_options.autocommit` " - "parameter in conjunction with the " - ":meth:`.Executable.execution_options` method.", - ) - def autocommit(self): - """return a new selectable with the 'autocommit' flag set to - True. - """ - - self._execution_options = self._execution_options.union( - {"autocommit": True} - ) - def _generate(self): """Override the default _generate() method to also clear out exported collections.""" @@ -2458,8 +2393,8 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """ - _order_by_clause = ClauseList() - _group_by_clause = ClauseList() + _order_by_clauses = () + _group_by_clauses = () _limit_clause = None _offset_clause = None _for_update_arg = None @@ -2467,56 +2402,25 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): def __init__( self, use_labels=False, - for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, - autocommit=None, ): self.use_labels = use_labels - if for_update is not False: - self._for_update_arg = ForUpdateArg.parse_legacy_select(for_update) - - if autocommit is not None: - util.warn_deprecated( - "The select.autocommit parameter is deprecated and will be " - "removed in a future release. Please refer to the " - "Select.execution_options.autocommit` parameter." - ) - self._execution_options = self._execution_options.union( - {"autocommit": autocommit} - ) if limit is not None: - self._limit_clause = self._offset_or_limit_clause(limit) + self.limit.non_generative(self, limit) if offset is not None: - self._offset_clause = self._offset_or_limit_clause(offset) - self._bind = bind + self.offset.non_generative(self, offset) if order_by is not None: - self._order_by_clause = ClauseList( - *util.to_list(order_by), - _literal_as_text_role=roles.OrderByRole - ) + self.order_by.non_generative(self, *util.to_list(order_by)) if group_by is not None: - self._group_by_clause = ClauseList( - *util.to_list(group_by), _literal_as_text_role=roles.ByOfRole - ) + self.group_by.non_generative(self, *util.to_list(group_by)) - @property - def for_update(self): - """Provide legacy dialect support for the ``for_update`` attribute. - """ - if self._for_update_arg is not None: - return self._for_update_arg.legacy_for_update_value - else: - return None - - @for_update.setter - def for_update(self, value): - self._for_update_arg = ForUpdateArg.parse_legacy_select(value) + self._bind = bind @_generative def with_for_update( @@ -2565,14 +2469,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): on Oracle and PostgreSQL dialects or ``FOR SHARE SKIP LOCKED`` if ``read=True`` is also specified. - .. versionadded:: 1.1.0 - :param key_share: boolean, will render ``FOR NO KEY UPDATE``, or if combined with ``read=True`` will render ``FOR KEY SHARE``, on the PostgreSQL dialect. - .. versionadded:: 1.1.0 - """ self._for_update_arg = ForUpdateArg( nowait=nowait, @@ -2596,6 +2496,20 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """ self.use_labels = True + @property + def _group_by_clause(self): + """ClauseList access to group_by_clauses for legacy dialects""" + return ClauseList._construct_raw( + operators.comma_op, self._group_by_clauses + ) + + @property + def _order_by_clause(self): + """ClauseList access to order_by_clauses for legacy dialects""" + return ClauseList._construct_raw( + operators.comma_op, self._order_by_clauses + ) + def _offset_or_limit_clause(self, element, name=None, type_=None): """Convert the given value to an "offset or limit" clause. @@ -2727,12 +2641,11 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """ if len(clauses) == 1 and clauses[0] is None: - self._order_by_clause = ClauseList() + self._order_by_clauses = () else: - if getattr(self, "_order_by_clause", None) is not None: - clauses = list(self._order_by_clause) + list(clauses) - self._order_by_clause = ClauseList( - *clauses, _literal_as_text_role=roles.OrderByRole + self._order_by_clauses += tuple( + coercions.expect(roles.OrderByRole, clause) + for clause in clauses ) @_generative @@ -2755,20 +2668,24 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """ if len(clauses) == 1 and clauses[0] is None: - self._group_by_clause = ClauseList() + self._group_by_clauses = () else: - if getattr(self, "_group_by_clause", None) is not None: - clauses = list(self._group_by_clause) + list(clauses) - self._group_by_clause = ClauseList( - *clauses, _literal_as_text_role=roles.ByOfRole + self._group_by_clauses += tuple( + coercions.expect(roles.ByOfRole, clause) for clause in clauses ) - @property + +class CompoundSelectState(CompileState): + @util.memoized_property def _label_resolve_dict(self): - raise NotImplementedError() + # TODO: this is hacky and slow + hacky_subquery = self.statement.subquery() + hacky_subquery.named_with_column = False + d = dict((c.key, c) for c in hacky_subquery.c) + return d, d, d -class CompoundSelect(GenerativeSelect): +class CompoundSelect(HasCompileState, GenerativeSelect): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -2790,13 +2707,14 @@ class CompoundSelect(GenerativeSelect): """ __visit_name__ = "compound_select" + _compile_state_factory = CompoundSelectState._create _traverse_internals = [ ("selects", InternalTraversal.dp_clauseelement_list), ("_limit_clause", InternalTraversal.dp_clauseelement), ("_offset_clause", InternalTraversal.dp_clauseelement), - ("_order_by_clause", InternalTraversal.dp_clauseelement), - ("_group_by_clause", InternalTraversal.dp_clauseelement), + ("_order_by_clauses", InternalTraversal.dp_clauseelement_list), + ("_group_by_clauses", InternalTraversal.dp_clauseelement_list), ("_for_update_arg", InternalTraversal.dp_clauseelement), ("keyword", InternalTraversal.dp_string), ] + SupportsCloneAnnotations._clone_annotations_traverse_internals @@ -2822,14 +2740,6 @@ class CompoundSelect(GenerativeSelect): GenerativeSelect.__init__(self, **kwargs) - @SelectBase._memoized_property - def _label_resolve_dict(self): - # TODO: this is hacky and slow - hacky_subquery = self.subquery() - hacky_subquery.named_with_column = False - d = dict((c.key, c) for c in hacky_subquery.c) - return d, d, d - @classmethod def _create_union(cls, *selects, **kwargs): r"""Return a ``UNION`` of multiple selectables. @@ -3001,6 +2911,7 @@ class CompoundSelect(GenerativeSelect): """ return self.selects[0].selected_columns + @property def bind(self): if self._bind: return self._bind @@ -3011,11 +2922,10 @@ class CompoundSelect(GenerativeSelect): else: return None - def _set_bind(self, bind): + @bind.setter + def bind(self, bind): self._bind = bind - bind = property(bind, _set_bind) - class DeprecatedSelectGenerations(object): @util.deprecated( @@ -3135,8 +3045,130 @@ class DeprecatedSelectGenerations(object): self.select_from.non_generative(self, fromclause) +class SelectState(CompileState): + def __init__(self, statement, compiler, **kw): + self.statement = statement + self.froms = self._get_froms(statement) + + self.columns_plus_names = statement._generate_columns_plus_names(True) + + def _get_froms(self, statement): + froms = [] + seen = set() + + for item in statement._iterate_from_elements(): + if item._is_subquery and item.element is statement: + raise exc.InvalidRequestError( + "select() construct refers to itself as a FROM" + ) + if not seen.intersection(item._cloned_set): + froms.append(item) + seen.update(item._cloned_set) + + return froms + + def _get_display_froms( + self, explicit_correlate_froms=None, implicit_correlate_froms=None + ): + """Return the full list of 'from' clauses to be displayed. + + Takes into account a set of existing froms which may be + rendered in the FROM clause of enclosing selects; this Select + may want to leave those absent if it is automatically + correlating. + + """ + froms = self.froms + + toremove = set( + itertools.chain.from_iterable( + [_expand_cloned(f._hide_froms) for f in froms] + ) + ) + if toremove: + # filter out to FROM clauses not in the list, + # using a list to maintain ordering + froms = [f for f in froms if f not in toremove] + + if self.statement._correlate: + to_correlate = self.statement._correlate + if to_correlate: + froms = [ + f + for f in froms + if f + not in _cloned_intersection( + _cloned_intersection( + froms, explicit_correlate_froms or () + ), + to_correlate, + ) + ] + + if self.statement._correlate_except is not None: + + froms = [ + f + for f in froms + if f + not in _cloned_difference( + _cloned_intersection( + froms, explicit_correlate_froms or () + ), + self.statement._correlate_except, + ) + ] + + if ( + self.statement._auto_correlate + and implicit_correlate_froms + and len(froms) > 1 + ): + + froms = [ + f + for f in froms + if f + not in _cloned_intersection(froms, implicit_correlate_froms) + ] + + if not len(froms): + raise exc.InvalidRequestError( + "Select statement '%s" + "' returned no FROM clauses " + "due to auto-correlation; " + "specify correlate(<tables>) " + "to control correlation " + "manually." % self + ) + + return froms + + @util.memoized_property + def _label_resolve_dict(self): + with_cols = dict( + (c._resolve_label or c._label or c.key, c) + for c in _select_iterables(self.statement._raw_columns) + if c._allow_label_resolve + ) + only_froms = dict( + (c.key, c) + for c in _select_iterables(self.froms) + if c._allow_label_resolve + ) + only_cols = with_cols.copy() + for key, value in only_froms.items(): + with_cols.setdefault(key, value) + + return with_cols, only_froms, only_cols + + class Select( - HasPrefixes, HasSuffixes, DeprecatedSelectGenerations, GenerativeSelect + HasPrefixes, + HasSuffixes, + HasCompileState, + DeprecatedSelectGenerations, + GenerativeSelect, ): """Represents a ``SELECT`` statement. @@ -3144,28 +3176,29 @@ class Select( __visit_name__ = "select" + _compile_state_factory = SelectState._create + _hints = util.immutabledict() _statement_hints = () _distinct = False _distinct_on = () _correlate = () _correlate_except = None + _where_criteria = () + _having_criteria = () + _from_obj = () + _auto_correlate = True + _memoized_property = SelectBase._memoized_property _traverse_internals = ( [ - ("_from_obj", InternalTraversal.dp_fromclause_ordered_set), + ("_from_obj", InternalTraversal.dp_clauseelement_list), ("_raw_columns", InternalTraversal.dp_clauseelement_list), - ("_whereclause", InternalTraversal.dp_clauseelement), - ("_having", InternalTraversal.dp_clauseelement), - ( - "_order_by_clause.clauses", - InternalTraversal.dp_clauseelement_list, - ), - ( - "_group_by_clause.clauses", - InternalTraversal.dp_clauseelement_list, - ), + ("_where_criteria", InternalTraversal.dp_clauseelement_list), + ("_having_criteria", InternalTraversal.dp_clauseelement_list), + ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,), + ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,), ("_correlate", InternalTraversal.dp_clauseelement_unordered_set), ( "_correlate_except", @@ -3220,13 +3253,6 @@ class Select( for ent in util.to_list(entities) ] - # this should all go away once Select is converted to have - # default state at the class level - self._auto_correlate = True - self._from_obj = util.OrderedSet() - self._whereclause = None - self._having = None - GenerativeSelect.__init__(self) return self @@ -3244,24 +3270,6 @@ class Select( else: return Select._create_select(*entities) - @util.deprecated_params( - autocommit=( - "0.6", - "The :paramref:`.select.autocommit` parameter is deprecated " - "and will be removed in a future release. Please refer to " - "the :paramref:`.Connection.execution_options.autocommit` " - "parameter in conjunction with the the " - ":meth:`.Executable.execution_options` method in order to " - "affect the autocommit behavior for a statement.", - ), - for_update=( - "0.9", - "The :paramref:`.select.for_update` parameter is deprecated and " - "will be removed in a future release. Please refer to the " - ":meth:`.Select.with_for_update` to specify the " - "structure of the ``FOR UPDATE`` clause.", - ), - ) def __init__( self, columns=None, @@ -3332,8 +3340,6 @@ class Select( :meth:`.Select.select_from` - full description of explicit FROM clause specification. - :param autocommit: legacy autocommit parameter. - :param bind=None: an :class:`~.Engine` or :class:`~.Connection` instance to which the @@ -3369,25 +3375,6 @@ class Select( :meth:`.Select.distinct` - :param for_update=False: - when ``True``, applies ``FOR UPDATE`` to the end of the - resulting statement. - - ``for_update`` accepts various string values interpreted by - specific backends, including: - - * ``"read"`` - on MySQL, translates to ``LOCK IN SHARE MODE``; - on PostgreSQL, translates to ``FOR SHARE``. - * ``"nowait"`` - on PostgreSQL and Oracle, translates to - ``FOR UPDATE NOWAIT``. - * ``"read_nowait"`` - on PostgreSQL, translates to - ``FOR SHARE NOWAIT``. - - .. seealso:: - - :meth:`.Select.with_for_update` - improved API for - specifying the ``FOR UPDATE`` clause. - :param group_by: a list of :class:`.ClauseElement` objects which will comprise the ``GROUP BY`` clause of the resulting select. This parameter @@ -3469,23 +3456,15 @@ class Select( ) self._auto_correlate = correlate + if distinct is not False: - self._distinct = True - if not isinstance(distinct, bool): - self._distinct_on = tuple( - [ - coercions.expect(roles.ByOfRole, e) - for e in util.to_list(distinct) - ] - ) + if distinct is True: + self.distinct.non_generative(self) + else: + self.distinct.non_generative(self, *util.to_list(distinct)) if from_obj is not None: - self._from_obj = util.OrderedSet( - coercions.expect(roles.FromClauseRole, f) - for f in util.to_list(from_obj) - ) - else: - self._from_obj = util.OrderedSet() + self.select_from.non_generative(self, *util.to_list(from_obj)) try: cols_present = bool(columns) @@ -3499,26 +3478,17 @@ class Select( ) if cols_present: - self._raw_columns = [] - for c in columns: - c = coercions.expect(roles.ColumnsClauseRole, c) - self._raw_columns.append(c) + self._raw_columns = [ + coercions.expect(roles.ColumnsClauseRole, c,) for c in columns + ] else: self._raw_columns = [] if whereclause is not None: - self._whereclause = coercions.expect( - roles.WhereHavingRole, whereclause - ).self_group(against=operators._asbool) - else: - self._whereclause = None + self.where.non_generative(self, whereclause) if having is not None: - self._having = coercions.expect( - roles.WhereHavingRole, having - ).self_group(against=operators._asbool) - else: - self._having = None + self.having.non_generative(self, having) if prefixes: self._setup_prefixes(prefixes) @@ -3528,119 +3498,27 @@ class Select( GenerativeSelect.__init__(self, **kwargs) - @property - def _froms(self): - # current roadblock to caching is two tests that test that the - # SELECT can be compiled to a string, then a Table is created against - # columns, then it can be compiled again and works. this is somewhat - # valid as people make select() against declarative class where - # columns don't have their Table yet and perhaps some operations - # call upon _froms and cache it too soon. - froms = [] - seen = set() - - for item in itertools.chain( - _from_objects(*self._raw_columns), - _from_objects(self._whereclause) - if self._whereclause is not None - else (), - self._from_obj, - ): - if item._is_subquery and item.element is self: - raise exc.InvalidRequestError( - "select() construct refers to itself as a FROM" - ) - if not seen.intersection(item._cloned_set): - froms.append(item) - seen.update(item._cloned_set) - - return froms - - def _get_display_froms( - self, explicit_correlate_froms=None, implicit_correlate_froms=None - ): - """Return the full list of 'from' clauses to be displayed. - - Takes into account a set of existing froms which may be - rendered in the FROM clause of enclosing selects; this Select - may want to leave those absent if it is automatically - correlating. - - """ - froms = self._froms - - toremove = set( - itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms]) - ) - if toremove: - # filter out to FROM clauses not in the list, - # using a list to maintain ordering - froms = [f for f in froms if f not in toremove] - - if self._correlate: - to_correlate = self._correlate - if to_correlate: - froms = [ - f - for f in froms - if f - not in _cloned_intersection( - _cloned_intersection( - froms, explicit_correlate_froms or () - ), - to_correlate, - ) - ] - - if self._correlate_except is not None: - - froms = [ - f - for f in froms - if f - not in _cloned_difference( - _cloned_intersection( - froms, explicit_correlate_froms or () - ), - self._correlate_except, - ) - ] - - if ( - self._auto_correlate - and implicit_correlate_froms - and len(froms) > 1 - ): - - froms = [ - f - for f in froms - if f - not in _cloned_intersection(froms, implicit_correlate_froms) - ] - - if not len(froms): - raise exc.InvalidRequestError( - "Select statement '%s" - "' returned no FROM clauses " - "due to auto-correlation; " - "specify correlate(<tables>) " - "to control correlation " - "manually." % self - ) - - return froms - def _scalar_type(self): elem = self._raw_columns[0] cols = list(elem._select_iterable) return cols[0].type + def _iterate_from_elements(self): + return itertools.chain( + itertools.chain.from_iterable( + [element._from_objects for element in self._raw_columns] + ), + itertools.chain.from_iterable( + [element._from_objects for element in self._where_criteria] + ), + self._from_obj, + ) + @property def froms(self): """Return the displayed list of FromClause elements.""" - return self._get_display_froms() + return self._compile_state_factory(self, None)._get_display_froms() def with_statement_hint(self, text, dialect_name="*"): """add a statement hint to this :class:`.Select`. @@ -3705,18 +3583,6 @@ class Select( else: self._hints = self._hints.union({(selectable, dialect_name): text}) - @_memoized_property.method - def locate_all_froms(self): - """return a Set of all FromClause elements referenced by this Select. - - This set is a superset of that returned by the ``froms`` property, - which is specifically for those FromClause elements that would - actually be rendered. - - """ - froms = self._froms - return froms + list(_from_objects(*froms)) - @property def inner_columns(self): """an iterator of all ColumnElement expressions which would @@ -3725,29 +3591,11 @@ class Select( """ return _select_iterables(self._raw_columns) - @_memoized_property - def _label_resolve_dict(self): - with_cols = dict( - (c._resolve_label or c._label or c.key, c) - for c in _select_iterables(self._raw_columns) - if c._allow_label_resolve - ) - only_froms = dict( - (c.key, c) - for c in _select_iterables(self.froms) - if c._allow_label_resolve - ) - only_cols = with_cols.copy() - for key, value in only_froms.items(): - with_cols.setdefault(key, value) - - return with_cols, only_froms, only_cols - def is_derived_from(self, fromclause): if self in fromclause._cloned_set: return True - for f in self.locate_all_froms(): + for f in self._iterate_from_elements(): if f.is_derived_from(fromclause): return True return False @@ -3758,40 +3606,24 @@ class Select( # objects # 1. keep a dictionary of the froms we've cloned, and what - # they've become. This is consulted later when we derive - # additional froms from "whereclause" and the columns clause, - # which may still reference the uncloned parent table. - # as of 0.7.4 we also put the current version of _froms, which - # gets cleared on each generation. previously we were "baking" - # _froms into self._from_obj. + # they've become. This allows us to ensure the same cloned from + # is used when other items such as columns are "cloned" all_the_froms = list( itertools.chain( _from_objects(*self._raw_columns), - _from_objects(self._whereclause) - if self._whereclause is not None - else (), + _from_objects(*self._where_criteria), ) ) new_froms = {f: clone(f, **kw) for f in all_the_froms} - # copy FROM collections - self._from_obj = util.OrderedSet( - clone(f, **kw) for f in self._from_obj - ).union(f for f in new_froms.values() if isinstance(f, Join)) - - self._correlate = set(clone(f, **kw) for f in self._correlate) - if self._correlate_except: - self._correlate_except = set( - clone(f, **kw) for f in self._correlate_except - ) + # 2. copy FROM collections. + self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple( + f for f in new_froms.values() if isinstance(f, Join) + ) - # 4. clone other things. The difficulty here is that Column - # objects are usually not altered by a straight clone because they - # are dependent on the FROM cloning we just did above in order to - # be targeted correctly, or a new FROM we have might be a JOIN - # object which doesn't have its own columns. so give the cloner a - # hint. + # 3. clone everything else, making sure we use columns + # corresponding to the froms we just made. def replace(obj, **kw): if isinstance(obj, ColumnClause) and obj.table in new_froms: newelem = new_froms[obj.table].corresponding_column(obj) @@ -3799,25 +3631,16 @@ class Select( kw["replace"] = replace - # TODO: I'd still like to try to leverage the traversal data - self._raw_columns = [clone(c, **kw) for c in self._raw_columns] - for attr in ( - "_limit_clause", - "_offset_clause", - "_whereclause", - "_having", - "_order_by_clause", - "_group_by_clause", - "_for_update_arg", - ): - if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr), **kw)) + super(Select, self)._copy_internals( + clone=clone, omit_attrs=("_from_obj",), **kw + ) self._reset_memoizations() def get_children(self, **kwargs): - # TODO: define "get_children" traversal items separately? - return self._froms + super(Select, self).get_children( + return list(set(self._iterate_from_elements())) + super( + Select, self + ).get_children( omit_attrs=["_from_obj", "_correlate", "_correlate_except"] ) @@ -3838,7 +3661,7 @@ class Select( self._reset_memoizations() self._raw_columns = self._raw_columns + [ - coercions.expect(roles.ColumnsClauseRole, column) + coercions.expect(roles.ColumnsClauseRole, column,) for column in columns ] @@ -3889,7 +3712,7 @@ class Select( util.preloaded.sql_util.reduce_columns( self.inner_columns, only_synonyms=only_synonyms, - *(self._whereclause,) + tuple(self._from_obj) + *(self._where_criteria + self._from_obj) ) ) @@ -3965,12 +3788,18 @@ class Select( self._reset_memoizations() rc = [] for c in columns: - c = coercions.expect(roles.ColumnsClauseRole, c) + c = coercions.expect(roles.ColumnsClauseRole, c,) if isinstance(c, ScalarSelect): c = c.self_group(against=operators.comma_op) rc.append(c) self._raw_columns = rc + @property + def _whereclause(self): + """Legacy, return the WHERE clause as a :class:`.BooleanClauseList`""" + + return and_(*self._where_criteria) + @_generative def where(self, whereclause): """return a new select() construct with the given expression added to @@ -3978,8 +3807,9 @@ class Select( """ - self._reset_memoizations() - self._whereclause = and_(True_._ifnone(self._whereclause), whereclause) + self._where_criteria += ( + coercions.expect(roles.WhereHavingRole, whereclause), + ) @_generative def having(self, having): @@ -3987,8 +3817,9 @@ class Select( its HAVING clause, joined to the existing clause via AND, if any. """ - self._reset_memoizations() - self._having = and_(True_._ifnone(self._having), having) + self._having_criteria += ( + coercions.expect(roles.WhereHavingRole, having), + ) @_generative def distinct(self, *expr): @@ -4001,16 +3832,17 @@ class Select( """ if expr: - expr = [coercions.expect(roles.ByOfRole, e) for e in expr] self._distinct = True - self._distinct_on = self._distinct_on + tuple(expr) + self._distinct_on = self._distinct_on + tuple( + coercions.expect(roles.ByOfRole, e) for e in expr + ) else: self._distinct = True @_generative - def select_from(self, fromclause): + def select_from(self, *froms): r"""return a new :func:`.select` construct with the - given FROM expression + given FROM expression(s) merged into its list of FROM objects. E.g.:: @@ -4039,9 +3871,11 @@ class Select( select([func.count('*')]).select_from(table1) """ - self._reset_memoizations() - fromclause = coercions.expect(roles.FromClauseRole, fromclause) - self._from_obj = self._from_obj.union([fromclause]) + + self._from_obj += tuple( + coercions.expect(roles.FromClauseRole, fromclause) + for fromclause in froms + ) @_generative def correlate(self, *fromclauses): @@ -4198,12 +4032,17 @@ class Select( names = {} def name_for_col(c): - if c._label is None or not c._render_label_in_columns_clause: + if not c._render_label_in_columns_clause: return (None, c, False) + elif c._label is None: + repeated = c.anon_label in names + names[c.anon_label] = c + return (None, c, repeated) - repeated = False name = c._label + repeated = False + if name in names: # when looking to see if names[name] is the same column as # c, use hash(), so that an annotated version of the column @@ -4242,13 +4081,11 @@ class Select( return [name_for_col(c) for c in cols] else: # repeated name logic only for use labels at the moment - return [(None, c, False) for c in cols] - - @_memoized_property - def _columns_plus_names(self): - """generate label names plus columns to render in a SELECT.""" + same_cols = set() - return self._generate_columns_plus_names(True) + return [ + (None, c, c in same_cols or same_cols.add(c)) for c in cols + ] def _generate_fromclause_column_proxies(self, subquery): """generate column proxies to place in the exported .c collection @@ -4340,29 +4177,34 @@ class Select( """ return CompoundSelect._create_intersect_all(self, other, **kwargs) + @property def bind(self): if self._bind: return self._bind - froms = self._froms - if not froms: - for c in self._raw_columns: - e = c.bind - if e: - self._bind = e - return e - else: - e = list(froms)[0].bind + + for item in self._iterate_from_elements(): + if item._is_subquery and item.element is self: + raise exc.InvalidRequestError( + "select() construct refers to itself as a FROM" + ) + + e = item.bind if e: self._bind = e return e + else: + break - return None + for c in self._raw_columns: + e = c.bind + if e: + self._bind = e + return e - def _set_bind(self, bind): + @bind.setter + def bind(self, bind): self._bind = bind - bind = property(bind, _set_bind) - class ScalarSelect(roles.InElementRole, Generative, Grouping): _from_objects = [] diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 2a4f7ebb6..1fcc2d023 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -459,6 +459,11 @@ class _CopyInternals(InternalTraversal): def visit_clauseelement_list(self, parent, element, clone=_clone, **kw): return [clone(clause, **kw) for clause in element] + def visit_clauseelement_unordered_set( + self, parent, element, clone=_clone, **kw + ): + return {clone(clause, **kw) for clause in element} + def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw): return [ tuple(clone(tup_elem, **kw) for tup_elem in elem) |