diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-13 13:37:11 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-15 21:38:29 -0400 |
commit | 6acf5d2fca4a988a77481b82662174e8015a6b37 (patch) | |
tree | 73e2868a51b8b7ac46d7b3b7f9562c1d011f6e1b /lib/sqlalchemy/sql/elements.py | |
parent | 35f82173e04b3209e07fcfc0606a7614108d018e (diff) | |
download | sqlalchemy-6acf5d2fca4a988a77481b82662174e8015a6b37.tar.gz |
pep-484 - SQL column operations
note we are taking out the
ColumnOperartors[SQLCoreOperations] thing; not really clear
why that was needed and at the moment it seems I was likely
confused.
Change-Id: I834b75f9b44f91b97e29f2e1a7b1029bd910e0a1
Diffstat (limited to 'lib/sqlalchemy/sql/elements.py')
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 1284 |
1 files changed, 793 insertions, 491 deletions
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 08d632afd..fdb3fc8bb 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -12,20 +12,28 @@ from __future__ import annotations +from decimal import Decimal +from enum import IntEnum import itertools import operator import re import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict +from typing import FrozenSet from typing import Generic +from typing import Iterable from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence -from typing import Text as typing_Text +from typing import Set +from typing import Tuple as typing_Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -34,10 +42,15 @@ from . import operators from . import roles from . import traversals from . import type_api +from ._typing import has_schema_attr +from ._typing import is_named_from_clause +from ._typing import is_quoted_name +from ._typing import is_tuple_type from .annotation import Annotated from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative +from .base import _NoArg from .base import Executable from .base import HasMemoized from .base import Immutable @@ -57,30 +70,47 @@ from .. import exc from .. import inspection from .. import util from ..util.langhelpers import TypingOnly +from ..util.typing import Literal if typing.TYPE_CHECKING: - from decimal import Decimal - + from ._typing import _ColumnExpression + from ._typing import _PropagateAttrsType + from ._typing import _TypeEngineArgument + from .cache_key import CacheKey from .compiler import Compiled from .compiler import SQLCompiler + from .functions import FunctionElement from .operators import OperatorType + from .schema import Column + from .schema import DefaultGenerator + from .schema import ForeignKey from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import ReturnsRows from .selectable import Select - from .sqltypes import Boolean # noqa + from .selectable import TableClause + from .sqltypes import Boolean + from .sqltypes import TupleType from .type_api import TypeEngine + from .visitors import _TraverseInternalsType from ..engine import Connection from ..engine import Dialect from ..engine import Engine from ..engine.base import _CompiledCacheType - from ..engine.base import _SchemaTranslateMapType - + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _SchemaTranslateMapType + from ..engine.interfaces import CacheStats + from ..engine.result import Result -_NUMERIC = Union[complex, "Decimal"] +_NUMERIC = Union[complex, Decimal] +_NUMBER = Union[complex, int, Decimal] _T = TypeVar("_T", bound="Any") _OPT = TypeVar("_OPT", bound="Any") _NT = TypeVar("_NT", bound="_NUMERIC") -_ST = TypeVar("_ST", bound="typing_Text") + +_NMT = TypeVar("_NMT", bound="_NUMBER") def literal(value, type_=None): @@ -210,28 +240,27 @@ class CompilerElement(Visitable): """ - if not dialect: + if dialect is None: if bind: dialect = bind.dialect + elif self.stringify_dialect == "default": + default = util.preloaded.engine_default + dialect = default.StrCompileDialect() else: - if self.stringify_dialect == "default": - default = util.preloaded.engine_default - dialect = default.StrCompileDialect() - else: - url = util.preloaded.engine_url - dialect = url.URL.create( - self.stringify_dialect - ).get_dialect()() + url = util.preloaded.engine_url + dialect = url.URL.create( + self.stringify_dialect + ).get_dialect()() return self._compiler(dialect, **kw) - def _compiler(self, dialect, **kw): + def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled: """Return a compiler appropriate for this ClauseElement, given a Dialect.""" return dialect.statement_compiler(dialect, self, **kw) - def __str__(self): + def __str__(self) -> str: return str(self.compile()) @@ -253,16 +282,17 @@ class ClauseElement( __visit_name__ = "clause" - _propagate_attrs = util.immutabledict() + _propagate_attrs: _PropagateAttrsType = util.immutabledict() """like annotations, however these propagate outwards liberally as SQL constructs are built, and are set up at construction time. """ - _from_objects = [] - bind = None - description = None - _is_clone_of = None + @util.memoized_property + def description(self) -> Optional[str]: + return None + + _is_clone_of: Optional[ClauseElement] = None is_clause_element = True is_selectable = False @@ -281,10 +311,25 @@ class ClauseElement( _is_singleton_constant = False _is_immutable = False - _order_by_label_element = None + @property + def _order_by_label_element(self) -> Optional[Label[Any]]: + return None _cache_key_traversal = None + negation_clause: ClauseElement + + if typing.TYPE_CHECKING: + + def get_children( + self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any + ) -> Iterable[ClauseElement]: + ... + + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] + def _set_propagate_attrs(self, values): # usually, self._propagate_attrs is empty here. one case where it's # not is a subquery against ORM select, that is then pulled as a @@ -295,7 +340,7 @@ class ClauseElement( self._propagate_attrs = util.immutabledict(values) return self - def _clone(self: SelfClauseElement, **kw) -> SelfClauseElement: + def _clone(self: SelfClauseElement, **kw: Any) -> SelfClauseElement: """Create a shallow copy of this ClauseElement. This method may be used by a generative API. Its also used as @@ -357,7 +402,7 @@ class ClauseElement( """ s = util.column_set() - f = self + f: Optional[ClauseElement] = self # note this creates a cycle, asserted in test_memusage. however, # turning this into a plain @property adds tends of thousands of method @@ -383,16 +428,26 @@ class ClauseElement( return d def _execute_on_connection( - self, connection, distilled_params, execution_options, _force=False - ): + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + _force: bool = False, + ) -> Result: if _force or self.supports_execution: + if TYPE_CHECKING: + assert isinstance(self, Executable) return connection._execute_clauseelement( self, distilled_params, execution_options ) else: raise exc.ObjectNotExecutableError(self) - def unique_params(self, *optionaldict, **kwargs): + def unique_params( + self: SelfClauseElement, + __optionaldict: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> SelfClauseElement: """Return a copy with :func:`_expression.bindparam` elements replaced. @@ -402,11 +457,13 @@ class ClauseElement( used. """ - return self._replace_params(True, optionaldict, kwargs) + return self._replace_params(True, __optionaldict, kwargs) def params( - self, *optionaldict: Dict[str, Any], **kwargs: Any - ) -> ClauseElement: + self: SelfClauseElement, + __optionaldict: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> SelfClauseElement: """Return a copy with :func:`_expression.bindparam` elements replaced. @@ -421,33 +478,32 @@ class ClauseElement( {'foo':7} """ - return self._replace_params(False, optionaldict, kwargs) + return self._replace_params(False, __optionaldict, kwargs) def _replace_params( - self, + self: SelfClauseElement, unique: bool, optionaldict: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - ) -> ClauseElement: + ) -> SelfClauseElement: - if len(optionaldict) == 1: - kwargs.update(optionaldict[0]) - elif len(optionaldict) > 1: - raise exc.ArgumentError( - "params() takes zero or one positional dictionary argument" - ) + if optionaldict: + kwargs.update(optionaldict) - def visit_bindparam(bind): + def visit_bindparam(bind: BindParameter[Any]) -> None: if bind.key in kwargs: bind.value = kwargs[bind.key] bind.required = False if unique: bind._convert_to_unique() - return cloned_traverse( - self, - {"maintain_key": True, "detect_subquery_cols": True}, - {"bindparam": visit_bindparam}, + return cast( + SelfClauseElement, + cloned_traverse( + self, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ), ) def compare(self, other, **kw): @@ -501,18 +557,26 @@ class ClauseElement( def _compile_w_cache( self, dialect: Dialect, - compiled_cache: Optional[_CompiledCacheType] = None, - column_keys: Optional[List[str]] = None, + *, + compiled_cache: Optional[_CompiledCacheType], + column_keys: List[str], for_executemany: bool = False, schema_translate_map: Optional[_SchemaTranslateMapType] = None, **kw: Any, - ): + ) -> typing_Tuple[ + Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ]: + elem_cache_key: Optional[CacheKey] + if compiled_cache is not None and dialect._supports_statement_cache: elem_cache_key = self._generate_cache_key() else: elem_cache_key = None - if elem_cache_key: + if elem_cache_key is not None: + if TYPE_CHECKING: + assert compiled_cache is not None + cache_key, extracted_params = elem_cache_key key = ( dialect, @@ -564,7 +628,7 @@ class ClauseElement( else: return self._negate() - def _negate(self): + def _negate(self) -> ClauseElement: return UnaryExpression( self.self_group(against=operators.inv), operator=operators.inv ) @@ -605,6 +669,9 @@ class DQLDMLClauseElement(ClauseElement): ) -> SQLCompiler: ... + def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: + ... + class CompilerColumnElement( roles.DMLColumnRole, @@ -621,9 +688,7 @@ class CompilerColumnElement( __slots__ = () -class SQLCoreOperations( - Generic[_T], ColumnOperators["SQLCoreOperations"], TypingOnly -): +class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): __slots__ = () # annotations for comparison methods @@ -631,173 +696,186 @@ class SQLCoreOperations( # redefined with the specific types returned by ColumnElement hierarchies if typing.TYPE_CHECKING: + _propagate_attrs: _PropagateAttrsType + def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement: + ) -> ColumnElement[Any]: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnElement: + ) -> ColumnElement[Any]: ... def op( self, - opstring: Any, + opstring: str, precedence: int = 0, is_comparison: bool = False, - return_type: Optional[ - Union[Type["TypeEngine[_OPT]"], "TypeEngine[_OPT]"] - ] = None, - python_impl=None, - ) -> Callable[[Any], "BinaryExpression[_OPT]"]: + return_type: Optional[_TypeEngineArgument[_OPT]] = None, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[_OPT]]: ... def bool_op( - self, opstring: Any, precedence: int = 0, python_impl=None - ) -> Callable[[Any], "BinaryExpression[bool]"]: + self, + opstring: str, + precedence: int = 0, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[bool]]: ... - def __and__(self, other: Any) -> "BooleanClauseList": + def __and__(self, other: Any) -> BooleanClauseList: ... - def __or__(self, other: Any) -> "BooleanClauseList": + def __or__(self, other: Any) -> BooleanClauseList: ... - def __invert__(self) -> "UnaryExpression[_T]": + def __invert__(self) -> ColumnElement[_T]: ... - def __lt__(self, other: Any) -> "ColumnElement[bool]": + def __lt__(self, other: Any) -> ColumnElement[bool]: ... - def __le__(self, other: Any) -> "ColumnElement[bool]": + def __le__(self, other: Any) -> ColumnElement[bool]: ... - def __eq__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501 + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... - def __ne__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501 + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... - def is_distinct_from(self, other: Any) -> "ColumnElement[bool]": + def is_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def is_not_distinct_from(self, other: Any) -> "ColumnElement[bool]": + def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def __gt__(self, other: Any) -> "ColumnElement[bool]": + def __gt__(self, other: Any) -> ColumnElement[bool]: ... - def __ge__(self, other: Any) -> "ColumnElement[bool]": + def __ge__(self, other: Any) -> ColumnElement[bool]: ... - def __neg__(self) -> "UnaryExpression[_T]": + def __neg__(self) -> UnaryExpression[_T]: ... - def __contains__(self, other: Any) -> "ColumnElement[bool]": + def __contains__(self, other: Any) -> ColumnElement[bool]: ... - def __getitem__(self, index: Any) -> "ColumnElement": + def __getitem__(self, index: Any) -> ColumnElement[Any]: ... @overload - def concat( - self: "SQLCoreOperations[_ST]", other: Any - ) -> "ColumnElement[_ST]": + def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: ... @overload - def concat(self, other: Any) -> "ColumnElement": + def concat(self, other: Any) -> ColumnElement[Any]: ... - def concat(self, other: Any) -> "ColumnElement": + def concat(self, other: Any) -> ColumnElement[Any]: ... - def like(self, other: Any, escape=None) -> "BinaryExpression[bool]": + def like( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... - def ilike(self, other: Any, escape=None) -> "BinaryExpression[bool]": + def ilike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... def in_( self, - other: Union[Sequence[Any], "BindParameter", "Select"], - ) -> "BinaryExpression[bool]": + other: Union[Sequence[Any], BindParameter[Any], Select], + ) -> BinaryExpression[bool]: ... def not_in( self, - other: Union[Sequence[Any], "BindParameter", "Select"], - ) -> "BinaryExpression[bool]": + other: Union[Sequence[Any], BindParameter[Any], Select], + ) -> BinaryExpression[bool]: ... def not_like( - self, other: Any, escape=None - ) -> "BinaryExpression[bool]": + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... def not_ilike( - self, other: Any, escape=None - ) -> "BinaryExpression[bool]": + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... - def is_(self, other: Any) -> "BinaryExpression[bool]": + def is_(self, other: Any) -> BinaryExpression[bool]: ... - def is_not(self, other: Any) -> "BinaryExpression[bool]": + def is_not(self, other: Any) -> BinaryExpression[bool]: ... def startswith( - self, other: Any, escape=None, autoescape=False - ) -> "ColumnElement[bool]": + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... def endswith( - self, other: Any, escape=None, autoescape=False - ) -> "ColumnElement[bool]": + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... - def contains(self, other: Any, **kw: Any) -> "ColumnElement[bool]": + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: ... - def match(self, other: Any, **kwargs) -> "ColumnElement[bool]": + def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: ... - def regexp_match(self, pattern, flags=None) -> "ColumnElement[bool]": + def regexp_match( + self, pattern: Any, flags: Optional[str] = None + ) -> ColumnElement[bool]: ... def regexp_replace( - self, pattern, replacement, flags=None - ) -> "ColumnElement": + self, pattern: Any, replacement: Any, flags: Optional[str] = None + ) -> ColumnElement[str]: ... - def desc(self) -> "UnaryExpression[_T]": + def desc(self) -> UnaryExpression[_T]: ... - def asc(self) -> "UnaryExpression[_T]": + def asc(self) -> UnaryExpression[_T]: ... - def nulls_first(self) -> "UnaryExpression[_T]": + def nulls_first(self) -> UnaryExpression[_T]: ... - def nulls_last(self) -> "UnaryExpression[_T]": + def nulls_last(self) -> UnaryExpression[_T]: ... - def collate(self, collation) -> "CollationClause": + def collate(self, collation: str) -> CollationClause: ... def between( - self, cleft, cright, symmetric=False - ) -> "ColumnElement[bool]": + self, cleft: Any, cright: Any, symmetric: bool = False + ) -> BinaryExpression[bool]: ... - def distinct(self: "SQLCoreOperations[_T]") -> "UnaryExpression[_T]": + def distinct(self: _SQO[_T]) -> UnaryExpression[_T]: ... - def any_(self) -> "CollectionAggregate": + def any_(self) -> CollectionAggregate[Any]: ... - def all_(self) -> "CollectionAggregate": + def all_(self) -> CollectionAggregate[Any]: ... # numeric overloads. These need more tweaking @@ -807,179 +885,173 @@ class SQLCoreOperations( @overload def __add__( - self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]", - other: "Union[_SQO[Optional[_NT]], _SQO[_NT], _NT]", - ) -> "ColumnElement[_NT]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload def __add__( - self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]", + self: _SQO[str], other: Any, - ) -> "ColumnElement[_NUMERIC]": + ) -> ColumnElement[str]: ... - @overload - def __add__( - self: "Union[_SQO[_ST], _SQO[Optional[_ST]]]", - other: Any, - ) -> "ColumnElement[_ST]": + def __add__(self, other: Any) -> ColumnElement[Any]: ... - def __add__(self, other: Any) -> "ColumnElement": + @overload + def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: ... @overload - def __radd__(self, other: Any) -> "ColumnElement[_NUMERIC]": + def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... @overload - def __radd__(self, other: Any) -> "ColumnElement": + def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: ... - def __radd__(self, other: Any) -> "ColumnElement": + def __radd__(self, other: Any) -> ColumnElement[Any]: ... @overload def __sub__( - self: "SQLCoreOperations[_NT]", - other: "Union[SQLCoreOperations[_NT], _NT]", - ) -> "ColumnElement[_NT]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __sub__(self, other: Any) -> "ColumnElement": + def __sub__(self, other: Any) -> ColumnElement[Any]: ... - def __sub__(self, other: Any) -> "ColumnElement": + def __sub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rsub__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __rsub__(self, other: Any) -> "ColumnElement": + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... - def __rsub__(self, other: Any) -> "ColumnElement": + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __mul__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __mul__(self, other: Any) -> "ColumnElement": + def __mul__(self, other: Any) -> ColumnElement[Any]: ... - def __mul__(self, other: Any) -> "ColumnElement": + def __mul__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rmul__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __rmul__(self, other: Any) -> "ColumnElement": + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... - def __rmul__(self, other: Any) -> "ColumnElement": + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __mod__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __mod__(self, other: Any) -> "ColumnElement": + def __mod__(self, other: Any) -> ColumnElement[Any]: ... - def __mod__(self, other: Any) -> "ColumnElement": + def __mod__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rmod__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __rmod__(self, other: Any) -> "ColumnElement": + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... - def __rmod__(self, other: Any) -> "ColumnElement": + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... @overload def __truediv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NUMERIC]: ... @overload - def __truediv__(self, other: Any) -> "ColumnElement": + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... - def __truediv__(self, other: Any) -> "ColumnElement": + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rtruediv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NUMERIC]: ... @overload - def __rtruediv__(self, other: Any) -> "ColumnElement": + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... - def __rtruediv__(self, other: Any) -> "ColumnElement": + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __floordiv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __floordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __floordiv__(self, other: Any) -> "ColumnElement": + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __floordiv__(self, other: Any) -> "ColumnElement": + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rfloordiv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __rfloordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __rfloordiv__(self, other: Any) -> "ColumnElement": + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __rfloordiv__(self, other: Any) -> "ColumnElement": + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... _SQO = SQLCoreOperations +SelfColumnElement = TypeVar("SelfColumnElement", bound="ColumnElement[Any]") + class ColumnElement( roles.ColumnArgumentOrKeyRole, roles.StatementOptionRole, roles.WhereHavingRole, - roles.BinaryElementRole, + roles.BinaryElementRole[_T], roles.OrderByRole, roles.ColumnsClauseRole, roles.LimitOffsetRole, @@ -987,7 +1059,6 @@ class ColumnElement( roles.DDLConstraintColumnRole, roles.DDLExpressionRole, SQLCoreOperations[_T], - operators.ColumnOperators[SQLCoreOperations], DQLDMLClauseElement, ): """Represent a column-oriented SQL expression suitable for usage in the @@ -1069,28 +1140,37 @@ class ColumnElement( __visit_name__ = "column_element" - primary_key = False - foreign_keys = [] - _proxies = () + primary_key: bool = False + _is_clone_of: Optional[ColumnElement[_T]] - _tq_label = None - """The named label that can be used to target - this column in a result set in a "table qualified" context. + @util.memoized_property + def foreign_keys(self) -> Iterable[ForeignKey]: + return [] - This label is almost always the label used when - rendering <expr> AS <label> in a SELECT statement when using - the LABEL_STYLE_TABLENAME_PLUS_COL label style, which is what the legacy - ORM ``Query`` object uses as well. + @util.memoized_property + def _proxies(self) -> List[ColumnElement[Any]]: + return [] - For a regular Column bound to a Table, this is typically the label - <tablename>_<columnname>. For other constructs, different rules - may apply, such as anonymized labels and others. + @util.non_memoized_property + def _tq_label(self) -> Optional[str]: + """The named label that can be used to target + this column in a result set in a "table qualified" context. - .. versionchanged:: 1.4.21 renamed from ``._label`` + This label is almost always the label used when + rendering <expr> AS <label> in a SELECT statement when using + the LABEL_STYLE_TABLENAME_PLUS_COL label style, which is what the + legacy ORM ``Query`` object uses as well. - """ + For a regular Column bound to a Table, this is typically the label + <tablename>_<columnname>. For other constructs, different rules + may apply, such as anonymized labels and others. + + .. versionchanged:: 1.4.21 renamed from ``._label`` + + """ + return None - key = None + key: Optional[str] = None """The 'key' that in some circumstances refers to this object in a Python namespace. @@ -1101,7 +1181,7 @@ class ColumnElement( """ @HasMemoized.memoized_attribute - def _tq_key_label(self): + def _tq_key_label(self) -> Optional[str]: """A label-based version of 'key' that in some circumstances refers to this object in a Python namespace. @@ -1119,17 +1199,17 @@ class ColumnElement( return self._proxy_key @property - def _key_label(self): + def _key_label(self) -> Optional[str]: """legacy; renamed to _tq_key_label""" return self._tq_key_label @property - def _label(self): + def _label(self) -> Optional[str]: """legacy; renamed to _tq_label""" return self._tq_label @property - def _non_anon_label(self): + def _non_anon_label(self) -> Optional[str]: """the 'name' that naturally applies this element when rendered in SQL. @@ -1184,9 +1264,23 @@ class ColumnElement( _is_implicitly_boolean = False - _alt_names = () + _alt_names: Sequence[str] = () - def self_group(self, against=None): + @overload + def self_group( + self: ColumnElement[bool], against: Optional[OperatorType] = None + ) -> ColumnElement[bool]: + ... + + @overload + def self_group( + self: ColumnElement[_T], against: Optional[OperatorType] = None + ) -> ColumnElement[_T]: + ... + + def self_group( + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: if ( against in (operators.and_, operators.or_, operators._asbool) and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity @@ -1197,18 +1291,32 @@ class ColumnElement( else: return self - def _negate(self): + @overload + def _negate(self: ColumnElement[bool]) -> ColumnElement[bool]: + ... + + @overload + def _negate(self: ColumnElement[_T]) -> ColumnElement[_T]: + ... + + def _negate(self) -> ColumnElement[Any]: if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return AsBoolean(self, operators.is_false, operators.is_true) else: - return super(ColumnElement, self)._negate() + return cast("UnaryExpression[_T]", super()._negate()) - @util.memoized_property - def type(self) -> "TypeEngine[_T]": - return type_api.NULLTYPE + type: TypeEngine[_T] + + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + # used for delayed setup of + # type_api + return type_api.NULLTYPE @HasMemoized.memoized_attribute - def comparator(self) -> "TypeEngine.Comparator[_T]": + def comparator(self) -> TypeEngine.Comparator[_T]: try: comparator_factory = self.type.comparator_factory except AttributeError as err: @@ -1219,7 +1327,7 @@ class ColumnElement( else: return comparator_factory(self) - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: try: return getattr(self.comparator, key) except AttributeError as err: @@ -1236,16 +1344,22 @@ class ColumnElement( self, op: operators.OperatorType, *other: Any, - **kwargs, - ) -> "ColumnElement": - return op(self.comparator, *other, **kwargs) + **kwargs: Any, + ) -> ColumnElement[Any]: + return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501 def reverse_operate( - self, op: operators.OperatorType, other: Any, **kwargs - ) -> "ColumnElement": - return op(other, self.comparator, **kwargs) + self, op: operators.OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501 - def _bind_param(self, operator, obj, type_=None, expanding=False): + def _bind_param( + self, + operator: operators.OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + ) -> BindParameter[_T]: return BindParameter( None, obj, @@ -1257,7 +1371,7 @@ class ColumnElement( ) @property - def expression(self): + def expression(self) -> ColumnElement[Any]: """Return a column expression. Part of the inspection interface; returns self. @@ -1266,39 +1380,39 @@ class ColumnElement( return self @property - def _select_iterable(self): + def _select_iterable(self) -> Iterable[ColumnElement[Any]]: return (self,) @util.memoized_property - def base_columns(self): - return util.column_set(c for c in self.proxy_set if not c._proxies) + def base_columns(self) -> FrozenSet[ColumnElement[Any]]: + return frozenset(c for c in self.proxy_set if not c._proxies) @util.memoized_property - def proxy_set(self): - s = util.column_set([self]) - for c in self._proxies: - s.update(c.proxy_set) - return s + def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + return frozenset([self]).union( + itertools.chain.from_iterable(c.proxy_set for c in self._proxies) + ) - def _uncached_proxy_set(self): + def _uncached_proxy_set(self) -> FrozenSet[ColumnElement[Any]]: """An 'uncached' version of proxy set. This is so that we can read annotations from the list of columns without breaking the caching of the above proxy_set. """ - s = util.column_set([self]) - for c in self._proxies: - s.update(c._uncached_proxy_set()) - return s + return frozenset([self]).union( + itertools.chain.from_iterable( + c._uncached_proxy_set() for c in self._proxies + ) + ) - def shares_lineage(self, othercolumn): + def shares_lineage(self, othercolumn: ColumnElement[Any]) -> bool: """Return True if the given :class:`_expression.ColumnElement` has a common ancestor to this :class:`_expression.ColumnElement`.""" return bool(self.proxy_set.intersection(othercolumn.proxy_set)) - def _compare_name_for_result(self, other): + def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool: """Return True if the given column element compares to this one when targeting within a result row.""" @@ -1309,9 +1423,9 @@ class ColumnElement( ) @HasMemoized.memoized_attribute - def _proxy_key(self): + def _proxy_key(self) -> Optional[str]: if self._annotations and "proxy_key" in self._annotations: - return self._annotations["proxy_key"] + return cast(str, self._annotations["proxy_key"]) name = self.key if not name: @@ -1327,7 +1441,7 @@ class ColumnElement( return name @HasMemoized.memoized_attribute - def _expression_label(self): + def _expression_label(self) -> Optional[str]: """a suggested label to use in the case that the column has no name, which should be used if possible as the explicit 'AS <label>' where this expression would normally have an anon label. @@ -1340,18 +1454,18 @@ class ColumnElement( if getattr(self, "name", None) is not None: return None elif self._annotations and "proxy_key" in self._annotations: - return self._annotations["proxy_key"] + return cast(str, self._annotations["proxy_key"]) else: return None def _make_proxy( self, - selectable, + selectable: FromClause, name: Optional[str] = None, - key=None, - name_is_truncatable=False, - **kw, - ): + key: Optional[str] = None, + name_is_truncatable: bool = False, + **kw: Any, + ) -> typing_Tuple[str, ColumnClause[_T]]: """Create a new :class:`_expression.ColumnElement` representing this :class:`_expression.ColumnElement` as it appears in the select list of a descending selectable. @@ -1364,7 +1478,7 @@ class ColumnElement( else: key = name - co = ColumnClause( + co: ColumnClause[_T] = ColumnClause( coercions.expect(roles.TruncatedLabelRole, name) if name_is_truncatable else name, @@ -1376,9 +1490,10 @@ class ColumnElement( co._proxies = [self] if selectable._is_clone_of is not None: co._is_clone_of = selectable._is_clone_of.columns.get(key) + assert key is not None return key, co - def cast(self, type_): + def cast(self, type_: TypeEngine[_T]) -> Cast[_T]: """Produce a type cast, i.e. ``CAST(<expression> AS <type>)``. This is a shortcut to the :func:`_expression.cast` function. @@ -1406,7 +1521,9 @@ class ColumnElement( """ return Label(name, self, self.type) - def _anon_label(self, seed, add_hash=None) -> "_anonymous_label": + def _anon_label( + self, seed: Optional[str], add_hash: Optional[int] = None + ) -> _anonymous_label: while self._is_clone_of is not None: self = self._is_clone_of @@ -1441,7 +1558,7 @@ class ColumnElement( return _anonymous_label.safe_construct(hash_value, seed or "anon") @util.memoized_property - def _anon_name_label(self) -> "_anonymous_label": + def _anon_name_label(self) -> str: """Provides a constant 'anonymous label' for this ColumnElement. This is a label() expression which will be named at compile time. @@ -1462,7 +1579,7 @@ class ColumnElement( return self._anon_label(name) @util.memoized_property - def _anon_key_label(self): + def _anon_key_label(self) -> _anonymous_label: """Provides a constant 'anonymous key label' for this ColumnElement. Compare to ``anon_label``, except that the "key" of the column, @@ -1478,25 +1595,23 @@ class ColumnElement( """ return self._anon_label(self._proxy_key) - @property - @util.deprecated( + @util.deprecated_property( "1.4", "The :attr:`_expression.ColumnElement.anon_label` attribute is now " "private, and the public accessor is deprecated.", ) - def anon_label(self): + def anon_label(self) -> str: return self._anon_name_label - @property - @util.deprecated( + @util.deprecated_property( "1.4", "The :attr:`_expression.ColumnElement.anon_key_label` attribute is " "now private, and the public accessor is deprecated.", ) - def anon_key_label(self): + def anon_key_label(self) -> str: return self._anon_key_label - def _dedupe_anon_label_idx(self, idx): + def _dedupe_anon_label_idx(self, idx: int) -> str: """label to apply to a column that is anon labeled, but repeated in the SELECT, so that we have to make an "extra anon" label that disambiguates it from the previous appearance. @@ -1520,20 +1635,20 @@ class ColumnElement( return self._anon_label(label, add_hash=idx) @util.memoized_property - def _anon_tq_label(self): + def _anon_tq_label(self) -> _anonymous_label: return self._anon_label(getattr(self, "_tq_label", None)) @util.memoized_property - def _anon_tq_key_label(self): + def _anon_tq_key_label(self) -> _anonymous_label: return self._anon_label(getattr(self, "_tq_key_label", None)) - def _dedupe_anon_tq_label_idx(self, idx): + def _dedupe_anon_tq_label_idx(self, idx: int) -> _anonymous_label: label = getattr(self, "_tq_label", None) or "anon" return self._anon_label(label, add_hash=idx) -class WrapsColumnExpression: +class WrapsColumnExpression(ColumnElement[_T]): """Mixin that defines a :class:`_expression.ColumnElement` as a wrapper with special labeling behavior for an expression that already has a name. @@ -1548,25 +1663,27 @@ class WrapsColumnExpression: """ @property - def wrapped_column_expression(self): + def wrapped_column_expression(self) -> ColumnElement[_T]: raise NotImplementedError() - @property - def _tq_label(self): + @util.non_memoized_property + def _tq_label(self) -> Optional[str]: wce = self.wrapped_column_expression if hasattr(wce, "_tq_label"): return wce._tq_label else: return None - _label = _tq_label + @property + def _label(self) -> Optional[str]: + return self._tq_label @property - def _non_anon_label(self): + def _non_anon_label(self) -> Optional[str]: return None - @property - def _anon_name_label(self): + @util.non_memoized_property + def _anon_name_label(self) -> str: wce = self.wrapped_column_expression # this logic tries to get the WrappedColumnExpression to render @@ -1578,9 +1695,9 @@ class WrapsColumnExpression: return nal elif hasattr(wce, "_anon_name_label"): return wce._anon_name_label - return super(WrapsColumnExpression, self)._anon_name_label + return super()._anon_name_label - def _dedupe_anon_label_idx(self, idx): + def _dedupe_anon_label_idx(self, idx: int) -> str: wce = self.wrapped_column_expression nal = wce._non_anon_label if nal: @@ -1589,7 +1706,7 @@ class WrapsColumnExpression: return self._dedupe_anon_tq_label_idx(idx) -SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter") +SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]") class BindParameter(roles.InElementRole, ColumnElement[_T]): @@ -1614,7 +1731,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): __visit_name__ = "bindparam" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("key", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("callable", InternalTraversal.dp_plain_dict), @@ -1622,7 +1739,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): ] key: str - type: TypeEngine + type: TypeEngine[_T] _is_crud = False _is_bind_parameter = True @@ -1634,23 +1751,23 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): def __init__( self, - key, - value=NO_ARG, - type_=None, - unique=False, - required=NO_ARG, - quote=None, - callable_=None, - expanding=False, - isoutparam=False, - literal_execute=False, - _compared_to_operator=None, - _compared_to_type=None, - _is_crud=False, + key: Optional[str], + value: Any = _NoArg.NO_ARG, + type_: Optional[_TypeEngineArgument[_T]] = None, + unique: bool = False, + required: Union[bool, Literal[_NoArg.NO_ARG]] = _NoArg.NO_ARG, + quote: Optional[bool] = None, + callable_: Optional[Callable[[], Any]] = None, + expanding: bool = False, + isoutparam: bool = False, + literal_execute: bool = False, + _compared_to_operator: Optional[OperatorType] = None, + _compared_to_type: Optional[TypeEngine[Any]] = None, + _is_crud: bool = False, ): - if required is NO_ARG: - required = value is NO_ARG and callable_ is None - if value is NO_ARG: + if required is _NoArg.NO_ARG: + required = value is _NoArg.NO_ARG and callable_ is None + if value is _NoArg.NO_ARG: value = None if quote is not None: @@ -1713,12 +1830,19 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): self.type = type_api._resolve_value_to_type(check_value) elif isinstance(type_, type): self.type = type_() - elif type_._is_tuple_type and value: - if expanding: - check_value = value[0] + elif is_tuple_type(type_): + if value: + if expanding: + check_value = value[0] + else: + check_value = value + cast( + "BindParameter[typing_Tuple[Any, ...]]", self + ).type = type_._resolve_values_to_types(check_value) else: - check_value = value - self.type = type_._resolve_values_to_types(check_value) + cast( + "BindParameter[typing_Tuple[Any, ...]]", self + ).type = type_ else: self.type = type_ @@ -1791,7 +1915,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): return c def _clone( - self: SelfBindParameter, maintain_key=False, **kw + self: SelfBindParameter, maintain_key: bool = False, **kw: Any ) -> SelfBindParameter: c = ClauseElement._clone(self, **kw) if not maintain_key and self.unique: @@ -1865,7 +1989,9 @@ class TypeClause(DQLDMLClauseElement): __visit_name__ = "typeclause" - _traverse_internals = [("type", InternalTraversal.dp_type)] + _traverse_internals: _TraverseInternalsType = [ + ("type", InternalTraversal.dp_type) + ] def __init__(self, type_): self.type = type_ @@ -1882,7 +2008,7 @@ class TextClause( roles.OrderByRole, roles.FromClauseRole, roles.SelectStatementRole, - roles.BinaryElementRole, + roles.BinaryElementRole[Any], roles.InElementRole, Executable, DQLDMLClauseElement, @@ -1909,7 +2035,7 @@ class TextClause( __visit_name__ = "textclause" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("_bindparams", InternalTraversal.dp_string_clauseelement_dict), ("text", InternalTraversal.dp_string), ] @@ -1923,7 +2049,9 @@ class TextClause( _render_label_in_columns_clause = False - _hide_froms = () + @property + def _hide_froms(self) -> Iterable[FromClause]: + return () def __and__(self, other): # support use in select.where(), query.filter() @@ -1935,12 +2063,13 @@ class TextClause( # help in those cases where text() is # interpreted in a column expression situation - key = _label = None + key: Optional[str] = None + _label: Optional[str] = None _allow_label_resolve = False - def __init__(self, text): - self._bindparams = {} + def __init__(self, text: str): + self._bindparams: Dict[str, BindParameter[Any]] = {} def repl(m): self._bindparams[m.group(1)] = BindParameter(m.group(1)) @@ -1952,7 +2081,9 @@ class TextClause( @_generative def bindparams( - self: SelfTextClause, *binds, **names_to_values + self: SelfTextClause, + *binds: BindParameter[Any], + **names_to_values: Any, ) -> SelfTextClause: """Establish the values and/or types of bound parameters within this :class:`_expression.TextClause` construct. @@ -2205,7 +2336,7 @@ class TextClause( else col for col in cols ] - keyed_input_cols = [ + keyed_input_cols: List[ColumnClause[Any]] = [ ColumnClause(key, type_) for key, type_ in types.items() ] @@ -2230,7 +2361,7 @@ class TextClause( return self -class Null(SingletonConstant, roles.ConstExprRole, ColumnElement): +class Null(SingletonConstant, roles.ConstExprRole[None], ColumnElement[None]): """Represent the NULL keyword in a SQL statement. :class:`.Null` is accessed as a constant via the @@ -2240,23 +2371,26 @@ class Null(SingletonConstant, roles.ConstExprRole, ColumnElement): __visit_name__ = "null" - _traverse_internals = [] + _traverse_internals: _TraverseInternalsType = [] + _singleton: Null @util.memoized_property def type(self): return type_api.NULLTYPE @classmethod - def _instance(cls): + def _instance(cls) -> Null: """Return a constant :class:`.Null` construct.""" - return Null() + return Null._singleton Null._create_singleton() -class False_(SingletonConstant, roles.ConstExprRole, ColumnElement): +class False_( + SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool] +): """Represent the ``false`` keyword, or equivalent, in a SQL statement. :class:`.False_` is accessed as a constant via the @@ -2265,24 +2399,25 @@ class False_(SingletonConstant, roles.ConstExprRole, ColumnElement): """ __visit_name__ = "false" - _traverse_internals = [] + _traverse_internals: _TraverseInternalsType = [] + _singleton: False_ @util.memoized_property def type(self): return type_api.BOOLEANTYPE - def _negate(self): - return True_() + def _negate(self) -> True_: + return True_._singleton @classmethod - def _instance(cls): - return False_() + def _instance(cls) -> False_: + return False_._singleton False_._create_singleton() -class True_(SingletonConstant, roles.ConstExprRole, ColumnElement): +class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]): """Represent the ``true`` keyword, or equivalent, in a SQL statement. :class:`.True_` is accessed as a constant via the @@ -2292,14 +2427,15 @@ class True_(SingletonConstant, roles.ConstExprRole, ColumnElement): __visit_name__ = "true" - _traverse_internals = [] + _traverse_internals: _TraverseInternalsType = [] + _singleton: True_ @util.memoized_property def type(self): return type_api.BOOLEANTYPE - def _negate(self): - return False_() + def _negate(self) -> False_: + return False_._singleton @classmethod def _ifnone(cls, other): @@ -2309,8 +2445,8 @@ class True_(SingletonConstant, roles.ConstExprRole, ColumnElement): return other @classmethod - def _instance(cls): - return True_() + def _instance(cls) -> True_: + return True_._singleton True_._create_singleton() @@ -2333,18 +2469,18 @@ class ClauseList( _is_clause_list = True - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("clauses", InternalTraversal.dp_clauseelement_list), ("operator", InternalTraversal.dp_operator), ] def __init__( self, - *clauses, - operator=operators.comma_op, - group=True, - group_contents=True, - _flatten_sub_clauses=False, + *clauses: _ColumnExpression[Any], + operator: OperatorType = operators.comma_op, + group: bool = True, + group_contents: bool = True, + _flatten_sub_clauses: bool = False, _literal_as_text_role: Type[roles.SQLRole] = roles.WhereHavingRole, ): self.operator = operator @@ -2405,8 +2541,8 @@ class ClauseList( coercions.expect(self._text_converter_role, clause) ) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) def self_group(self, against=None): @@ -2465,7 +2601,14 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): return lcc, [c.self_group(against=against) for c in convert_clauses] @classmethod - def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): + def _construct( + cls, + operator: OperatorType, + continue_on: Any, + skip_on: Any, + *clauses: _ColumnExpression[Any], + **kw: Any, + ) -> BooleanClauseList: lcc, convert_clauses = cls._process_clauses_for_boolean( operator, continue_on, @@ -2479,11 +2622,11 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): if lcc > 1: # multiple elements. Return regular BooleanClauseList # which will link elements against the operator. - return cls._construct_raw(operator, convert_clauses) + return cls._construct_raw(operator, convert_clauses) # type: ignore[no-any-return] # noqa E501 elif lcc == 1: # just one element. return it as a single boolean element, # not a list and discard the operator. - return convert_clauses[0] + return convert_clauses[0] # type: ignore[no-any-return] # noqa E501 else: # no elements period. deprecated use case. return an empty # ClauseList construct that generates nothing unless it has @@ -2500,7 +2643,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): }, version="1.4", ) - return cls._construct_raw(operator) + return cls._construct_raw(operator) # type: ignore[no-any-return] # noqa E501 @classmethod def _construct_for_whereclause(cls, clauses): @@ -2540,7 +2683,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): return self @classmethod - def and_(cls, *clauses): + def and_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList: r"""Produce a conjunction of expressions joined by ``AND``. See :func:`_sql.and_` for full documentation. @@ -2550,7 +2693,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): ) @classmethod - def or_(cls, *clauses): + def or_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList: """Produce a conjunction of expressions joined by ``OR``. See :func:`_sql.or_` for full documentation. @@ -2577,19 +2720,27 @@ and_ = BooleanClauseList.and_ or_ = BooleanClauseList.or_ -class Tuple(ClauseList, ColumnElement): +class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): """Represent a SQL tuple.""" __visit_name__ = "tuple" - _traverse_internals = ClauseList._traverse_internals + [] + _traverse_internals: _TraverseInternalsType = ( + ClauseList._traverse_internals + [] + ) + + type: TupleType @util.preload_module("sqlalchemy.sql.sqltypes") - def __init__(self, *clauses, types=None): + def __init__( + self, + *clauses: _ColumnExpression[Any], + types: Optional[Sequence[_TypeEngineArgument[Any]]] = None, + ): sqltypes = util.preloaded.sql_sqltypes if types is None: - clauses = [ + init_clauses = [ coercions.expect(roles.ExpressionElementRole, c) for c in clauses ] @@ -2599,7 +2750,7 @@ class Tuple(ClauseList, ColumnElement): "Wrong number of elements for %d-tuple: %r " % (len(types), clauses) ) - clauses = [ + init_clauses = [ coercions.expect( roles.ExpressionElementRole, c, @@ -2608,8 +2759,8 @@ class Tuple(ClauseList, ColumnElement): for typ, c in zip(types, clauses) ] - self.type = sqltypes.TupleType(*[arg.type for arg in clauses]) - super(Tuple, self).__init__(*clauses) + self.type = sqltypes.TupleType(*[arg.type for arg in init_clauses]) + super(Tuple, self).__init__(*init_clauses) @property def _select_iterable(self): @@ -2672,7 +2823,7 @@ class Case(ColumnElement[_T]): __visit_name__ = "case" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("value", InternalTraversal.dp_clauseelement), ("whens", InternalTraversal.dp_clauseelement_tuples), ("else_", InternalTraversal.dp_clauseelement), @@ -2681,13 +2832,24 @@ class Case(ColumnElement[_T]): # for case(), the type is derived from the whens. so for the moment # users would have to cast() the case to get a specific type - def __init__(self, *whens, value=None, else_=None): + whens: List[typing_Tuple[ColumnElement[bool], ColumnElement[_T]]] + else_: Optional[ColumnElement[_T]] + value: Optional[ColumnElement[Any]] - whens = coercions._expression_collection_was_a_list( + def __init__( + self, + *whens: Union[ + typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any] + ], + value: Optional[Any] = None, + else_: Optional[Any] = None, + ): + + new_whens: Iterable[Any] = coercions._expression_collection_was_a_list( "whens", "case", whens ) try: - whens = util.dictlike_iteritems(whens) + new_whens = util.dictlike_iteritems(new_whens) except TypeError: pass @@ -2700,7 +2862,7 @@ class Case(ColumnElement[_T]): ).self_group(), coercions.expect(roles.ExpressionElementRole, r), ) - for (c, r) in whens + for (c, r) in new_whens ] if whenlist: @@ -2713,7 +2875,7 @@ class Case(ColumnElement[_T]): else: self.value = coercions.expect(roles.ExpressionElementRole, value) - self.type = type_ + self.type = cast(_T, type_) self.whens = whenlist if else_ is not None: @@ -2721,14 +2883,14 @@ class Case(ColumnElement[_T]): else: self.else_ = None - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list( itertools.chain(*[x._from_objects for x in self.get_children()]) ) -class Cast(WrapsColumnExpression, ColumnElement[_T]): +class Cast(WrapsColumnExpression[_T]): """Represent a ``CAST`` expression. :class:`.Cast` is produced using the :func:`.cast` factory function, @@ -2754,12 +2916,20 @@ class Cast(WrapsColumnExpression, ColumnElement[_T]): __visit_name__ = "cast" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("clause", InternalTraversal.dp_clauseelement), ("typeclause", InternalTraversal.dp_clauseelement), ] - def __init__(self, expression, type_): + clause: ColumnElement[Any] + type: TypeEngine[_T] + typeclause: TypeClause + + def __init__( + self, + expression: _ColumnExpression[Any], + type_: _TypeEngineArgument[_T], + ): self.type = type_api.to_instance(type_) self.clause = coercions.expect( roles.ExpressionElementRole, @@ -2769,8 +2939,8 @@ class Cast(WrapsColumnExpression, ColumnElement[_T]): ) self.typeclause = TypeClause(self.type) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.clause._from_objects @property @@ -2778,7 +2948,7 @@ class Cast(WrapsColumnExpression, ColumnElement[_T]): return self.clause -class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]): +class TypeCoerce(WrapsColumnExpression[_T]): """Represent a Python-side type-coercion wrapper. :class:`.TypeCoerce` supplies the :func:`_expression.type_coerce` @@ -2798,12 +2968,19 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]): __visit_name__ = "type_coerce" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("clause", InternalTraversal.dp_clauseelement), ("type", InternalTraversal.dp_type), ] - def __init__(self, expression, type_): + clause: ColumnElement[Any] + type: TypeEngine[_T] + + def __init__( + self, + expression: _ColumnExpression[Any], + type_: _TypeEngineArgument[_T], + ): self.type = type_api.to_instance(type_) self.clause = coercions.expect( roles.ExpressionElementRole, @@ -2812,8 +2989,8 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]): apply_propagate_attrs=self, ) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.clause._from_objects @HasMemoized.memoized_attribute @@ -2837,27 +3014,30 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]): return self -class Extract(ColumnElement[_T]): +class Extract(ColumnElement[int]): """Represent a SQL EXTRACT clause, ``extract(field FROM expr)``.""" __visit_name__ = "extract" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("expr", InternalTraversal.dp_clauseelement), ("field", InternalTraversal.dp_string), ] - def __init__(self, field, expr): + expr: ColumnElement[Any] + field: str + + def __init__(self, field: str, expr: _ColumnExpression[Any]): self.type = type_api.INTEGERTYPE self.field = field self.expr = coercions.expect(roles.ExpressionElementRole, expr) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.expr._from_objects -class _label_reference(ColumnElement): +class _label_reference(ColumnElement[_T]): """Wrap a column expression as it appears in a 'reference' context. This expression is any that includes an _order_by_label_element, @@ -2872,26 +3052,30 @@ class _label_reference(ColumnElement): __visit_name__ = "label_reference" - _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement) + ] - def __init__(self, element): + def __init__(self, element: ColumnElement[_T]): self.element = element - @property - def _from_objects(self): - return () + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] -class _textual_label_reference(ColumnElement): +class _textual_label_reference(ColumnElement[Any]): __visit_name__ = "textual_label_reference" - _traverse_internals = [("element", InternalTraversal.dp_string)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_string) + ] - def __init__(self, element): + def __init__(self, element: str): self.element = element @util.memoized_property - def _text_clause(self): + def _text_clause(self) -> TextClause: return TextClause(self.element) @@ -2911,7 +3095,7 @@ class UnaryExpression(ColumnElement[_T]): __visit_name__ = "unary" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("operator", InternalTraversal.dp_operator), ("modifier", InternalTraversal.dp_operator), @@ -2919,11 +3103,11 @@ class UnaryExpression(ColumnElement[_T]): def __init__( self, - element, - operator=None, - modifier=None, - type_: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] = None, - wraps_column_expression=False, + element: ColumnElement[Any], + operator: Optional[OperatorType] = None, + modifier: Optional[OperatorType] = None, + type_: Optional[_TypeEngineArgument[_T]] = None, + wraps_column_expression: bool = False, ): self.operator = operator self.modifier = modifier @@ -2935,7 +3119,10 @@ class UnaryExpression(ColumnElement[_T]): self.wraps_column_expression = wraps_column_expression @classmethod - def _create_nulls_first(cls, column): + def _create_nulls_first( + cls, + column: _ColumnExpression[_T], + ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), modifier=operators.nulls_first_op, @@ -2943,7 +3130,10 @@ class UnaryExpression(ColumnElement[_T]): ) @classmethod - def _create_nulls_last(cls, column): + def _create_nulls_last( + cls, + column: _ColumnExpression[_T], + ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), modifier=operators.nulls_last_op, @@ -2951,7 +3141,9 @@ class UnaryExpression(ColumnElement[_T]): ) @classmethod - def _create_desc(cls, column): + def _create_desc( + cls, column: _ColumnExpression[_T] + ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), modifier=operators.desc_op, @@ -2959,7 +3151,10 @@ class UnaryExpression(ColumnElement[_T]): ) @classmethod - def _create_asc(cls, column): + def _create_asc( + cls, + column: _ColumnExpression[_T], + ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), modifier=operators.asc_op, @@ -2967,24 +3162,27 @@ class UnaryExpression(ColumnElement[_T]): ) @classmethod - def _create_distinct(cls, expr): - expr = coercions.expect(roles.ExpressionElementRole, expr) + def _create_distinct( + cls, + expr: _ColumnExpression[_T], + ) -> UnaryExpression[_T]: + col_expr = coercions.expect(roles.ExpressionElementRole, expr) return UnaryExpression( - expr, + col_expr, operator=operators.distinct_op, - type_=expr.type, + type_=col_expr.type, wraps_column_expression=False, ) @property - def _order_by_label_element(self): + def _order_by_label_element(self) -> Optional[Label[Any]]: if self.modifier in (operators.desc_op, operators.asc_op): return self.element._order_by_label_element else: return None - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.element._from_objects def _negate(self): @@ -3005,7 +3203,7 @@ class UnaryExpression(ColumnElement[_T]): return self -class CollectionAggregate(UnaryExpression): +class CollectionAggregate(UnaryExpression[_T]): """Forms the basis for right-hand collection operator modifiers ANY and ALL. @@ -3018,7 +3216,9 @@ class CollectionAggregate(UnaryExpression): inherit_cache = True @classmethod - def _create_any(cls, expr): + def _create_any( + cls, expr: _ColumnExpression[_T] + ) -> CollectionAggregate[_T]: expr = coercions.expect(roles.ExpressionElementRole, expr) expr = expr.self_group() @@ -3030,7 +3230,9 @@ class CollectionAggregate(UnaryExpression): ) @classmethod - def _create_all(cls, expr): + def _create_all( + cls, expr: _ColumnExpression[_T] + ) -> CollectionAggregate[_T]: expr = coercions.expect(roles.ExpressionElementRole, expr) expr = expr.self_group() return CollectionAggregate( @@ -3059,7 +3261,7 @@ class CollectionAggregate(UnaryExpression): ) -class AsBoolean(WrapsColumnExpression, UnaryExpression): +class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): inherit_cache = True def __init__(self, element, operator, negate): @@ -3101,7 +3303,7 @@ class BinaryExpression(ColumnElement[_T]): __visit_name__ = "binary" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("left", InternalTraversal.dp_clauseelement), ("right", InternalTraversal.dp_clauseelement), ("operator", InternalTraversal.dp_operator), @@ -3119,16 +3321,16 @@ class BinaryExpression(ColumnElement[_T]): """ + modifiers: Optional[Mapping[str, Any]] + def __init__( self, - left: ColumnElement, - right: Union[ColumnElement, ClauseList], - operator, - type_: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] - ] = None, - negate=None, - modifiers=None, + left: ColumnElement[Any], + right: Union[ColumnElement[Any], ClauseList], + operator: OperatorType, + type_: Optional[_TypeEngineArgument[_T]] = None, + negate: Optional[OperatorType] = None, + modifiers: Optional[Mapping[str, Any]] = None, ): # allow compatibility with libraries that # refer to BinaryExpression directly and pass strings @@ -3149,8 +3351,40 @@ class BinaryExpression(ColumnElement[_T]): self.modifiers = modifiers def __bool__(self): - if self.operator in (operator.eq, operator.ne): - return self.operator(*self._orig) + """Implement Python-side "bool" for BinaryExpression as a + simple "identity" check for the left and right attributes, + if the operator is "eq" or "ne". Otherwise the expression + continues to not support "bool" like all other column expressions. + + The rationale here is so that ColumnElement objects can be hashable. + What? Well, suppose you do this:: + + c1, c2 = column('x'), column('y') + s1 = set([c1, c2]) + + We do that **a lot**, columns inside of sets is an extremely basic + thing all over the ORM for example. + + So what happens if we do this? :: + + c1 in s1 + + Hashing means it will normally use ``__hash__()`` of the object, + but in case of hash collision, it's going to also do ``c1 == c1`` + and/or ``c1 == c2`` inside. Those operations need to return a + True/False value. But because we override ``==`` and ``!=``, they're + going to get a BinaryExpression. Hence we implement ``__bool__`` here + so that these comparisons behave in this particular context mostly + like regular object comparisons. Thankfully Python is OK with + that! Otherwise we'd have to use special set classes for columns + (which we used to do, decades ago). + + """ + if self.operator in (operators.eq, operators.ne): + # this is using the eq/ne operator given int hash values, + # rather than Operator, so that "bool" can be based on + # identity + return self.operator(*self._orig) # type: ignore else: raise TypeError("Boolean value of this clause is not defined") @@ -3167,8 +3401,8 @@ class BinaryExpression(ColumnElement[_T]): def is_comparison(self): return operators.is_comparison(self.operator) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.left._from_objects + self.right._from_objects def self_group(self, against=None): @@ -3192,7 +3426,7 @@ class BinaryExpression(ColumnElement[_T]): return super(BinaryExpression, self)._negate() -class Slice(ColumnElement): +class Slice(ColumnElement[Any]): """Represent SQL for a Python array-slice object. This is not a specific SQL construct at this level, but @@ -3202,7 +3436,7 @@ class Slice(ColumnElement): __visit_name__ = "slice" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("start", InternalTraversal.dp_clauseelement), ("stop", InternalTraversal.dp_clauseelement), ("step", InternalTraversal.dp_clauseelement), @@ -3234,7 +3468,7 @@ class Slice(ColumnElement): return self -class IndexExpression(BinaryExpression): +class IndexExpression(BinaryExpression[Any]): """Represent the class of expressions that are like an "index" operation.""" @@ -3246,6 +3480,8 @@ class GroupedElement(DQLDMLClauseElement): __visit_name__ = "grouping" + element: ClauseElement + def self_group(self, against=None): return self @@ -3253,15 +3489,19 @@ class GroupedElement(DQLDMLClauseElement): return self.element._ungroup() -class Grouping(GroupedElement, ColumnElement): +class Grouping(GroupedElement, ColumnElement[_T]): """Represent a grouping within a column expression""" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("type", InternalTraversal.dp_type), ] - def __init__(self, element): + element: Union[TextClause, ClauseList, ColumnElement[_T]] + + def __init__( + self, element: Union[TextClause, ClauseList, ColumnElement[_T]] + ): self.element = element self.type = getattr(element, "type", type_api.NULLTYPE) @@ -3272,21 +3512,21 @@ class Grouping(GroupedElement, ColumnElement): def _is_implicitly_boolean(self): return self.element._is_implicitly_boolean - @property - def _tq_label(self): + @util.non_memoized_property + def _tq_label(self) -> Optional[str]: return ( getattr(self.element, "_tq_label", None) or self._anon_name_label ) - @property - def _proxies(self): + @util.non_memoized_property + def _proxies(self) -> List[ColumnElement[Any]]: if isinstance(self.element, ColumnElement): return [self.element] else: return [] - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.element._from_objects def __getattr__(self, attr): @@ -3300,8 +3540,13 @@ class Grouping(GroupedElement, ColumnElement): self.type = state["type"] -RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED") -RANGE_CURRENT = util.symbol("RANGE_CURRENT") +class _OverRange(IntEnum): + RANGE_UNBOUNDED = 0 + RANGE_CURRENT = 1 + + +RANGE_UNBOUNDED = _OverRange.RANGE_UNBOUNDED +RANGE_CURRENT = _OverRange.RANGE_CURRENT class Over(ColumnElement[_T]): @@ -3316,7 +3561,7 @@ class Over(ColumnElement[_T]): __visit_name__ = "over" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("order_by", InternalTraversal.dp_clauseelement), ("partition_by", InternalTraversal.dp_clauseelement), @@ -3324,15 +3569,26 @@ class Over(ColumnElement[_T]): ("rows", InternalTraversal.dp_plain_obj), ] - order_by = None - partition_by = None + order_by: Optional[ClauseList] = None + partition_by: Optional[ClauseList] = None - element = None + element: ColumnElement[_T] """The underlying expression object to which this :class:`.Over` object refers towards.""" + range_: Optional[typing_Tuple[int, int]] + def __init__( - self, element, partition_by=None, order_by=None, range_=None, rows=None + self, + element: ColumnElement[_T], + partition_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + order_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ): self.element = element if order_by is not None: @@ -3368,10 +3624,15 @@ class Over(ColumnElement[_T]): self.rows, ) - def _interpret_range(self, range_): + def _interpret_range( + self, range_: typing_Tuple[Optional[int], Optional[int]] + ) -> typing_Tuple[int, int]: if not isinstance(range_, tuple) or len(range_) != 2: raise exc.ArgumentError("2-tuple expected for range/rows") + lower: int + upper: int + if range_[0] is None: lower = RANGE_UNBOUNDED else: @@ -3404,8 +3665,8 @@ class Over(ColumnElement[_T]): def type(self): return self.element.type - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list( itertools.chain( *[ @@ -3436,14 +3697,16 @@ class WithinGroup(ColumnElement[_T]): __visit_name__ = "withingroup" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("order_by", InternalTraversal.dp_clauseelement), ] - order_by = None + order_by: Optional[ClauseList] = None - def __init__(self, element, *order_by): + def __init__( + self, element: FunctionElement[_T], *order_by: _ColumnExpression[Any] + ): self.element = element if order_by is not None: self.order_by = ClauseList( @@ -3451,7 +3714,9 @@ class WithinGroup(ColumnElement[_T]): ) def __reduce__(self): - return self.__class__, (self.element,) + tuple(self.order_by) + return self.__class__, (self.element,) + ( + tuple(self.order_by) if self.order_by is not None else () + ) def over(self, partition_by=None, order_by=None, range_=None, rows=None): """Produce an OVER clause against this :class:`.WithinGroup` @@ -3477,8 +3742,8 @@ class WithinGroup(ColumnElement[_T]): else: return self.element.type - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list( itertools.chain( *[ @@ -3490,7 +3755,7 @@ class WithinGroup(ColumnElement[_T]): ) -class FunctionFilter(ColumnElement): +class FunctionFilter(ColumnElement[_T]): """Represent a function FILTER clause. This is a special operator against aggregate and window functions, @@ -3512,14 +3777,16 @@ class FunctionFilter(ColumnElement): __visit_name__ = "funcfilter" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("func", InternalTraversal.dp_clauseelement), ("criterion", InternalTraversal.dp_clauseelement), ] - criterion = None + criterion: Optional[ColumnElement[bool]] = None - def __init__(self, func, *criterion): + def __init__( + self, func: FunctionElement[_T], *criterion: _ColumnExpression[bool] + ): self.func = func self.filter(*criterion) @@ -3535,17 +3802,27 @@ class FunctionFilter(ColumnElement): """ - for criterion in list(criterion): - criterion = coercions.expect(roles.WhereHavingRole, criterion) + for crit in list(criterion): + crit = coercions.expect(roles.WhereHavingRole, crit) if self.criterion is not None: - self.criterion = self.criterion & criterion + self.criterion = self.criterion & crit else: - self.criterion = criterion + self.criterion = crit return self - def over(self, partition_by=None, order_by=None, range_=None, rows=None): + def over( + self, + partition_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + order_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + ) -> Over[_T]: """Produce an OVER clause against this filtered function. Used against aggregate or so-called "window" functions, @@ -3581,8 +3858,8 @@ class FunctionFilter(ColumnElement): def type(self): return self.func.type - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list( itertools.chain( *[ @@ -3594,7 +3871,7 @@ class FunctionFilter(ColumnElement): ) -class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): +class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -3604,13 +3881,21 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): __visit_name__ = "label" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("name", InternalTraversal.dp_anon_name), - ("_type", InternalTraversal.dp_type), + ("type", InternalTraversal.dp_type), ("_element", InternalTraversal.dp_clauseelement), ] - def __init__(self, name, element, type_=None): + _element: ColumnElement[_T] + name: str + + def __init__( + self, + name: Optional[str], + element: _ColumnExpression[_T], + type_: Optional[_TypeEngineArgument[_T]] = None, + ): orig_element = element element = coercions.expect( roles.ExpressionElementRole, @@ -3635,11 +3920,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): self.key = self._tq_label = self._tq_key_label = self.name self._element = element - self._type = type_ + # self._type = type_ + self.type = type_api.to_instance( + type_ or getattr(self._element, "type", None) + ) self._proxies = [element] def __reduce__(self): - return self.__class__, (self.name, self._element, self._type) + return self.__class__, (self.name, self._element, self.type) @util.memoized_property def _is_implicitly_boolean(self): @@ -3653,14 +3941,8 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): def _order_by_label_element(self): return self - @util.memoized_property - def type(self): - return type_api.to_instance( - self._type or getattr(self._element, "type", None) - ) - @HasMemoized.memoized_attribute - def element(self): + def element(self) -> ColumnElement[_T]: return self._element.self_group(against=operators.as_) def self_group(self, against=None): @@ -3672,7 +3954,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): def _apply_to_inner(self, fn, *arg, **kw): sub_element = fn(*arg, **kw) if sub_element is not self._element: - return Label(self.name, sub_element, type_=self._type) + return Label(self.name, sub_element, type_=self.type) else: return self @@ -3693,8 +3975,8 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): ) self.key = self._tq_label = self._tq_key_label = self.name - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.element._from_objects def _make_proxy(self, selectable, name=None, **kw): @@ -3724,15 +4006,16 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): e._propagate_attrs = selectable._propagate_attrs e._proxies.append(self) - if self._type is not None: - e.type = self._type + if self.type is not None: + e.type = self.type return self.key, e class NamedColumn(ColumnElement[_T]): is_literal = False - table = None + table: Optional[FromClause] = None + name: str def _compare_name_for_result(self, other): return (hasattr(other, "name") and self.name == other.name) or ( @@ -3740,7 +4023,7 @@ class NamedColumn(ColumnElement[_T]): ) @util.memoized_property - def description(self): + def description(self) -> str: return self.name @HasMemoized.memoized_attribute @@ -3759,7 +4042,7 @@ class NamedColumn(ColumnElement[_T]): return self._tq_label @HasMemoized.memoized_attribute - def _tq_label(self): + def _tq_label(self) -> Optional[str]: """table qualified label based on column name. for table-bound columns this is <tablename>_<columnname>; all other @@ -3776,7 +4059,9 @@ class NamedColumn(ColumnElement[_T]): def _non_anon_label(self): return self.name - def _gen_tq_label(self, name, dedupe_on_key=True): + def _gen_tq_label( + self, name: str, dedupe_on_key: bool = True + ) -> Optional[str]: return name def _bind_param(self, operator, obj, type_=None, expanding=False): @@ -3817,7 +4102,7 @@ class NamedColumn(ColumnElement[_T]): class ColumnClause( roles.DDLReferredColumnRole, - roles.LabeledColumnExprRole, + roles.LabeledColumnExprRole[_T], roles.StrAsPlainColumnRole, Immutable, NamedColumn[_T], @@ -3859,30 +4144,31 @@ class ColumnClause( """ - table = None - is_literal = False + table: Optional[FromClause] + is_literal: bool __visit_name__ = "column" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("name", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("table", InternalTraversal.dp_clauseelement), ("is_literal", InternalTraversal.dp_boolean), ] - onupdate = default = server_default = server_onupdate = None + onupdate: Optional[DefaultGenerator] = None + default: Optional[DefaultGenerator] = None + server_default: Optional[DefaultGenerator] = None + server_onupdate: Optional[DefaultGenerator] = None _is_multiparam_column = False def __init__( self, text: str, - type_: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] - ] = None, + type_: Optional[_TypeEngineArgument[_T]] = None, is_literal: bool = False, - _selectable: Optional["FromClause"] = None, + _selectable: Optional[FromClause] = None, ): self.key = self.name = text self.table = _selectable @@ -3916,7 +4202,7 @@ class ColumnClause( return super(ColumnClause, self)._clone(**kw) @HasMemoized.memoized_attribute - def _from_objects(self): + def _from_objects(self) -> List[FromClause]: t = self.table if t is not None: return [t] @@ -3953,7 +4239,9 @@ class ColumnClause( else: return other.proxy_set.intersection(self.proxy_set) - def _gen_tq_label(self, name, dedupe_on_key=True): + def _gen_tq_label( + self, name: str, dedupe_on_key: bool = True + ) -> Optional[str]: """generate table-qualified label for a table-bound column this is <tablename>_<columnname>. @@ -3962,22 +4250,24 @@ class ColumnClause( as well as the .columns collection on a Join object. """ + label: str t = self.table if self.is_literal: return None - elif t is not None and t.named_with_column: - if getattr(t, "schema", None): + elif t is not None and is_named_from_clause(t): + if has_schema_attr(t) and t.schema: label = t.schema.replace(".", "_") + "_" + t.name + "_" + name else: + assert not TYPE_CHECKING or isinstance(t, NamedFromClause) label = t.name + "_" + name # propagate name quoting rules for labels. - if getattr(name, "quote", None) is not None: - if isinstance(label, quoted_name): + if is_quoted_name(name) and name.quote is not None: + if is_quoted_name(label): label.quote = name.quote else: label = quoted_name(label, name.quote) - elif getattr(t.name, "quote", None) is not None: + elif is_quoted_name(t.name) and t.name.quote is not None: # can't get this situation to occur, so let's # assert false on it for now assert not isinstance(label, quoted_name) @@ -4046,16 +4336,16 @@ class ColumnClause( return c.key, c -class TableValuedColumn(NamedColumn): +class TableValuedColumn(NamedColumn[_T]): __visit_name__ = "table_valued_column" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("name", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("scalar_alias", InternalTraversal.dp_clauseelement), ] - def __init__(self, scalar_alias, type_): + def __init__(self, scalar_alias: NamedFromClause, type_: TypeEngine[_T]): self.scalar_alias = scalar_alias self.key = self.name = scalar_alias.name self.type = type_ @@ -4064,24 +4354,28 @@ class TableValuedColumn(NamedColumn): self.scalar_alias = clone(self.scalar_alias, **kw) self.key = self.name = self.scalar_alias.name - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return [self.scalar_alias] -class CollationClause(ColumnElement): +class CollationClause(ColumnElement[str]): __visit_name__ = "collation" - _traverse_internals = [("collation", InternalTraversal.dp_string)] + _traverse_internals: _TraverseInternalsType = [ + ("collation", InternalTraversal.dp_string) + ] @classmethod - def _create_collation_expression(cls, expression, collation): + def _create_collation_expression( + cls, expression: _ColumnExpression[str], collation: str + ) -> BinaryExpression[str]: expr = coercions.expect(roles.ExpressionElementRole, expression) return BinaryExpression( expr, CollationClause(collation), operators.collate, - type_=expression.type, + type_=expr.type, ) def __init__(self, collation): @@ -4163,6 +4457,8 @@ class quoted_name(util.MemoizedSlots, str): __slots__ = "quote", "lower", "upper" + quote: Optional[bool] + def __new__(cls, value, quote): if value is None: return None @@ -4196,10 +4492,10 @@ class quoted_name(util.MemoizedSlots, str): return str(self).upper() -def _find_columns(clause): +def _find_columns(clause: ClauseElement) -> Set[ColumnClause[Any]]: """locate Column objects within the given expression.""" - cols = util.column_set() + cols: Set[ColumnClause[Any]] = set() traverse(clause, {}, {"column": cols.add}) return cols @@ -4226,6 +4522,8 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False): class AnnotatedColumnElement(Annotated): + _Annotated__element: ColumnElement[Any] + def __init__(self, element, values): Annotated.__init__(self, element, values) for attr in ( @@ -4265,7 +4563,7 @@ class AnnotatedColumnElement(Annotated): return self._Annotated__element.info @util.memoized_property - def _anon_name_label(self): + def _anon_name_label(self) -> str: return self._Annotated__element._anon_name_label @@ -4353,8 +4651,12 @@ class _anonymous_label(_truncated_label): @classmethod def safe_construct( - cls, seed, body, enclosing_label=None, sanitize_key=False - ) -> "_anonymous_label": + cls, + seed: int, + body: str, + enclosing_label: Optional[str] = None, + sanitize_key: bool = False, + ) -> _anonymous_label: if sanitize_key: body = re.sub(r"[%\(\) \$]+", "_", body).strip("_") |