diff options
Diffstat (limited to 'lib/sqlalchemy/sql/selectable.py')
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 1868 |
1 files changed, 1273 insertions, 595 deletions
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 4f6e3795e..6504449f1 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -17,16 +17,26 @@ import collections from enum import Enum import itertools import typing +from typing import AbstractSet from typing import Any as TODO_Any from typing import Any +from typing import Callable +from typing import cast +from typing import Dict from typing import Iterable +from typing import Iterator from typing import List from typing import NamedTuple +from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import cache_key from . import coercions @@ -37,6 +47,9 @@ from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument from ._typing import is_column_element +from ._typing import is_select_statement +from ._typing import is_subquery +from ._typing import is_table from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -68,32 +81,80 @@ from .elements import ColumnClause from .elements import ColumnElement from .elements import DQLDMLClauseElement from .elements import GroupedElement -from .elements import Grouping from .elements import literal_column from .elements import TableValuedColumn from .elements import UnaryExpression +from .operators import OperatorType +from .visitors import _TraverseInternalsType from .visitors import InternalTraversal from .visitors import prefix_anon_map from .. import exc from .. import util +from ..util import HasMemoized_ro_memoized_attribute +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import Self and_ = BooleanClauseList.and_ _T = TypeVar("_T", bound=Any) if TYPE_CHECKING: - from ._typing import _SelectIterable + from ._typing import _ColumnExpressionArgument + from ._typing import _FromClauseArgument + from ._typing import _JoinTargetArgument + from ._typing import _OnClauseArgument + from ._typing import _SelectStatementForCompoundArgument + from ._typing import _TextCoercedExpressionArgument + from ._typing import _TypeEngineArgument + from .base import _AmbiguousTableNameMap + from .base import ExecutableOption from .base import ReadOnlyColumnCollection + from .cache_key import _CacheKeyTraversalType + from .compiler import SQLCompiler + from .dml import Delete + from .dml import Insert + from .dml import Update from .elements import NamedColumn + from .elements import TextClause + from .functions import Function + from .schema import Column from .schema import ForeignKey - from .schema import PrimaryKeyConstraint + from .schema import ForeignKeyConstraint + from .type_api import TypeEngine + from .util import ClauseAdapter + from .visitors import _CloneCallableType -class _OffsetLimitParam(BindParameter): +_ColumnsClauseElement = Union["FromClause", ColumnElement[Any], "TextClause"] + + +class _JoinTargetProtocol(Protocol): + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + ... + + +_JoinTargetElement = Union["FromClause", _JoinTargetProtocol] +_OnClauseElement = Union["ColumnElement[bool]", _JoinTargetProtocol] + + +_SetupJoinsElement = Tuple[ + _JoinTargetElement, + Optional[_OnClauseElement], + Optional["FromClause"], + Dict[str, Any], +] + + +_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]] + + +class _OffsetLimitParam(BindParameter[int]): inherit_cache = True @property - def _limit_offset_value(self): + def _limit_offset_value(self) -> Optional[int]: return self.effective_value @@ -114,11 +175,12 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): # sub-elements of returns_rows _is_from_clause = False + _is_select_base = False _is_select_statement = False _is_lateral = False @property - def selectable(self): + def selectable(self) -> ReturnsRows: return self @util.non_memoized_property @@ -133,8 +195,28 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): """ raise NotImplementedError() + def is_derived_from(self, fromclause: FromClause) -> bool: + """Return ``True`` if this :class:`.ReturnsRows` is + 'derived' from the given :class:`.FromClause`. + + An example would be an Alias of a Table is derived from that Table. + + """ + raise NotImplementedError() + + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: + """Populate columns into an :class:`.AliasedReturnsRows` object.""" + + raise NotImplementedError() + + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: + """reset internal collections for an incoming column being added.""" + raise NotImplementedError() + @property - def exported_columns(self): + def exported_columns(self) -> ReadOnlyColumnCollection[Any, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.ReturnsRows`. @@ -160,6 +242,9 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): raise NotImplementedError() +SelfSelectable = TypeVar("SelfSelectable", bound="Selectable") + + class Selectable(ReturnsRows): """Mark a class as being selectable.""" @@ -167,10 +252,10 @@ class Selectable(ReturnsRows): is_selectable = True - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: raise NotImplementedError() - def lateral(self, name=None): + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a LATERAL alias of this :class:`_expression.Selectable`. The return value is the :class:`_expression.Lateral` construct also @@ -192,15 +277,21 @@ class Selectable(ReturnsRows): "functionality is available via the sqlalchemy.sql.visitors module.", ) @util.preload_module("sqlalchemy.sql.util") - def replace_selectable(self, old, alias): + def replace_selectable( + self: SelfSelectable, old: FromClause, alias: Alias + ) -> SelfSelectable: """Replace all occurrences of :class:`_expression.FromClause` 'old' with the given :class:`_expression.Alias` object, returning a copy of this :class:`_expression.FromClause`. """ - return util.preloaded.sql_util.ClauseAdapter(alias).traverse(self) + return util.preloaded.sql_util.ClauseAdapter(alias).traverse( # type: ignore # noqa E501 + self + ) - def corresponding_column(self, column, require_embedded=False): + def corresponding_column( + self, column: ColumnElement[Any], require_embedded: bool = False + ) -> Optional[ColumnElement[Any]]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from the :attr:`_expression.Selectable.exported_columns` @@ -242,19 +333,23 @@ SelfHasPrefixes = typing.TypeVar("SelfHasPrefixes", bound="HasPrefixes") class HasPrefixes: - _prefixes = () + _prefixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = () - _has_prefixes_traverse_internals = [ + _has_prefixes_traverse_internals: _TraverseInternalsType = [ ("_prefixes", InternalTraversal.dp_prefix_sequence) ] @_generative @_document_text_coercion( - "expr", + "prefixes", ":meth:`_expression.HasPrefixes.prefix_with`", - ":paramref:`.HasPrefixes.prefix_with.*expr`", + ":paramref:`.HasPrefixes.prefix_with.*prefixes`", ) - def prefix_with(self: SelfHasPrefixes, *expr, **kw) -> SelfHasPrefixes: + def prefix_with( + self: SelfHasPrefixes, + *prefixes: _TextCoercedExpressionArgument[Any], + dialect: str = "*", + ) -> SelfHasPrefixes: r"""Add one or more expressions following the statement keyword, i.e. SELECT, INSERT, UPDATE, or DELETE. Generative. @@ -272,49 +367,44 @@ class HasPrefixes: Multiple prefixes can be specified by multiple calls to :meth:`_expression.HasPrefixes.prefix_with`. - :param \*expr: textual or :class:`_expression.ClauseElement` + :param \*prefixes: textual or :class:`_expression.ClauseElement` construct which will be rendered following the INSERT, UPDATE, or DELETE keyword. - :param \**kw: A single keyword 'dialect' is accepted. This is an - optional string dialect name which will + :param dialect: optional string dialect name which will limit rendering of this prefix to only that dialect. """ - dialect = kw.pop("dialect", None) - if kw: - raise exc.ArgumentError( - "Unsupported argument(s): %s" % ",".join(kw) - ) - self._setup_prefixes(expr, dialect) - return self - - def _setup_prefixes(self, prefixes, dialect=None): self._prefixes = self._prefixes + tuple( [ (coercions.expect(roles.StatementOptionRole, p), dialect) for p in prefixes ] ) + return self SelfHasSuffixes = typing.TypeVar("SelfHasSuffixes", bound="HasSuffixes") class HasSuffixes: - _suffixes = () + _suffixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = () - _has_suffixes_traverse_internals = [ + _has_suffixes_traverse_internals: _TraverseInternalsType = [ ("_suffixes", InternalTraversal.dp_prefix_sequence) ] @_generative @_document_text_coercion( - "expr", + "suffixes", ":meth:`_expression.HasSuffixes.suffix_with`", - ":paramref:`.HasSuffixes.suffix_with.*expr`", + ":paramref:`.HasSuffixes.suffix_with.*suffixes`", ) - def suffix_with(self: SelfHasSuffixes, *expr, **kw) -> SelfHasSuffixes: + def suffix_with( + self: SelfHasSuffixes, + *suffixes: _TextCoercedExpressionArgument[Any], + dialect: str = "*", + ) -> SelfHasSuffixes: r"""Add one or more expressions following the statement as a whole. This is used to support backend-specific suffix keywords on @@ -328,44 +418,39 @@ class HasSuffixes: Multiple suffixes can be specified by multiple calls to :meth:`_expression.HasSuffixes.suffix_with`. - :param \*expr: textual or :class:`_expression.ClauseElement` + :param \*suffixes: textual or :class:`_expression.ClauseElement` construct which will be rendered following the target clause. - :param \**kw: A single keyword 'dialect' is accepted. This is an - optional string dialect name which will + :param dialect: Optional string dialect name which will limit rendering of this suffix to only that dialect. """ - dialect = kw.pop("dialect", None) - if kw: - raise exc.ArgumentError( - "Unsupported argument(s): %s" % ",".join(kw) - ) - self._setup_suffixes(expr, dialect) - return self - - def _setup_suffixes(self, suffixes, dialect=None): self._suffixes = self._suffixes + tuple( [ (coercions.expect(roles.StatementOptionRole, p), dialect) for p in suffixes ] ) + return self SelfHasHints = typing.TypeVar("SelfHasHints", bound="HasHints") class HasHints: - _hints = util.immutabledict() - _statement_hints = () + _hints: util.immutabledict[ + Tuple[FromClause, str], str + ] = util.immutabledict() + _statement_hints: Tuple[Tuple[str, str], ...] = () - _has_hints_traverse_internals = [ + _has_hints_traverse_internals: _TraverseInternalsType = [ ("_statement_hints", InternalTraversal.dp_statement_hint_list), ("_hints", InternalTraversal.dp_table_hint_list), ] - def with_statement_hint(self, text, dialect_name="*"): + def with_statement_hint( + self: SelfHasHints, text: str, dialect_name: str = "*" + ) -> SelfHasHints: """Add a statement hint to this :class:`_expression.Select` or other selectable object. @@ -389,11 +474,14 @@ class HasHints: MySQL optimizer hints """ - return self.with_hint(None, text, dialect_name) + return self._with_hint(None, text, dialect_name) @_generative def with_hint( - self: SelfHasHints, selectable, text, dialect_name="*" + self: SelfHasHints, + selectable: _FromClauseArgument, + text: str, + dialect_name: str = "*", ) -> SelfHasHints: r"""Add an indexing or other executional context hint for the given selectable to this :class:`_expression.Select` or other selectable @@ -429,6 +517,15 @@ class HasHints: :meth:`_expression.Select.with_statement_hint` """ + + return self._with_hint(selectable, text, dialect_name) + + def _with_hint( + self: SelfHasHints, + selectable: Optional[_FromClauseArgument], + text: str, + dialect_name: str, + ) -> SelfHasHints: if selectable is None: self._statement_hints += ((dialect_name, text),) else: @@ -443,6 +540,9 @@ class HasHints: return self +SelfFromClause = TypeVar("SelfFromClause", bound="FromClause") + + class FromClause(roles.AnonymizedFromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -473,6 +573,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _is_clone_of: Optional[FromClause] + _columns: ColumnCollection[Any, Any] + schema: Optional[str] = None """Define the 'schema' attribute for this :class:`_expression.FromClause`. @@ -488,7 +590,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _use_schema_map = False - def select(self) -> "Select": + def select(self) -> Select: r"""Return a SELECT of this :class:`_expression.FromClause`. @@ -504,7 +606,13 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return Select(self) - def join(self, right, onclause=None, isouter=False, full=False): + def join( + self, + right: _FromClauseArgument, + onclause: Optional[_ColumnExpressionArgument[bool]] = None, + isouter: bool = False, + full: bool = False, + ) -> Join: """Return a :class:`_expression.Join` from this :class:`_expression.FromClause` to another :class:`FromClause`. @@ -550,7 +658,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return Join(self, right, onclause, isouter, full) - def outerjoin(self, right, onclause=None, full=False): + def outerjoin( + self, + right: _FromClauseArgument, + onclause: Optional[_ColumnExpressionArgument[bool]] = None, + full: bool = False, + ) -> Join: """Return a :class:`_expression.Join` from this :class:`_expression.FromClause` to another :class:`FromClause`, with the "isouter" flag set to @@ -596,7 +709,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return Join(self, right, onclause, True, full) - def alias(self, name=None, flat=False): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromClause: """Return an alias of this :class:`_expression.FromClause`. E.g.:: @@ -617,35 +732,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return Alias._construct(self, name) - @util.preload_module("sqlalchemy.sql.sqltypes") - def table_valued(self): - """Return a :class:`_sql.TableValuedColumn` object for this - :class:`_expression.FromClause`. - - A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that - represents a complete row in a table. Support for this construct is - backend dependent, and is supported in various forms by backends - such as PostgreSQL, Oracle and SQL Server. - - E.g.:: - - >>> from sqlalchemy import select, column, func, table - >>> a = table("a", column("id"), column("x"), column("y")) - >>> stmt = select(func.row_to_json(a.table_valued())) - >>> print(stmt) - SELECT row_to_json(a) AS row_to_json_1 - FROM a - - .. versionadded:: 1.4.0b2 - - .. seealso:: - - :ref:`tutorial_functions` - in the :ref:`unified_tutorial` - - """ - return TableValuedColumn(self, type_api.TABLEVALUE) - - def tablesample(self, sampling, name=None, seed=None): + def tablesample( + self, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, + ) -> TableSample: """Return a TABLESAMPLE alias of this :class:`_expression.FromClause`. The return value is the :class:`_expression.TableSample` @@ -661,7 +753,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return TableSample._construct(self, sampling, name, seed) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: """Return ``True`` if this :class:`_expression.FromClause` is 'derived' from the given ``FromClause``. @@ -673,7 +765,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): # contained elements. return fromclause in self._cloned_set - def _is_lexical_equivalent(self, other): + def _is_lexical_equivalent(self, other: FromClause) -> bool: """Return ``True`` if this :class:`_expression.FromClause` and the other represent the same lexical identity. @@ -681,9 +773,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): if they are the same via annotation identity. """ - return self._cloned_set.intersection(other._cloned_set) + return bool(self._cloned_set.intersection(other._cloned_set)) - @util.non_memoized_property + @util.ro_non_memoized_property def description(self) -> str: """A brief description of this :class:`_expression.FromClause`. @@ -692,13 +784,15 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return getattr(self, "name", self.__class__.__name__ + " object") - def _generate_fromclause_column_proxies(self, fromclause): + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: fromclause._columns._populate_separate_keys( col._make_proxy(fromclause) for col in self.c ) @property - def exported_columns(self): + def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`. @@ -796,7 +890,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._populate_column_collection() return self.foreign_keys - def _reset_column_collection(self): + def _reset_column_collection(self) -> None: """Reset the attributes linked to the ``FromClause.c`` attribute. This collection is separate from all the other memoized things @@ -817,7 +911,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): def _select_iterable(self) -> _SelectIterable: return self.c - def _init_collections(self): + def _init_collections(self) -> None: assert "_columns" not in self.__dict__ assert "primary_key" not in self.__dict__ assert "foreign_keys" not in self.__dict__ @@ -827,10 +921,10 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self.foreign_keys = set() # type: ignore @property - def _cols_populated(self): + def _cols_populated(self) -> bool: return "_columns" in self.__dict__ - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: """Called on subclasses to establish the .c collection. Each implementation has a different way of establishing @@ -838,7 +932,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: """Given a column added to the .c collection of an underlying selectable, produce the local version of that column, assuming this selectable ultimately should proxy this column. @@ -865,15 +959,60 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ self._reset_column_collection() - def _anonymous_fromclause(self, name=None, flat=False): + def _anonymous_fromclause( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromClause: return self.alias(name=name) + if TYPE_CHECKING: + + def self_group( + self: Self, against: Optional[OperatorType] = None + ) -> Union[FromGrouping, Self]: + ... + class NamedFromClause(FromClause): + """A :class:`.FromClause` that has a name. + + Examples include tables, subqueries, CTEs, aliased tables. + + .. versionadded:: 2.0 + + """ + named_with_column = True name: str + @util.preload_module("sqlalchemy.sql.sqltypes") + def table_valued(self) -> TableValuedColumn[Any]: + """Return a :class:`_sql.TableValuedColumn` object for this + :class:`_expression.FromClause`. + + A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that + represents a complete row in a table. Support for this construct is + backend dependent, and is supported in various forms by backends + such as PostgreSQL, Oracle and SQL Server. + + E.g.:: + + >>> from sqlalchemy import select, column, func, table + >>> a = table("a", column("id"), column("x"), column("y")) + >>> stmt = select(func.row_to_json(a.table_valued())) + >>> print(stmt) + SELECT row_to_json(a) AS row_to_json_1 + FROM a + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :ref:`tutorial_functions` - in the :ref:`unified_tutorial` + + """ + return TableValuedColumn(self, type_api.TABLEVALUE) + class SelectLabelStyle(Enum): """Label style constants that may be passed to @@ -992,7 +1131,7 @@ class Join(roles.DMLTableRole, FromClause): __visit_name__ = "join" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("left", InternalTraversal.dp_clauseelement), ("right", InternalTraversal.dp_clauseelement), ("onclause", InternalTraversal.dp_clauseelement), @@ -1002,7 +1141,20 @@ class Join(roles.DMLTableRole, FromClause): _is_join = True - def __init__(self, left, right, onclause=None, isouter=False, full=False): + left: FromClause + right: FromClause + onclause: Optional[ColumnElement[bool]] + isouter: bool + full: bool + + def __init__( + self, + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + ): """Construct a new :class:`_expression.Join`. The usual entrypoint here is the :func:`_expression.join` @@ -1010,11 +1162,23 @@ class Join(roles.DMLTableRole, FromClause): :class:`_expression.FromClause` object. """ + + # when deannotate was removed here, callcounts went up for ORM + # compilation of eager joins, since there were more comparisons of + # annotated objects. test_orm.py -> test_fetch_results + # was therefore changed to show a more real-world use case, where the + # compilation is cached; there's no change in post-cache callcounts. + # callcounts for a single compilation in that particular test + # that includes about eight joins about 1100 extra fn calls, from + # 29200 -> 30373 + self.left = coercions.expect( - roles.FromClauseRole, left, deannotate=True + roles.FromClauseRole, + left, ) self.right = coercions.expect( - roles.FromClauseRole, right, deannotate=True + roles.FromClauseRole, + right, ).self_group() if onclause is None: @@ -1029,7 +1193,7 @@ class Join(roles.DMLTableRole, FromClause): self.isouter = isouter self.full = full - @property + @util.ro_non_memoized_property def description(self) -> str: return "Join object on %s(%d) and %s(%d)" % ( self.left.description, @@ -1038,7 +1202,7 @@ class Join(roles.DMLTableRole, FromClause): id(self.right), ) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: return ( # use hash() to ensure direct comparison to annotated works # as well @@ -1047,7 +1211,10 @@ class Join(roles.DMLTableRole, FromClause): or self.right.is_derived_from(fromclause) ) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> FromGrouping: + ... return FromGrouping(self) @util.preload_module("sqlalchemy.sql.util") @@ -1055,7 +1222,7 @@ class Join(roles.DMLTableRole, FromClause): sqlutil = util.preloaded.sql_util columns = [c for c in self.left.c] + [c for c in self.right.c] - self.primary_key.extend( + self.primary_key.extend( # type: ignore sqlutil.reduce_columns( (c for c in columns if c.primary_key), self.onclause ) @@ -1063,11 +1230,13 @@ class Join(roles.DMLTableRole, FromClause): self._columns._populate_separate_keys( (col._tq_key_label, col) for col in columns ) - self.foreign_keys.update( + self.foreign_keys.update( # type: ignore itertools.chain(*[col.foreign_keys for col in columns]) ) - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: # see Select._copy_internals() for similar concept # here we pre-clone "left" and "right" so that we can @@ -1100,12 +1269,14 @@ class Join(roles.DMLTableRole, FromClause): self._reset_memoizations() - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: super(Join, self)._refresh_for_new_column(column) self.left._refresh_for_new_column(column) self.right._refresh_for_new_column(column) - def _match_primaries(self, left, right): + def _match_primaries( + self, left: FromClause, right: FromClause + ) -> ColumnElement[bool]: if isinstance(left, Join): left_right = left.right else: @@ -1114,8 +1285,15 @@ class Join(roles.DMLTableRole, FromClause): @classmethod def _join_condition( - cls, a, b, a_subset=None, consider_as_foreign_keys=None - ): + cls, + a: FromClause, + b: FromClause, + *, + a_subset: Optional[FromClause] = None, + consider_as_foreign_keys: Optional[ + AbstractSet[ColumnClause[Any]] + ] = None, + ) -> ColumnElement[bool]: """Create a join condition between two tables or selectables. See sqlalchemy.sql.util.join_condition() for full docs. @@ -1151,7 +1329,15 @@ class Join(roles.DMLTableRole, FromClause): return and_(*crit) @classmethod - def _can_join(cls, left, right, consider_as_foreign_keys=None): + def _can_join( + cls, + left: FromClause, + right: FromClause, + *, + consider_as_foreign_keys: Optional[ + AbstractSet[ColumnClause[Any]] + ] = None, + ) -> bool: if isinstance(left, Join): left_right = left.right else: @@ -1169,20 +1355,31 @@ class Join(roles.DMLTableRole, FromClause): @classmethod @util.preload_module("sqlalchemy.sql.util") def _joincond_scan_left_right( - cls, a, a_subset, b, consider_as_foreign_keys - ): + cls, + a: FromClause, + a_subset: Optional[FromClause], + b: FromClause, + consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]], + ) -> collections.defaultdict[ + Optional[ForeignKeyConstraint], + List[Tuple[ColumnClause[Any], ColumnClause[Any]]], + ]: sql_util = util.preloaded.sql_util a = coercions.expect(roles.FromClauseRole, a) b = coercions.expect(roles.FromClauseRole, b) - constraints = collections.defaultdict(list) + constraints: collections.defaultdict[ + Optional[ForeignKeyConstraint], + List[Tuple[ColumnClause[Any], ColumnClause[Any]]], + ] = collections.defaultdict(list) for left in (a_subset, a): if left is None: continue for fk in sorted( - b.foreign_keys, key=lambda fk: fk.parent._creation_order + b.foreign_keys, + key=lambda fk: fk.parent._creation_order, # type: ignore ): if ( consider_as_foreign_keys is not None @@ -1202,7 +1399,8 @@ class Join(roles.DMLTableRole, FromClause): constraints[fk.constraint].append((col, fk.parent)) if left is not b: for fk in sorted( - left.foreign_keys, key=lambda fk: fk.parent._creation_order + left.foreign_keys, + key=lambda fk: fk.parent._creation_order, # type: ignore ): if ( consider_as_foreign_keys is not None @@ -1309,7 +1507,8 @@ class Join(roles.DMLTableRole, FromClause): @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: - return [self] + self.left._from_objects + self.right._from_objects + self_list: List[FromClause] = [self] + return self_list + self.left._from_objects + self.right._from_objects class NoInit: @@ -1327,6 +1526,14 @@ class NoInit: ) +class LateralFromClause(NamedFromClause): + """mark a FROM clause as being able to render directly as LATERAL""" + + +_SelfAliasedReturnsRows = TypeVar( + "_SelfAliasedReturnsRows", bound="AliasedReturnsRows" +) + # FromClause -> # AliasedReturnsRows # -> Alias only for FromClause @@ -1335,6 +1542,8 @@ class NoInit: # -> Lateral -> FromClause, but we accept SelectBase # w/ non-deprecated coercion # -> TableSample -> only for FromClause + + class AliasedReturnsRows(NoInit, NamedFromClause): """Base class of aliases against tables, subqueries, and other selectables.""" @@ -1343,24 +1552,21 @@ class AliasedReturnsRows(NoInit, NamedFromClause): _supports_derived_columns = False - element: ClauseElement + element: ReturnsRows - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("name", InternalTraversal.dp_anon_name), ] @classmethod - def _construct(cls, *arg, **kw): + def _construct( + cls: Type[_SelfAliasedReturnsRows], *arg: Any, **kw: Any + ) -> _SelfAliasedReturnsRows: obj = cls.__new__(cls) obj._init(*arg, **kw) return obj - @classmethod - def _factory(cls, returnsrows, name=None): - """Base factory method. Subclasses need to provide this.""" - raise NotImplementedError() - def _init(self, selectable, name=None): self.element = coercions.expect( roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self @@ -1378,11 +1584,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause): name = _anonymous_label.safe_construct(id(self), name or "anon") self.name = name - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: super(AliasedReturnsRows, self)._refresh_for_new_column(column) self.element._refresh_for_new_column(column) - @property + def _populate_column_collection(self): + self.element._generate_fromclause_column_proxies(self) + + @util.ro_non_memoized_property def description(self) -> str: name = self.name if isinstance(name, _anonymous_label): @@ -1395,15 +1604,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause): """Legacy for dialects that are referring to Alias.original.""" return self.element - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: if fromclause in self._cloned_set: return True return self.element.is_derived_from(fromclause) - def _populate_column_collection(self): - self.element._generate_fromclause_column_proxies(self) - - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: existing_element = self.element super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw) @@ -1420,7 +1628,11 @@ class AliasedReturnsRows(NoInit, NamedFromClause): return [self] -class Alias(roles.DMLTableRole, AliasedReturnsRows): +class FromClauseAlias(AliasedReturnsRows): + element: FromClause + + +class Alias(roles.DMLTableRole, FromClauseAlias): """Represents an table or selectable alias (AS). Represents an alias, as typically applied to any table or @@ -1445,13 +1657,18 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows): element: FromClause @classmethod - def _factory(cls, selectable, name=None, flat=False): + def _factory( + cls, + selectable: FromClause, + name: Optional[str] = None, + flat: bool = False, + ) -> NamedFromClause: return coercions.expect( roles.FromClauseRole, selectable, allow_select=True ).alias(name=name, flat=flat) -class TableValuedAlias(Alias): +class TableValuedAlias(LateralFromClause, Alias): """An alias against a "table valued" SQL function. This construct provides for a SQL function that returns columns @@ -1480,7 +1697,7 @@ class TableValuedAlias(Alias): _render_derived_w_types = False joins_implicitly = False - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("name", InternalTraversal.dp_anon_name), ("_tableval_type", InternalTraversal.dp_type), @@ -1526,7 +1743,9 @@ class TableValuedAlias(Alias): return TableValuedColumn(self, self._tableval_type) - def alias(self, name=None): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> TableValuedAlias: """Return a new alias of this :class:`_sql.TableValuedAlias`. This creates a distinct FROM object that will be distinguished @@ -1547,7 +1766,7 @@ class TableValuedAlias(Alias): return tva - def lateral(self, name=None): + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a new :class:`_sql.TableValuedAlias` with the lateral flag set, so that it renders as LATERAL. @@ -1619,7 +1838,7 @@ class TableValuedAlias(Alias): return new_alias -class Lateral(AliasedReturnsRows): +class Lateral(FromClauseAlias, LateralFromClause): """Represent a LATERAL subquery. This object is constructed from the :func:`_expression.lateral` module @@ -1644,13 +1863,17 @@ class Lateral(AliasedReturnsRows): inherit_cache = True @classmethod - def _factory(cls, selectable, name=None): + def _factory( + cls, + selectable: Union[SelectBase, _FromClauseArgument], + name: Optional[str] = None, + ) -> LateralFromClause: return coercions.expect( roles.FromClauseRole, selectable, explicit_subquery=True ).lateral(name=name) -class TableSample(AliasedReturnsRows): +class TableSample(FromClauseAlias): """Represent a TABLESAMPLE clause. This object is constructed from the :func:`_expression.tablesample` module @@ -1668,13 +1891,22 @@ class TableSample(AliasedReturnsRows): __visit_name__ = "tablesample" - _traverse_internals = AliasedReturnsRows._traverse_internals + [ - ("sampling", InternalTraversal.dp_clauseelement), - ("seed", InternalTraversal.dp_clauseelement), - ] + _traverse_internals: _TraverseInternalsType = ( + AliasedReturnsRows._traverse_internals + + [ + ("sampling", InternalTraversal.dp_clauseelement), + ("seed", InternalTraversal.dp_clauseelement), + ] + ) @classmethod - def _factory(cls, selectable, sampling, name=None, seed=None): + def _factory( + cls, + selectable: _FromClauseArgument, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, + ) -> TableSample: return coercions.expect(roles.FromClauseRole, selectable).tablesample( sampling, name=name, seed=seed ) @@ -1721,7 +1953,7 @@ class CTE( __visit_name__ = "cte" - _traverse_internals = ( + _traverse_internals: _TraverseInternalsType = ( AliasedReturnsRows._traverse_internals + [ ("_cte_alias", InternalTraversal.dp_clauseelement), @@ -1736,7 +1968,12 @@ class CTE( element: HasCTE @classmethod - def _factory(cls, selectable, name=None, recursive=False): + def _factory( + cls, + selectable: HasCTE, + name: Optional[str] = None, + recursive: bool = False, + ) -> CTE: r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -1775,7 +2012,9 @@ class CTE( else: self.element._generate_fromclause_column_proxies(self) - def alias(self, name=None, flat=False): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromClause: """Return an :class:`_expression.Alias` of this :class:`_expression.CTE`. @@ -1814,6 +2053,10 @@ class CTE( :meth:`_sql.HasCTE.cte` - examples of calling styles """ + assert is_select_statement( + self.element + ), f"CTE element f{self.element} does not support union()" + return CTE._construct( self.element.union(*other), name=self.name, @@ -1839,6 +2082,11 @@ class CTE( :meth:`_sql.HasCTE.cte` - examples of calling styles """ + + assert is_select_statement( + self.element + ), f"CTE element f{self.element} does not support union_all()" + return CTE._construct( self.element.union_all(*other), name=self.name, @@ -1865,23 +2113,229 @@ class _CTEOpts(NamedTuple): nesting: bool -class HasCTE(roles.HasCTERole, ClauseElement): +class _ColumnsPlusNames(NamedTuple): + required_label_name: Optional[str] + """ + string label name, if non-None, must be rendered as a + label, i.e. "AS <name>" + """ + + proxy_key: Optional[str] + """ + proxy_key that is to be part of the result map for this + col. this is also the key in a fromclause.c or + select.selected_columns collection + """ + + fallback_label_name: Optional[str] + """ + name that can be used to render an "AS <name>" when + we have to render a label even though + required_label_name was not given + """ + + column: Union[ColumnElement[Any], TextClause] + """ + the ColumnElement itself + """ + + repeated: bool + """ + True if this is a duplicate of a previous column + in the list of columns + """ + + +class SelectsRows(ReturnsRows): + """Sub-base of ReturnsRows for elements that deliver rows + directly, namely SELECT and INSERT/UPDATE/DELETE..RETURNING""" + + _label_style: SelectLabelStyle = LABEL_STYLE_NONE + + def _generate_columns_plus_names( + self, anon_for_dupe_key: bool + ) -> List[_ColumnsPlusNames]: + """Generate column names as rendered in a SELECT statement by + the compiler. + + This is distinct from the _column_naming_convention generator that's + intended for population of .c collections and similar, which has + different rules. the collection returned here calls upon the + _column_naming_convention as well. + + """ + cols = self._all_selected_columns + + key_naming_convention = SelectState._column_naming_convention( + self._label_style + ) + + names = {} + + result: List[_ColumnsPlusNames] = [] + result_append = result.append + + table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + label_style_none = self._label_style is LABEL_STYLE_NONE + + # a counter used for "dedupe" labels, which have double underscores + # in them and are never referred by name; they only act + # as positional placeholders. they need only be unique within + # the single columns clause they're rendered within (required by + # some dbs such as mysql). So their anon identity is tracked against + # a fixed counter rather than hash() identity. + dedupe_hash = 1 + + for c in cols: + repeated = False + + if not c._render_label_in_columns_clause: + effective_name = ( + required_label_name + ) = fallback_label_name = None + elif label_style_none: + if TYPE_CHECKING: + assert is_column_element(c) + + effective_name = required_label_name = None + fallback_label_name = c._non_anon_label or c._anon_name_label + else: + if TYPE_CHECKING: + assert is_column_element(c) + + if table_qualified: + required_label_name = ( + effective_name + ) = fallback_label_name = c._tq_label + else: + effective_name = fallback_label_name = c._non_anon_label + required_label_name = None + + if effective_name is None: + # it seems like this could be _proxy_key and we would + # not need _expression_label but it isn't + # giving us a clue when to use anon_label instead + expr_label = c._expression_label + if expr_label is None: + repeated = c._anon_name_label in names + names[c._anon_name_label] = c + effective_name = required_label_name = None + + if repeated: + # here, "required_label_name" is sent as + # "None" and "fallback_label_name" is sent. + if table_qualified: + fallback_label_name = ( + c._dedupe_anon_tq_label_idx(dedupe_hash) + ) + dedupe_hash += 1 + else: + fallback_label_name = c._dedupe_anon_label_idx( + dedupe_hash + ) + dedupe_hash += 1 + else: + fallback_label_name = c._anon_name_label + else: + required_label_name = ( + effective_name + ) = fallback_label_name = expr_label + + if effective_name is not None: + if TYPE_CHECKING: + assert is_column_element(c) + + if effective_name in names: + # when looking to see if names[name] is the same column as + # c, use hash(), so that an annotated version of the column + # is seen as the same as the non-annotated + if hash(names[effective_name]) != hash(c): + + # different column under the same name. apply + # disambiguating label + if table_qualified: + required_label_name = ( + fallback_label_name + ) = c._anon_tq_label + else: + required_label_name = ( + fallback_label_name + ) = c._anon_name_label + + if anon_for_dupe_key and required_label_name in names: + # here, c._anon_tq_label is definitely unique to + # that column identity (or annotated version), so + # this should always be true. + # this is also an infrequent codepath because + # you need two levels of duplication to be here + assert hash(names[required_label_name]) == hash(c) + + # the column under the disambiguating label is + # already present. apply the "dedupe" label to + # subsequent occurrences of the column so that the + # original stays non-ambiguous + if table_qualified: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + dedupe_hash += 1 + else: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_label_idx(dedupe_hash) + dedupe_hash += 1 + repeated = True + else: + names[required_label_name] = c + elif anon_for_dupe_key: + # same column under the same name. apply the "dedupe" + # label so that the original stays non-ambiguous + if table_qualified: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + dedupe_hash += 1 + else: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_label_idx(dedupe_hash) + dedupe_hash += 1 + repeated = True + else: + names[effective_name] = c + + result_append( + _ColumnsPlusNames( + required_label_name, + key_naming_convention(c), + fallback_label_name, + c, + repeated, + ) + ) + + return result + + +class HasCTE(roles.HasCTERole, SelectsRows): """Mixin that declares a class to include CTE support. .. versionadded:: 1.1 """ - _has_ctes_traverse_internals = [ + _has_ctes_traverse_internals: _TraverseInternalsType = [ ("_independent_ctes", InternalTraversal.dp_clauseelement_list), ("_independent_ctes_opts", InternalTraversal.dp_plain_obj), ] - _independent_ctes = () - _independent_ctes_opts = () + _independent_ctes: Tuple[CTE, ...] = () + _independent_ctes_opts: Tuple[_CTEOpts, ...] = () @_generative - def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE: + def add_cte( + self: SelfHasCTE, *ctes: CTE, nest_here: bool = False + ) -> SelfHasCTE: r"""Add one or more :class:`_sql.CTE` constructs to this statement. This method will associate the given :class:`_sql.CTE` constructs with @@ -1985,7 +2439,12 @@ class HasCTE(roles.HasCTERole, ClauseElement): self._independent_ctes_opts += (opt,) return self - def cte(self, name=None, recursive=False, nesting=False): + def cte( + self, + name: Optional[str] = None, + recursive: bool = False, + nesting: bool = False, + ) -> CTE: r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -2293,10 +2752,12 @@ class Subquery(AliasedReturnsRows): inherit_cache = True - element: Select + element: SelectBase @classmethod - def _factory(cls, selectable, name=None): + def _factory( + cls, selectable: SelectBase, name: Optional[str] = None + ) -> Subquery: """Return a :class:`.Subquery` object.""" return coercions.expect( roles.SelectStatementRole, selectable @@ -2335,11 +2796,13 @@ class Subquery(AliasedReturnsRows): class FromGrouping(GroupedElement, FromClause): """Represent a grouping of a FROM clause""" - _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement) + ] element: FromClause - def __init__(self, element): + def __init__(self, element: FromClause): self.element = coercions.expect(roles.FromClauseRole, element) def _init_collections(self): @@ -2361,11 +2824,13 @@ class FromGrouping(GroupedElement, FromClause): def foreign_keys(self): return self.element.foreign_keys - def is_derived_from(self, element): - return self.element.is_derived_from(element) + def is_derived_from(self, fromclause: FromClause) -> bool: + return self.element.is_derived_from(fromclause) - def alias(self, **kw): - return FromGrouping(self.element.alias(**kw)) + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromGrouping: + return NamedFromGrouping(self.element.alias(name=name, flat=flat)) def _anonymous_fromclause(self, **kw): return FromGrouping(self.element._anonymous_fromclause(**kw)) @@ -2385,6 +2850,16 @@ class FromGrouping(GroupedElement, FromClause): self.element = state["element"] +class NamedFromGrouping(FromGrouping, NamedFromClause): + """represent a grouping of a named FROM clause + + .. versionadded:: 2.0 + + """ + + inherit_cache = True + + class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): """Represents a minimal "table" construct. @@ -2417,7 +2892,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): __visit_name__ = "table" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ( "columns", InternalTraversal.dp_fromclause_canonical_column_collection, @@ -2434,15 +2909,17 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): doesn't support having a primary key or column -level defaults, so implicit returning doesn't apply.""" - _autoincrement_column = None - """No PK or default support so no autoincrement column.""" + @util.ro_memoized_property + def _autoincrement_column(self) -> Optional[ColumnClause[Any]]: + """No PK or default support so no autoincrement column.""" + return None - def __init__(self, name, *columns, **kw): + def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any): super(TableClause, self).__init__() self.name = name self._columns = DedupeColumnCollection() - self.primary_key = ColumnSet() - self.foreign_keys = set() + self.primary_key = ColumnSet() # type: ignore + self.foreign_keys = set() # type: ignore for c in columns: self.append_column(c) @@ -2466,23 +2943,23 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: ... - def __str__(self): + def __str__(self) -> str: if self.schema is not None: return self.schema + "." + self.name else: return self.name - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: pass - def _init_collections(self): + def _init_collections(self) -> None: pass - @util.memoized_property + @util.ro_memoized_property def description(self) -> str: return self.name - def append_column(self, c, **kw): + def append_column(self, c: ColumnClause[Any]) -> None: existing = c.table if existing is not None and existing is not self: raise exc.ArgumentError( @@ -2494,7 +2971,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): c.table = self @util.preload_module("sqlalchemy.sql.dml") - def insert(self): + def insert(self) -> Insert: """Generate an :func:`_expression.insert` construct against this :class:`_expression.TableClause`. @@ -2505,10 +2982,11 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): See :func:`_expression.insert` for argument and usage information. """ + return util.preloaded.sql_dml.Insert(self) @util.preload_module("sqlalchemy.sql.dml") - def update(self): + def update(self) -> Update: """Generate an :func:`_expression.update` construct against this :class:`_expression.TableClause`. @@ -2524,7 +3002,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): ) @util.preload_module("sqlalchemy.sql.dml") - def delete(self): + def delete(self) -> Delete: """Generate a :func:`_expression.delete` construct against this :class:`_expression.TableClause`. @@ -2543,13 +3021,18 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): class ForUpdateArg(ClauseElement): - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("of", InternalTraversal.dp_clauseelement_list), ("nowait", InternalTraversal.dp_boolean), ("read", InternalTraversal.dp_boolean), ("skip_locked", InternalTraversal.dp_boolean), ] + of: Optional[Sequence[ClauseElement]] + nowait: bool + read: bool + skip_locked: bool + @classmethod def _from_argument(cls, with_for_update): if isinstance(with_for_update, ForUpdateArg): @@ -2606,7 +3089,7 @@ class ForUpdateArg(ClauseElement): SelfValues = typing.TypeVar("SelfValues", bound="Values") -class Values(Generative, NamedFromClause): +class Values(Generative, LateralFromClause): """Represent a ``VALUES`` construct that can be used as a FROM element in a statement. @@ -2619,28 +3102,42 @@ class Values(Generative, NamedFromClause): __visit_name__ = "values" - _data = () + _data: Tuple[List[Tuple[Any, ...]], ...] = () - _traverse_internals = [ + _unnamed: bool + _traverse_internals: _TraverseInternalsType = [ ("_column_args", InternalTraversal.dp_clauseelement_list), ("_data", InternalTraversal.dp_dml_multi_values), ("name", InternalTraversal.dp_string), ("literal_binds", InternalTraversal.dp_boolean), ] - def __init__(self, *columns, name=None, literal_binds=False): + def __init__( + self, + *columns: ColumnClause[Any], + name: Optional[str] = None, + literal_binds: bool = False, + ): super(Values, self).__init__() self._column_args = columns - self.name = name + if name is None: + self._unnamed = True + self.name = _anonymous_label.safe_construct(id(self), "anon") + else: + self._unnamed = False + self.name = name self.literal_binds = literal_binds - self.named_with_column = self.name is not None + self.named_with_column = not self._unnamed @property def _column_types(self): return [col.type for col in self._column_args] @_generative - def alias(self: SelfValues, name, **kw) -> SelfValues: + def alias( + self: SelfValues, name: Optional[str] = None, flat: bool = False + ) -> SelfValues: + """Return a new :class:`_expression.Values` construct that is a copy of this one with the given name. @@ -2655,12 +3152,20 @@ class Values(Generative, NamedFromClause): :func:`_expression.alias` """ - self.name = name - self.named_with_column = self.name is not None + non_none_name: str + + if name is None: + non_none_name = _anonymous_label.safe_construct(id(self), "anon") + else: + non_none_name = name + + self.name = non_none_name + self.named_with_column = True + self._unnamed = False return self @_generative - def lateral(self: SelfValues, name=None) -> SelfValues: + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a new :class:`_expression.Values` with the lateral flag set, so that it renders as LATERAL. @@ -2670,13 +3175,20 @@ class Values(Generative, NamedFromClause): :func:`_expression.lateral` """ + non_none_name: str + + if name is None: + non_none_name = self.name + else: + non_none_name = name + self._is_lateral = True - if name is not None: - self.name = name + self.name = non_none_name + self._unnamed = False return self @_generative - def data(self: SelfValues, values) -> SelfValues: + def data(self: SelfValues, values: List[Tuple[Any, ...]]) -> SelfValues: """Return a new :class:`_expression.Values` construct, adding the given data to the data list. @@ -2694,7 +3206,7 @@ class Values(Generative, NamedFromClause): self._data += (values,) return self - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: for c in self._column_args: self._columns.add(c) c.table = self @@ -2727,32 +3239,16 @@ class SelectBase( """ - _is_select_statement = True + _is_select_base = True is_select = True - def _generate_fromclause_column_proxies( - self, fromclause: FromClause - ) -> None: - raise NotImplementedError() + _label_style: SelectLabelStyle = LABEL_STYLE_NONE def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: self._reset_memoizations() - def _generate_columns_plus_names( - self, anon_for_dupe_key: bool - ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - raise NotImplementedError() - - def set_label_style( - self: SelfSelectBase, label_style: SelectLabelStyle - ) -> SelfSelectBase: - raise NotImplementedError() - - def get_label_style(self) -> SelectLabelStyle: - raise NotImplementedError() - - @property - def selected_columns(self): + @util.ro_non_memoized_property + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set. @@ -2797,7 +3293,7 @@ class SelectBase( raise NotImplementedError() @property - def exported_columns(self): + def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`, not including @@ -2819,7 +3315,7 @@ class SelectBase( """ - return self.selected_columns + return self.selected_columns.as_readonly() @util.deprecated_property( "1.4", @@ -2841,6 +3337,26 @@ class SelectBase( def columns(self): return self.c + def get_label_style(self) -> SelectLabelStyle: + """ + Retrieve the current label style. + + Implemented by subclasses. + + """ + raise NotImplementedError() + + def set_label_style( + self: SelfSelectBase, style: SelectLabelStyle + ) -> SelfSelectBase: + """Return a new selectable with the specified label style. + + Implemented by subclasses. + + """ + + raise NotImplementedError() + @util.deprecated( "1.4", "The :meth:`_expression.SelectBase.select` method is deprecated " @@ -2857,6 +3373,9 @@ class SelectBase( def _implicit_subquery(self): return self.subquery() + def _scalar_type(self) -> TypeEngine[Any]: + raise NotImplementedError() + @util.deprecated( "1.4", "The :meth:`_expression.SelectBase.as_scalar` " @@ -2926,7 +3445,7 @@ class SelectBase( """ return self.scalar_subquery().label(name) - def lateral(self, name=None): + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a LATERAL alias of this :class:`_expression.Selectable`. The return value is the :class:`_expression.Lateral` construct also @@ -2941,11 +3460,7 @@ class SelectBase( """ return Lateral._factory(self, name) - @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - return [self] - - def subquery(self, name=None): + def subquery(self, name: Optional[str] = None) -> Subquery: """Return a subquery of this :class:`_expression.SelectBase`. A subquery is from a SQL perspective a parenthesized, named @@ -2995,7 +3510,9 @@ class SelectBase( raise NotImplementedError() - def alias(self, name=None, flat=False): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> Subquery: """Return a named subquery against this :class:`_expression.SelectBase`. @@ -3023,7 +3540,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ __visit_name__ = "select_statement_grouping" - _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement) + ] _is_select_container = True @@ -3053,13 +3572,14 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def select_statement(self): return self.element - def self_group(self, against=None): + def self_group(self: Self, against: Optional[OperatorType] = None) -> Self: + ... return self - def _generate_columns_plus_names( - self, anon_for_dupe_key: bool - ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - return self.element._generate_columns_plus_names(anon_for_dupe_key) + # def _generate_columns_plus_names( + # self, anon_for_dupe_key: bool + # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: + # return self.element._generate_columns_plus_names(anon_for_dupe_key) def _generate_fromclause_column_proxies( self, subquery: FromClause @@ -3070,8 +3590,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def _all_selected_columns(self) -> _SelectIterable: return self.element._all_selected_columns - @property - def selected_columns(self): + @util.ro_non_memoized_property + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that the embedded SELECT statement returns in its result set, not including @@ -3112,25 +3632,30 @@ class GenerativeSelect(SelectBase): """ - _order_by_clauses = () - _group_by_clauses = () - _limit_clause = None - _offset_clause = None - _fetch_clause = None - _fetch_clause_options = None - _for_update_arg = None + _order_by_clauses: Tuple[ColumnElement[Any], ...] = () + _group_by_clauses: Tuple[ColumnElement[Any], ...] = () + _limit_clause: Optional[ColumnElement[Any]] = None + _offset_clause: Optional[ColumnElement[Any]] = None + _fetch_clause: Optional[ColumnElement[Any]] = None + _fetch_clause_options: Optional[Dict[str, bool]] = None + _for_update_arg: Optional[ForUpdateArg] = None - def __init__(self, _label_style=LABEL_STYLE_DEFAULT): + def __init__(self, _label_style: SelectLabelStyle = LABEL_STYLE_DEFAULT): self._label_style = _label_style @_generative def with_for_update( self: SelfGenerativeSelect, - nowait=False, - read=False, - of=None, - skip_locked=False, - key_share=False, + nowait: bool = False, + read: bool = False, + of: Optional[ + Union[ + _ColumnExpressionArgument[Any], + Sequence[_ColumnExpressionArgument[Any]], + ] + ] = None, + skip_locked: bool = False, + key_share: bool = False, ) -> SelfGenerativeSelect: """Specify a ``FOR UPDATE`` clause for this :class:`_expression.GenerativeSelect`. @@ -3241,20 +3766,25 @@ class GenerativeSelect(SelectBase): return self @property - def _group_by_clause(self): + def _group_by_clause(self) -> ClauseList: """ClauseList access to group_by_clauses for legacy dialects""" return ClauseList._construct_raw( operators.comma_op, self._group_by_clauses ) @property - def _order_by_clause(self): + def _order_by_clause(self) -> ClauseList: """ClauseList access to order_by_clauses for legacy dialects""" return ClauseList._construct_raw( operators.comma_op, self._order_by_clauses ) - def _offset_or_limit_clause(self, element, name=None, type_=None): + def _offset_or_limit_clause( + self, + element: Union[int, _ColumnExpressionArgument[Any]], + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[int]] = None, + ) -> ColumnElement[Any]: """Convert the given value to an "offset or limit" clause. This handles incoming integers and converts to an expression; if @@ -3265,7 +3795,21 @@ class GenerativeSelect(SelectBase): roles.LimitOffsetRole, element, name=name, type_=type_ ) - def _offset_or_limit_clause_asint(self, clause, attrname): + @overload + def _offset_or_limit_clause_asint( + self, clause: ColumnElement[Any], attrname: str + ) -> NoReturn: + ... + + @overload + def _offset_or_limit_clause_asint( + self, clause: Optional[_OffsetLimitParam], attrname: str + ) -> Optional[int]: + ... + + def _offset_or_limit_clause_asint( + self, clause: Optional[ColumnElement[Any]], attrname: str + ) -> Union[NoReturn, Optional[int]]: """Convert the "offset or limit" clause of a select construct to an integer. @@ -3286,7 +3830,7 @@ class GenerativeSelect(SelectBase): return util.asint(value) @property - def _limit(self): + def _limit(self) -> Optional[int]: """Get an integer value for the limit. This should only be used by code that cannot support a limit as a BindParameter or other custom clause as it will throw an exception if the limit @@ -3295,14 +3839,14 @@ class GenerativeSelect(SelectBase): """ return self._offset_or_limit_clause_asint(self._limit_clause, "limit") - def _simple_int_clause(self, clause): + def _simple_int_clause(self, clause: ClauseElement) -> bool: """True if the clause is a simple integer, False if it is not present or is a SQL expression. """ return isinstance(clause, _OffsetLimitParam) @property - def _offset(self): + def _offset(self) -> Optional[int]: """Get an integer value for the offset. This should only be used by code that cannot support an offset as a BindParameter or other custom clause as it will throw an exception if the @@ -3314,7 +3858,7 @@ class GenerativeSelect(SelectBase): ) @property - def _has_row_limiting_clause(self): + def _has_row_limiting_clause(self) -> bool: return ( self._limit_clause is not None or self._offset_clause is not None @@ -3322,7 +3866,10 @@ class GenerativeSelect(SelectBase): ) @_generative - def limit(self: SelfGenerativeSelect, limit) -> SelfGenerativeSelect: + def limit( + self: SelfGenerativeSelect, + limit: Union[int, _ColumnExpressionArgument[int]], + ) -> SelfGenerativeSelect: """Return a new selectable with the given LIMIT criterion applied. @@ -3356,7 +3903,10 @@ class GenerativeSelect(SelectBase): @_generative def fetch( - self: SelfGenerativeSelect, count, with_ties=False, percent=False + self: SelfGenerativeSelect, + count: Union[int, _ColumnExpressionArgument[int]], + with_ties: bool = False, + percent: bool = False, ) -> SelfGenerativeSelect: """Return a new selectable with the given FETCH FIRST criterion applied. @@ -3408,7 +3958,10 @@ class GenerativeSelect(SelectBase): return self @_generative - def offset(self: SelfGenerativeSelect, offset) -> SelfGenerativeSelect: + def offset( + self: SelfGenerativeSelect, + offset: Union[int, _ColumnExpressionArgument[int]], + ) -> SelfGenerativeSelect: """Return a new selectable with the given OFFSET criterion applied. @@ -3438,7 +3991,11 @@ class GenerativeSelect(SelectBase): @_generative @util.preload_module("sqlalchemy.sql.util") - def slice(self: SelfGenerativeSelect, start, stop) -> SelfGenerativeSelect: + def slice( + self: SelfGenerativeSelect, + start: int, + stop: int, + ) -> SelfGenerativeSelect: """Apply LIMIT / OFFSET to this statement based on a slice. The start and stop indices behave like the argument to Python's @@ -3485,7 +4042,9 @@ class GenerativeSelect(SelectBase): return self @_generative - def order_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect: + def order_by( + self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of ORDER BY criteria applied. @@ -3522,7 +4081,9 @@ class GenerativeSelect(SelectBase): return self @_generative - def group_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect: + def group_by( + self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of GROUP BY criterion applied. @@ -3567,6 +4128,15 @@ class CompoundSelectState(CompileState): return d, d, d +class _CompoundSelectKeyword(Enum): + UNION = "UNION" + UNION_ALL = "UNION ALL" + EXCEPT = "EXCEPT" + EXCEPT_ALL = "EXCEPT ALL" + INTERSECT = "INTERSECT" + INTERSECT_ALL = "INTERSECT ALL" + + class CompoundSelect(HasCompileState, GenerativeSelect): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -3590,7 +4160,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect): __visit_name__ = "compound_select" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("selects", InternalTraversal.dp_clauseelement_list), ("_limit_clause", InternalTraversal.dp_clauseelement), ("_offset_clause", InternalTraversal.dp_clauseelement), @@ -3602,17 +4172,16 @@ class CompoundSelect(HasCompileState, GenerativeSelect): ("keyword", InternalTraversal.dp_string), ] + SupportsCloneAnnotations._clone_annotations_traverse_internals - UNION = util.symbol("UNION") - UNION_ALL = util.symbol("UNION ALL") - EXCEPT = util.symbol("EXCEPT") - EXCEPT_ALL = util.symbol("EXCEPT ALL") - INTERSECT = util.symbol("INTERSECT") - INTERSECT_ALL = util.symbol("INTERSECT ALL") + selects: List[SelectBase] _is_from_container = True _auto_correlate = False - def __init__(self, keyword, *selects): + def __init__( + self, + keyword: _CompoundSelectKeyword, + *selects: _SelectStatementForCompoundArgument, + ): self.keyword = keyword self.selects = [ coercions.expect(roles.CompoundElementRole, s).self_group( @@ -3624,36 +4193,50 @@ class CompoundSelect(HasCompileState, GenerativeSelect): GenerativeSelect.__init__(self) @classmethod - def _create_union(cls, *selects, **kwargs): - return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs) + def _create_union( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.UNION, *selects) @classmethod - def _create_union_all(cls, *selects): - return CompoundSelect(CompoundSelect.UNION_ALL, *selects) + def _create_union_all( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects) @classmethod - def _create_except(cls, *selects): - return CompoundSelect(CompoundSelect.EXCEPT, *selects) + def _create_except( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects) @classmethod - def _create_except_all(cls, *selects): - return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects) + def _create_except_all( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects) @classmethod - def _create_intersect(cls, *selects): - return CompoundSelect(CompoundSelect.INTERSECT, *selects) + def _create_intersect( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects) @classmethod - def _create_intersect_all(cls, *selects): - return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects) + def _create_intersect_all( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects) - def _scalar_type(self): + def _scalar_type(self) -> TypeEngine[Any]: return self.selects[0]._scalar_type() - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> GroupedElement: return SelectStatementGrouping(self) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: for s in self.selects: if s.is_derived_from(fromclause): return True @@ -3675,7 +4258,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self - def _generate_fromclause_column_proxies(self, subquery): + def _generate_fromclause_column_proxies( + self, subquery: FromClause + ) -> None: # this is a slightly hacky thing - the union exports a # column that resembles just that of the *first* selectable. @@ -3716,8 +4301,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect): def _all_selected_columns(self) -> _SelectIterable: return self.selects[0]._all_selected_columns - @property - def selected_columns(self): + @util.ro_non_memoized_property + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -3739,6 +4324,11 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self.selects[0].selected_columns +# backwards compat +for elem in _CompoundSelectKeyword: + setattr(CompoundSelect, elem.name, elem) + + @CompileState.plugin_for("default", "select") class SelectState(util.MemoizedSlots, CompileState): __slots__ = ( @@ -3758,10 +4348,12 @@ class SelectState(util.MemoizedSlots, CompileState): if TYPE_CHECKING: @classmethod - def get_plugin_class(cls, statement: Select) -> SelectState: + def get_plugin_class(cls, statement: Executable) -> Type[SelectState]: ... - def __init__(self, statement, compiler, **kw): + def __init__( + self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any + ): self.statement = statement self.from_clauses = statement._from_obj @@ -3778,14 +4370,16 @@ class SelectState(util.MemoizedSlots, CompileState): self.columns_plus_names = statement._generate_columns_plus_names(True) @classmethod - def _plugin_not_implemented(cls): + def _plugin_not_implemented(cls) -> NoReturn: raise NotImplementedError( "The default SELECT construct without plugins does not " "implement this method." ) @classmethod - def get_column_descriptions(cls, statement): + def get_column_descriptions( + cls, statement: Select + ) -> List[Dict[str, Any]]: return [ { "name": name, @@ -3798,11 +4392,13 @@ class SelectState(util.MemoizedSlots, CompileState): ] @classmethod - def from_statement(cls, statement, from_statement): + def from_statement( + cls, statement: Select, from_statement: ReturnsRows + ) -> Any: cls._plugin_not_implemented() @classmethod - def get_columns_clause_froms(cls, statement): + def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]: return cls._normalize_froms( itertools.chain.from_iterable( element._from_objects for element in statement._raw_columns @@ -3810,7 +4406,9 @@ class SelectState(util.MemoizedSlots, CompileState): ) @classmethod - def _column_naming_convention(cls, label_style): + def _column_naming_convention( + cls, label_style: SelectLabelStyle + ) -> Callable[[Union[ColumnElement[Any], TextClause]], Optional[str]]: table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL dedupe = label_style is not LABEL_STYLE_NONE @@ -3850,7 +4448,8 @@ class SelectState(util.MemoizedSlots, CompileState): return go - def _get_froms(self, statement): + def _get_froms(self, statement: Select) -> List[FromClause]: + ambiguous_table_name_map: _AmbiguousTableNameMap self._ambiguous_table_name_map = ambiguous_table_name_map = {} return self._normalize_froms( @@ -3876,10 +4475,10 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def _normalize_froms( cls, - iterable_of_froms, - check_statement=None, - ambiguous_table_name_map=None, - ): + iterable_of_froms: Iterable[FromClause], + check_statement: Optional[Select] = None, + ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, + ) -> List[FromClause]: """given an iterable of things to select FROM, reduce them to what would actually render in the FROM clause of a SELECT. @@ -3888,12 +4487,12 @@ class SelectState(util.MemoizedSlots, CompileState): etc. """ - seen = set() - froms = [] + seen: Set[FromClause] = set() + froms: List[FromClause] = [] for item in iterable_of_froms: - if item._is_subquery and item.element is check_statement: + if is_subquery(item) and item.element is check_statement: raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" ) @@ -3923,7 +4522,7 @@ class SelectState(util.MemoizedSlots, CompileState): ) for item in froms for fr in item._from_objects - if fr._is_table + if is_table(fr) and fr.schema and fr.name not in ambiguous_table_name_map ) @@ -3931,8 +4530,10 @@ class SelectState(util.MemoizedSlots, CompileState): return froms def _get_display_froms( - self, explicit_correlate_froms=None, implicit_correlate_froms=None - ): + self, + explicit_correlate_froms: Optional[Sequence[FromClause]] = None, + implicit_correlate_froms: Optional[Sequence[FromClause]] = None, + ) -> List[FromClause]: """Return the full list of 'from' clauses to be displayed. Takes into account a set of existing froms which may be @@ -3998,25 +4599,33 @@ class SelectState(util.MemoizedSlots, CompileState): return froms - def _memoized_attr__label_resolve_dict(self): - with_cols = dict( - (c._tq_label or c.key, c) + def _memoized_attr__label_resolve_dict( + self, + ) -> Tuple[ + Dict[str, ColumnElement[Any]], + Dict[str, ColumnElement[Any]], + Dict[str, ColumnElement[Any]], + ]: + with_cols: Dict[str, ColumnElement[Any]] = dict( + (c._tq_label or c.key, c) # type: ignore for c in self.statement._all_selected_columns if c._allow_label_resolve ) - only_froms = dict( - (c.key, c) + only_froms: Dict[str, ColumnElement[Any]] = dict( + (c.key, c) # type: ignore for c in _select_iterables(self.froms) if c._allow_label_resolve ) - only_cols = with_cols.copy() + only_cols: Dict[str, ColumnElement[Any]] = with_cols.copy() for key, value in only_froms.items(): with_cols.setdefault(key, value) return with_cols, only_froms, only_cols @classmethod - def determine_last_joined_entity(cls, stmt): + def determine_last_joined_entity( + cls, stmt: Select + ) -> Optional[_JoinTargetElement]: if stmt._setup_joins: return stmt._setup_joins[-1][0] else: @@ -4026,8 +4635,16 @@ class SelectState(util.MemoizedSlots, CompileState): def all_selected_columns(cls, statement: Select) -> _SelectIterable: return [c for c in _select_iterables(statement._raw_columns)] - def _setup_joins(self, args, raw_columns): + def _setup_joins( + self, + args: Tuple[_SetupJoinsElement, ...], + raw_columns: List[_ColumnsClauseElement], + ) -> None: for (right, onclause, left, flags) in args: + if TYPE_CHECKING: + if onclause is not None: + assert isinstance(onclause, ColumnElement) + isouter = flags["isouter"] full = flags["full"] @@ -4043,6 +4660,16 @@ class SelectState(util.MemoizedSlots, CompileState): left ) + # these assertions can be made here, as if the right/onclause + # contained ORM elements, the select() statement would have been + # upgraded to an ORM select, and this method would not be called; + # orm.context.ORMSelectCompileState._join() would be + # used instead. + if TYPE_CHECKING: + assert isinstance(right, FromClause) + if onclause is not None: + assert isinstance(onclause, ColumnElement) + if replace_from_obj_index is not None: # splice into an existing element in the # self._from_obj list @@ -4062,15 +4689,19 @@ class SelectState(util.MemoizedSlots, CompileState): + self.from_clauses[replace_from_obj_index + 1 :] ) else: - + assert left is not None self.from_clauses = self.from_clauses + ( Join(left, right, onclause, isouter=isouter, full=full), ) @util.preload_module("sqlalchemy.sql.util") def _join_determine_implicit_left_side( - self, raw_columns, left, right, onclause - ): + self, + raw_columns: List[_ColumnsClauseElement], + left: Optional[FromClause], + right: _JoinTargetElement, + onclause: Optional[ColumnElement[Any]], + ) -> Tuple[Optional[FromClause], Optional[int]]: """When join conditions don't express the left side explicitly, determine if an existing FROM or entity in this query can serve as the left hand side. @@ -4079,13 +4710,13 @@ class SelectState(util.MemoizedSlots, CompileState): sql_util = util.preloaded.sql_util - replace_from_obj_index = None + replace_from_obj_index: Optional[int] = None from_clauses = self.from_clauses if from_clauses: - indexes = sql_util.find_left_clause_to_join_from( + indexes: List[int] = sql_util.find_left_clause_to_join_from( from_clauses, right, onclause ) @@ -4138,15 +4769,17 @@ class SelectState(util.MemoizedSlots, CompileState): return left, replace_from_obj_index @util.preload_module("sqlalchemy.sql.util") - def _join_place_explicit_left_side(self, left): - replace_from_obj_index = None + def _join_place_explicit_left_side( + self, left: FromClause + ) -> Optional[int]: + replace_from_obj_index: Optional[int] = None sql_util = util.preloaded.sql_util from_clauses = list(self.statement._iterate_from_elements()) if from_clauses: - indexes = sql_util.find_left_clause_that_matches_given( + indexes: List[int] = sql_util.find_left_clause_that_matches_given( self.from_clauses, left ) else: @@ -4171,7 +4804,13 @@ class SelectState(util.MemoizedSlots, CompileState): class _SelectFromElements: - def _iterate_from_elements(self): + __slots__ = () + + _raw_columns: List[_ColumnsClauseElement] + _where_criteria: Tuple[ColumnElement[Any], ...] + _from_obj: Tuple[FromClause, ...] + + def _iterate_from_elements(self) -> Iterator[FromClause]: # note this does not include elements # in _setup_joins @@ -4195,28 +4834,58 @@ class _SelectFromElements: yield element +Self_MemoizedSelectEntities = TypeVar("Self_MemoizedSelectEntities", bound=Any) + + class _MemoizedSelectEntities( cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible ): + """represents partial state from a Select object, for the case + where Select.columns() has redefined the set of columns/entities the + statement will be SELECTing from. This object represents + the entities from the SELECT before that transformation was applied, + so that transformations that were made in terms of the SELECT at that + time, such as join() as well as options(), can access the correct context. + + In previous SQLAlchemy versions, this wasn't needed because these + constructs calculated everything up front, like when you called join() + or options(), it did everything to figure out how that would translate + into specific SQL constructs that would be ready to send directly to the + SQL compiler when needed. But as of + 1.4, all of that stuff is done in the compilation phase, during the + "compile state" portion of the process, so that the work can all be + cached. So it needs to be able to resolve joins/options2 based on what + the list of entities was when those methods were called. + + + """ + __visit_name__ = "memoized_select_entities" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ("_setup_joins", InternalTraversal.dp_setup_join_tuple), ("_with_options", InternalTraversal.dp_executable_options), ] + _is_clone_of: Optional[ClauseElement] + _raw_columns: List[_ColumnsClauseElement] + _setup_joins: Tuple[_SetupJoinsElement, ...] + _with_options: Tuple[ExecutableOption, ...] + _annotations = util.EMPTY_DICT - def _clone(self, **kw): + def _clone( + self: Self_MemoizedSelectEntities, **kw: Any + ) -> Self_MemoizedSelectEntities: c = self.__class__.__new__(self.__class__) c.__dict__ = {k: v for k, v in self.__dict__.items()} c._is_clone_of = self.__dict__.get("_is_clone_of", self) - return c + return c # type: ignore @classmethod - def _generate_for_statement(cls, select_stmt): + def _generate_for_statement(cls, select_stmt: Select) -> None: if select_stmt._setup_joins or select_stmt._with_options: self = _MemoizedSelectEntities() self._raw_columns = select_stmt._raw_columns @@ -4224,12 +4893,10 @@ class _MemoizedSelectEntities( self._with_options = select_stmt._with_options select_stmt._memoized_select_entities += (self,) - select_stmt._raw_columns = ( - select_stmt._setup_joins - ) = select_stmt._with_options = () + select_stmt._raw_columns = [] + select_stmt._setup_joins = select_stmt._with_options = () -# TODO: use pep-673 when feasible SelfSelect = typing.TypeVar("SelfSelect", bound="Select") @@ -4258,9 +4925,11 @@ class Select( __visit_name__ = "select" - _setup_joins: Tuple[TODO_Any, ...] = () + _setup_joins: Tuple[_SetupJoinsElement, ...] = () _memoized_select_entities: Tuple[TODO_Any, ...] = () + _raw_columns: List[_ColumnsClauseElement] + _distinct = False _distinct_on: Tuple[ColumnElement[Any], ...] = () _correlate: Tuple[FromClause, ...] = () @@ -4269,12 +4938,12 @@ class Select( _having_criteria: Tuple[ColumnElement[Any], ...] = () _from_obj: Tuple[FromClause, ...] = () _auto_correlate = True - + _is_select_statement = True _compile_options: CacheableOptions = ( SelectState.default_select_compile_options ) - _traverse_internals = ( + _traverse_internals: _TraverseInternalsType = ( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ( @@ -4306,12 +4975,14 @@ class Select( + Executable._executable_traverse_internals ) - _cache_key_traversal = _traverse_internals + [ + _cache_key_traversal: _CacheKeyTraversalType = _traverse_internals + [ ("_compile_options", InternalTraversal.dp_has_cache_key) ] + _compile_state_factory: Type[SelectState] + @classmethod - def _create_raw_select(cls, **kw) -> "Select": + def _create_raw_select(cls, **kw: Any) -> Select: """Create a :class:`.Select` using raw ``__new__`` with no coercions. Used internally to build up :class:`.Select` constructs with @@ -4330,6 +5001,12 @@ class Select( :func:`_sql.select` function. """ + things = [ + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) + for ent in entities + ] self._raw_columns = [ coercions.expect( @@ -4340,7 +5017,7 @@ class Select( GenerativeSelect.__init__(self) - def _scalar_type(self): + def _scalar_type(self) -> TypeEngine[Any]: elem = self._raw_columns[0] cols = list(elem._select_iterable) return cols[0].type @@ -4446,7 +5123,12 @@ class Select( @_generative def join( - self: SelfSelect, target, onclause=None, *, isouter=False, full=False + self: SelfSelect, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + isouter: bool = False, + full: bool = False, ) -> SelfSelect: r"""Create a SQL JOIN against this :class:`_expression.Select` object's criterion @@ -4505,17 +5187,32 @@ class Select( :meth:`_expression.Select.outerjoin` """ # noqa: E501 - target = coercions.expect( + join_target = coercions.expect( roles.JoinTargetRole, target, apply_propagate_attrs=self ) if onclause is not None: - onclause = coercions.expect(roles.OnClauseRole, onclause) + onclause_element = coercions.expect(roles.OnClauseRole, onclause) + else: + onclause_element = None + self._setup_joins += ( - (target, onclause, None, {"isouter": isouter, "full": full}), + ( + join_target, + onclause_element, + None, + {"isouter": isouter, "full": full}, + ), ) return self - def outerjoin_from(self, from_, target, onclause=None, *, full=False): + def outerjoin_from( + self: SelfSelect, + from_: _FromClauseArgument, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> SelfSelect: r"""Create a SQL LEFT OUTER JOIN against this :class:`_expression.Select` object's criterion and apply generatively, returning the newly resulting @@ -4531,12 +5228,12 @@ class Select( @_generative def join_from( self: SelfSelect, - from_, - target, - onclause=None, + from_: _FromClauseArgument, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, *, - isouter=False, - full=False, + isouter: bool = False, + full: bool = False, ) -> SelfSelect: r"""Create a SQL JOIN against this :class:`_expression.Select` object's criterion @@ -4586,18 +5283,31 @@ class Select( from_ = coercions.expect( roles.FromClauseRole, from_, apply_propagate_attrs=self ) - target = coercions.expect( + join_target = coercions.expect( roles.JoinTargetRole, target, apply_propagate_attrs=self ) if onclause is not None: - onclause = coercions.expect(roles.OnClauseRole, onclause) + onclause_element = coercions.expect(roles.OnClauseRole, onclause) + else: + onclause_element = None self._setup_joins += ( - (target, onclause, from_, {"isouter": isouter, "full": full}), + ( + join_target, + onclause_element, + from_, + {"isouter": isouter, "full": full}, + ), ) return self - def outerjoin(self, target, onclause=None, *, full=False): + def outerjoin( + self: SelfSelect, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> SelfSelect: """Create a left outer join. Parameters are the same as that of :meth:`_expression.Select.join`. @@ -4634,7 +5344,7 @@ class Select( """ return self.join(target, onclause=onclause, isouter=True, full=full) - def get_final_froms(self): + def get_final_froms(self) -> Sequence[FromClause]: """Compute the final displayed list of :class:`_expression.FromClause` elements. @@ -4671,6 +5381,7 @@ class Select( :attr:`_sql.Select.columns_clause_froms` """ + return self._compile_state_factory(self, None)._get_display_froms() @util.deprecated_property( @@ -4678,7 +5389,7 @@ class Select( "The :attr:`_expression.Select.froms` attribute is moved to " "the :meth:`_expression.Select.get_final_froms` method.", ) - def froms(self): + def froms(self) -> Sequence[FromClause]: """Return the displayed list of :class:`_expression.FromClause` elements. @@ -4687,7 +5398,7 @@ class Select( return self.get_final_froms() @property - def columns_clause_froms(self): + def columns_clause_froms(self) -> List[FromClause]: """Return the set of :class:`_expression.FromClause` objects implied by the columns clause of this SELECT statement. @@ -4720,7 +5431,7 @@ class Select( return iter(self._all_selected_columns) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: if self in fromclause._cloned_set: return True @@ -4729,7 +5440,9 @@ class Select( return True return False - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: # Select() object has been cloned and probably adapted by the # given clone function. Apply the cloning function to internal # objects @@ -4786,13 +5499,15 @@ class Select( def get_children(self, **kwargs): return itertools.chain( super(Select, self).get_children( - omit_attrs=["_from_obj", "_correlate", "_correlate_except"] + omit_attrs=("_from_obj", "_correlate", "_correlate_except") ), self._iterate_from_elements(), ) @_generative - def add_columns(self: SelfSelect, *columns) -> SelfSelect: + def add_columns( + self: SelfSelect, *columns: _ColumnsClauseArgument + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given column expressions added to its columns clause. @@ -4816,7 +5531,9 @@ class Select( ] return self - def _set_entities(self, entities): + def _set_entities( + self, entities: Iterable[_ColumnsClauseArgument] + ) -> None: self._raw_columns = [ coercions.expect( roles.ColumnsClauseRole, ent, apply_propagate_attrs=self @@ -4830,7 +5547,7 @@ class Select( "be removed in a future release. Please use " ":meth:`_expression.Select.add_columns`", ) - def column(self, column): + def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given column expression added to its columns clause. @@ -4847,7 +5564,9 @@ class Select( return self.add_columns(column) @util.preload_module("sqlalchemy.sql.util") - def reduce_columns(self, only_synonyms=True): + def reduce_columns( + self: SelfSelect, only_synonyms: bool = True + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with redundantly named, equivalently-valued columns removed from the columns clause. @@ -4880,7 +5599,9 @@ class Select( @_generative def with_only_columns( - self: SelfSelect, *columns, maintain_column_froms=False + self: SelfSelect, + *columns: _ColumnsClauseArgument, + maintain_column_froms: bool = False, ) -> SelfSelect: r"""Return a new :func:`_expression.select` construct with its columns clause replaced with the given columns. @@ -4941,7 +5662,9 @@ class Select( self._assert_no_memoizations() if maintain_column_froms: - self.select_from.non_generative(self, *self.columns_clause_froms) + self.select_from.non_generative( # type: ignore + self, *self.columns_clause_froms + ) # then memoize the FROMs etc. _MemoizedSelectEntities._generate_for_statement(self) @@ -4974,7 +5697,9 @@ class Select( _whereclause = whereclause @_generative - def where(self: SelfSelect, *whereclause) -> SelfSelect: + def where( + self: SelfSelect, *whereclause: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any. @@ -4984,24 +5709,33 @@ class Select( assert isinstance(self._where_criteria, tuple) for criterion in whereclause: - where_criteria = coercions.expect(roles.WhereHavingRole, criterion) + where_criteria: ColumnElement[Any] = coercions.expect( + roles.WhereHavingRole, criterion + ) self._where_criteria += (where_criteria,) return self @_generative - def having(self: SelfSelect, having) -> SelfSelect: + def having( + self: SelfSelect, *having: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given expression added to its HAVING clause, joined to the existing clause via AND, if any. """ - self._having_criteria += ( - coercions.expect(roles.WhereHavingRole, having), - ) + + for criterion in having: + having_criteria = coercions.expect( + roles.WhereHavingRole, criterion + ) + self._having_criteria += (having_criteria,) return self @_generative - def distinct(self: SelfSelect, *expr) -> SelfSelect: + def distinct( + self: SelfSelect, *expr: _ColumnExpressionArgument[Any] + ) -> SelfSelect: r"""Return a new :func:`_expression.select` construct which will apply DISTINCT to its columns clause. @@ -5023,7 +5757,9 @@ class Select( return self @_generative - def select_from(self: SelfSelect, *froms) -> SelfSelect: + def select_from( + self: SelfSelect, *froms: _FromClauseArgument + ) -> SelfSelect: r"""Return a new :func:`_expression.select` construct with the given FROM expression(s) merged into its list of FROM objects. @@ -5067,7 +5803,10 @@ class Select( return self @_generative - def correlate(self: SelfSelect, *fromclauses) -> SelfSelect: + def correlate( + self: SelfSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfSelect: r"""Return a new :class:`_expression.Select` which will correlate the given FROM clauses to that of an enclosing :class:`_expression.Select`. @@ -5106,10 +5845,10 @@ class Select( none of its FROM entries, and all will render unconditionally in the local FROM clause. - :param \*fromclauses: a list of one or more - :class:`_expression.FromClause` - constructs, or other compatible constructs (i.e. ORM-mapped - classes) to become part of the correlate collection. + :param \*fromclauses: one or more :class:`.FromClause` or other + FROM-compatible construct such as an ORM mapped entity to become part + of the correlate collection; alternatively pass a single value + ``None`` to remove all existing correlations. .. seealso:: @@ -5119,8 +5858,16 @@ class Select( """ + # tests failing when we try to change how these + # arguments are passed + self._auto_correlate = False - if fromclauses and fromclauses[0] in {None, False}: + if not fromclauses or fromclauses[0] in {None, False}: + if len(fromclauses) > 1: + raise exc.ArgumentError( + "additional FROM objects not accepted when " + "passing None/False to correlate()" + ) self._correlate = () else: self._correlate = self._correlate + tuple( @@ -5129,7 +5876,10 @@ class Select( return self @_generative - def correlate_except(self: SelfSelect, *fromclauses) -> SelfSelect: + def correlate_except( + self: SelfSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfSelect: r"""Return a new :class:`_expression.Select` which will omit the given FROM clauses from the auto-correlation process. @@ -5141,9 +5891,9 @@ class Select( all other FROM elements remain subject to normal auto-correlation behaviors. - If ``None`` is passed, the :class:`_expression.Select` - object will correlate - all of its FROM entries. + If ``None`` is passed, or no arguments are passed, + the :class:`_expression.Select` object will correlate all of its + FROM entries. :param \*fromclauses: a list of one or more :class:`_expression.FromClause` @@ -5159,16 +5909,22 @@ class Select( """ self._auto_correlate = False - if fromclauses and fromclauses[0] in {None, False}: + if not fromclauses or fromclauses[0] in {None, False}: + if len(fromclauses) > 1: + raise exc.ArgumentError( + "additional FROM objects not accepted when " + "passing None/False to correlate_except()" + ) self._correlate_except = () else: self._correlate_except = (self._correlate_except or ()) + tuple( coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) + return self - @HasMemoized.memoized_attribute - def selected_columns(self): + @HasMemoized_ro_memoized_attribute + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -5214,18 +5970,22 @@ class Select( # generates the actual names used in the SELECT string. that # method is more complex because it also renders columns that are # fully ambiguous, e.g. same column more than once. - conv = SelectState._column_naming_convention(self._label_style) + conv = cast( + "Callable[[Any], str]", + SelectState._column_naming_convention(self._label_style), + ) - return ColumnCollection( + cc: ColumnCollection[str, ColumnElement[Any]] = ColumnCollection( [ (conv(c), c) for c in self._all_selected_columns if is_column_element(c) ] - ).as_readonly() + ) + return cc.as_readonly() @HasMemoized.memoized_attribute - def _all_selected_columns(self) -> Sequence[ColumnElement[Any]]: + def _all_selected_columns(self) -> _SelectIterable: meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) @@ -5234,173 +5994,9 @@ class Select( self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) return self - def _generate_columns_plus_names( - self, anon_for_dupe_key: bool - ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - """Generate column names as rendered in a SELECT statement by - the compiler. - - This is distinct from the _column_naming_convention generator that's - intended for population of .c collections and similar, which has - different rules. the collection returned here calls upon the - _column_naming_convention as well. - - """ - cols = self._all_selected_columns - - key_naming_convention = SelectState._column_naming_convention( - self._label_style - ) - - names = {} - - result = [] - result_append = result.append - - table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL - label_style_none = self._label_style is LABEL_STYLE_NONE - - # a counter used for "dedupe" labels, which have double underscores - # in them and are never referred by name; they only act - # as positional placeholders. they need only be unique within - # the single columns clause they're rendered within (required by - # some dbs such as mysql). So their anon identity is tracked against - # a fixed counter rather than hash() identity. - dedupe_hash = 1 - - for c in cols: - repeated = False - - if not c._render_label_in_columns_clause: - effective_name = ( - required_label_name - ) = fallback_label_name = None - elif label_style_none: - effective_name = required_label_name = None - fallback_label_name = c._non_anon_label or c._anon_name_label - else: - if table_qualified: - required_label_name = ( - effective_name - ) = fallback_label_name = c._tq_label - else: - effective_name = fallback_label_name = c._non_anon_label - required_label_name = None - - if effective_name is None: - # it seems like this could be _proxy_key and we would - # not need _expression_label but it isn't - # giving us a clue when to use anon_label instead - expr_label = c._expression_label - if expr_label is None: - repeated = c._anon_name_label in names - names[c._anon_name_label] = c - effective_name = required_label_name = None - - if repeated: - # here, "required_label_name" is sent as - # "None" and "fallback_label_name" is sent. - if table_qualified: - fallback_label_name = ( - c._dedupe_anon_tq_label_idx(dedupe_hash) - ) - dedupe_hash += 1 - else: - fallback_label_name = c._dedupe_anon_label_idx( - dedupe_hash - ) - dedupe_hash += 1 - else: - fallback_label_name = c._anon_name_label - else: - required_label_name = ( - effective_name - ) = fallback_label_name = expr_label - - if effective_name is not None: - if effective_name in names: - # when looking to see if names[name] is the same column as - # c, use hash(), so that an annotated version of the column - # is seen as the same as the non-annotated - if hash(names[effective_name]) != hash(c): - - # different column under the same name. apply - # disambiguating label - if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._anon_tq_label - else: - required_label_name = ( - fallback_label_name - ) = c._anon_name_label - - if anon_for_dupe_key and required_label_name in names: - # here, c._anon_tq_label is definitely unique to - # that column identity (or annotated version), so - # this should always be true. - # this is also an infrequent codepath because - # you need two levels of duplication to be here - assert hash(names[required_label_name]) == hash(c) - - # the column under the disambiguating label is - # already present. apply the "dedupe" label to - # subsequent occurrences of the column so that the - # original stays non-ambiguous - if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) - dedupe_hash += 1 - else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) - dedupe_hash += 1 - repeated = True - else: - names[required_label_name] = c - elif anon_for_dupe_key: - # same column under the same name. apply the "dedupe" - # label so that the original stays non-ambiguous - if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) - dedupe_hash += 1 - else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) - dedupe_hash += 1 - repeated = True - else: - names[effective_name] = c - - result_append( - ( - # string label name, if non-None, must be rendered as a - # label, i.e. "AS <name>" - required_label_name, - # proxy_key that is to be part of the result map for this - # col. this is also the key in a fromclause.c or - # select.selected_columns collection - key_naming_convention(c), - # name that can be used to render an "AS <name>" when - # we have to render a label even though - # required_label_name was not given - fallback_label_name, - # the ColumnElement itself - c, - # True if this is a duplicate of a previous column - # in the list of columns - repeated, - ) - ) - - return result - - def _generate_fromclause_column_proxies(self, subquery): + def _generate_fromclause_column_proxies( + self, subquery: FromClause + ) -> None: """Generate column proxies to place in the exported ``.c`` collection of a subquery.""" @@ -5418,7 +6014,7 @@ class Select( c, repeated, ) in (self._generate_columns_plus_names(False)) - if not c._is_text_clause + if is_column_element(c) ] subquery._columns._populate_separate_keys(prox) @@ -5428,7 +6024,10 @@ class Select( self._order_by_clause.clauses ) - def self_group(self, against=None): + def self_group( + self: Self, against: Optional[OperatorType] = None + ) -> Union[SelectStatementGrouping, Self]: + ... """Return a 'grouping' construct as per the :class:`_expression.ClauseElement` specification. @@ -5445,7 +6044,9 @@ class Select( else: return SelectStatementGrouping(self) - def union(self, *other, **kwargs): + def union( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``UNION`` of this select() construct against the given selectables provided as positional arguments. @@ -5460,9 +6061,11 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_union(self, *other, **kwargs) + return CompoundSelect._create_union(self, *other) - def union_all(self, *other, **kwargs): + def union_all( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``UNION ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -5477,9 +6080,11 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_union_all(self, *other, **kwargs) + return CompoundSelect._create_union_all(self, *other) - def except_(self, *other, **kwargs): + def except_( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``EXCEPT`` of this select() construct against the given selectable provided as positional arguments. @@ -5490,13 +6095,12 @@ class Select( multiple elements are now accepted. - :param \**kwargs: keyword arguments are forwarded to the constructor - for the newly created :class:`_sql.CompoundSelect` object. - """ - return CompoundSelect._create_except(self, *other, **kwargs) + return CompoundSelect._create_except(self, *other) - def except_all(self, *other, **kwargs): + def except_all( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``EXCEPT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -5507,13 +6111,12 @@ class Select( multiple elements are now accepted. - :param \**kwargs: keyword arguments are forwarded to the constructor - for the newly created :class:`_sql.CompoundSelect` object. - """ - return CompoundSelect._create_except_all(self, *other, **kwargs) + return CompoundSelect._create_except_all(self, *other) - def intersect(self, *other, **kwargs): + def intersect( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``INTERSECT`` of this select() construct against the given selectables provided as positional arguments. @@ -5528,9 +6131,11 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_intersect(self, *other, **kwargs) + return CompoundSelect._create_intersect(self, *other) - def intersect_all(self, *other, **kwargs): + def intersect_all( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``INTERSECT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -5545,13 +6150,17 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_intersect_all(self, *other, **kwargs) + return CompoundSelect._create_intersect_all(self, *other) -SelfScalarSelect = typing.TypeVar("SelfScalarSelect", bound="ScalarSelect") +SelfScalarSelect = typing.TypeVar( + "SelfScalarSelect", bound="ScalarSelect[Any]" +) -class ScalarSelect(roles.InElementRole, Generative, Grouping): +class ScalarSelect( + roles.InElementRole, Generative, GroupedElement, ColumnElement[_T] +): """Represent a scalar subquery. @@ -5570,15 +6179,33 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): """ - _from_objects = [] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ] + + _from_objects: List[FromClause] = [] _is_from_container = True - _is_implicitly_boolean = False + if not TYPE_CHECKING: + _is_implicitly_boolean = False inherit_cache = True + element: SelectBase + def __init__(self, element): self.element = element self.type = element._scalar_type() + def __getattr__(self, attr): + return getattr(self.element, attr) + + def __getstate__(self): + return {"element": self.element, "type": self.type} + + def __setstate__(self, state): + self.element = state["element"] + self.type = state["type"] + @property def columns(self): raise exc.InvalidRequestError( @@ -5590,19 +6217,39 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): c = columns @_generative - def where(self: SelfScalarSelect, crit) -> SelfScalarSelect: + def where( + self: SelfScalarSelect, crit: _ColumnExpressionArgument[bool] + ) -> SelfScalarSelect: """Apply a WHERE clause to the SELECT statement referred to by this :class:`_expression.ScalarSelect`. """ - self.element = self.element.where(crit) + self.element = cast(Select, self.element).where(crit) return self - def self_group(self, **kwargs): + @overload + def self_group( + self: ScalarSelect[Any], against: Optional[OperatorType] = None + ) -> ScalarSelect[Any]: + ... + + @overload + def self_group( + self: ColumnElement[Any], against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + ... + + def self_group( + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + return self @_generative - def correlate(self: SelfScalarSelect, *fromclauses) -> SelfScalarSelect: + def correlate( + self: SelfScalarSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfScalarSelect: r"""Return a new :class:`_expression.ScalarSelect` which will correlate the given FROM clauses to that of an enclosing :class:`_expression.Select`. @@ -5631,12 +6278,13 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): """ - self.element = self.element.correlate(*fromclauses) + self.element = cast(Select, self.element).correlate(*fromclauses) return self @_generative def correlate_except( - self: SelfScalarSelect, *fromclauses + self: SelfScalarSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], ) -> SelfScalarSelect: r"""Return a new :class:`_expression.ScalarSelect` which will omit the given FROM @@ -5668,11 +6316,16 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): """ - self.element = self.element.correlate_except(*fromclauses) + self.element = cast(Select, self.element).correlate_except( + *fromclauses + ) return self -class Exists(UnaryExpression[_T]): +SelfExists = TypeVar("SelfExists", bound="Exists") + + +class Exists(UnaryExpression[bool]): """Represent an ``EXISTS`` clause. See :func:`_sql.exists` for a description of usage. @@ -5682,10 +6335,14 @@ class Exists(UnaryExpression[_T]): """ - _from_objects = () inherit_cache = True - def __init__(self, __argument=None): + def __init__( + self, + __argument: Optional[ + Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + ] = None, + ): if __argument is None: s = Select(literal_column("*")).scalar_subquery() elif isinstance(__argument, (SelectBase, ScalarSelect)): @@ -5701,12 +6358,16 @@ class Exists(UnaryExpression[_T]): wraps_column_expression=True, ) + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] + def _regroup(self, fn): element = self.element._ungroup() element = fn(element) return element.self_group(against=operators.exists) - def select(self) -> "Select": + def select(self) -> Select: r"""Return a SELECT of this :class:`_expression.Exists`. e.g.:: @@ -5726,7 +6387,10 @@ class Exists(UnaryExpression[_T]): return Select(self) - def correlate(self, *fromclause): + def correlate( + self: SelfExists, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfExists: """Apply correlation to the subquery noted by this :class:`_sql.Exists`. .. seealso:: @@ -5736,11 +6400,14 @@ class Exists(UnaryExpression[_T]): """ e = self._clone() e.element = self._regroup( - lambda element: element.correlate(*fromclause) + lambda element: element.correlate(*fromclauses) ) return e - def correlate_except(self, *fromclause): + def correlate_except( + self: SelfExists, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfExists: """Apply correlation to the subquery noted by this :class:`_sql.Exists`. .. seealso:: @@ -5751,11 +6418,11 @@ class Exists(UnaryExpression[_T]): e = self._clone() e.element = self._regroup( - lambda element: element.correlate_except(*fromclause) + lambda element: element.correlate_except(*fromclauses) ) return e - def select_from(self, *froms): + def select_from(self: SelfExists, *froms: FromClause) -> SelfExists: """Return a new :class:`_expression.Exists` construct, applying the given expression to the :meth:`_expression.Select.select_from` @@ -5772,7 +6439,9 @@ class Exists(UnaryExpression[_T]): e.element = self._regroup(lambda element: element.select_from(*froms)) return e - def where(self, *clause): + def where( + self: SelfExists, *clause: _ColumnExpressionArgument[bool] + ) -> SelfExists: """Return a new :func:`_expression.exists` construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any. @@ -5824,7 +6493,7 @@ class TextualSelect(SelectBase): _label_style = LABEL_STYLE_NONE - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("column_args", InternalTraversal.dp_clauseelement_list), ] + SupportsCloneAnnotations._clone_annotations_traverse_internals @@ -5842,8 +6511,8 @@ class TextualSelect(SelectBase): ] self.positional = positional - @HasMemoized.memoized_attribute - def selected_columns(self): + @HasMemoized_ro_memoized_attribute + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -5868,6 +6537,13 @@ class TextualSelect(SelectBase): (c.key, c) for c in self.column_args ).as_readonly() + # def _generate_columns_plus_names( + # self, anon_for_dupe_key: bool + # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: + # return Select._generate_columns_plus_names( + # self, anon_for_dupe_key=anon_for_dupe_key + # ) + @util.non_memoized_property def _all_selected_columns(self) -> _SelectIterable: return self.column_args @@ -5880,7 +6556,9 @@ class TextualSelect(SelectBase): @_generative def bindparams( - self: SelfTextualSelect, *binds, **bind_as_values + self: SelfTextualSelect, + *binds: BindParameter[Any], + **bind_as_values: Any, ) -> SelfTextualSelect: self.element = self.element.bindparams(*binds, **bind_as_values) return self |