diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-17 16:18:55 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-19 23:15:15 -0400 |
commit | 6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f (patch) | |
tree | ae142d45de71d1ebd43df1a38e54e1d3cf1063ec /lib/sqlalchemy/sql/elements.py | |
parent | c2fe4a264003933ff895c51f5d07a8456ac86382 (diff) | |
download | sqlalchemy-6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f.tar.gz |
pep 484 for types
strict types type_api.py, including TypeDecorator,
NativeForEmulated, etc.
Change-Id: Ib2eba26de0981324a83733954cb7044a29bbd7db
Diffstat (limited to 'lib/sqlalchemy/sql/elements.py')
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 119 |
1 files changed, 82 insertions, 37 deletions
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 696d3c6f2..48c3c3be6 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -258,6 +258,8 @@ class CompilerElement(Visitable): """Return a compiler appropriate for this ClauseElement, given a Dialect.""" + if TYPE_CHECKING: + assert isinstance(self, ClauseElement) return dialect.statement_compiler(dialect, self, **kw) def __str__(self) -> str: @@ -663,6 +665,11 @@ class DQLDMLClauseElement(ClauseElement): if typing.TYPE_CHECKING: + def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + ... + def compile( # noqa: A001 self, bind: Optional[Union[Engine, Connection]] = None, @@ -671,9 +678,6 @@ class DQLDMLClauseElement(ClauseElement): ) -> SQLCompiler: ... - def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: - ... - class CompilerColumnElement( roles.DMLColumnRole, @@ -1274,14 +1278,20 @@ class ColumnElement( @overload def self_group( + self: ColumnElement[_T], against: Optional[OperatorType] = None + ) -> ColumnElement[_T]: + ... + + @overload + def self_group( self: ColumnElement[bool], against: Optional[OperatorType] = None ) -> ColumnElement[bool]: ... @overload def self_group( - self: ColumnElement[_T], against: Optional[OperatorType] = None - ) -> ColumnElement[_T]: + self: ColumnElement[Any], against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: ... def self_group( @@ -1777,7 +1787,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): value = None if quote is not None: - key = quoted_name(key, quote) + key = quoted_name.construct(key, quote) if unique: self.key = _anonymous_label.safe_construct( @@ -3121,7 +3131,11 @@ class UnaryExpression(ColumnElement[_T]): self.element = element.self_group( against=self.operator or self.modifier ) - self.type: TypeEngine[_T] = type_api.to_instance(type_) + + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore + self.wraps_column_expression = wraps_column_expression @classmethod @@ -3224,27 +3238,32 @@ class CollectionAggregate(UnaryExpression[_T]): @classmethod def _create_any( cls, expr: _ColumnExpression[_T] - ) -> CollectionAggregate[_T]: - expr = coercions.expect(roles.ExpressionElementRole, expr) - - expr = expr.self_group() - return CollectionAggregate( + ) -> CollectionAggregate[bool]: + col_expr = coercions.expect( + roles.ExpressionElementRole, expr, + ) + col_expr = col_expr.self_group() + return CollectionAggregate( + col_expr, operator=operators.any_op, - type_=type_api.NULLTYPE, + type_=type_api.BOOLEANTYPE, wraps_column_expression=False, ) @classmethod def _create_all( cls, expr: _ColumnExpression[_T] - ) -> CollectionAggregate[_T]: - expr = coercions.expect(roles.ExpressionElementRole, expr) - expr = expr.self_group() - return CollectionAggregate( + ) -> CollectionAggregate[bool]: + col_expr = coercions.expect( + roles.ExpressionElementRole, expr, + ) + col_expr = col_expr.self_group() + return CollectionAggregate( + col_expr, operator=operators.all_op, - type_=type_api.NULLTYPE, + type_=type_api.BOOLEANTYPE, wraps_column_expression=False, ) @@ -3347,7 +3366,11 @@ class BinaryExpression(ColumnElement[_T]): self.left = left.self_group(against=operator) self.right = right.self_group(against=operator) self.operator = operator - self.type: TypeEngine[_T] = type_api.to_instance(type_) + + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore + self.negate = negate self._is_implicitly_boolean = operators.is_boolean(operator) @@ -3509,7 +3532,9 @@ class Grouping(GroupedElement, ColumnElement[_T]): self, element: Union[TextClause, ClauseList, ColumnElement[_T]] ): self.element = element - self.type = getattr(element, "type", type_api.NULLTYPE) + + # nulltype assignment issue + self.type = getattr(element, "type", type_api.NULLTYPE) # type: ignore def _with_binary_element_type(self, type_): return self.__class__(self.element._with_binary_element_type(type_)) @@ -3926,10 +3951,13 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]): self.key = self._tq_label = self._tq_key_label = self.name self._element = element - # self._type = type_ - self.type = type_api.to_instance( - type_ or getattr(self._element, "type", None) + + self.type = ( + type_api.to_instance(type_) + if type_ is not None + else self._element.type ) + self._proxies = [element] def __reduce__(self): @@ -4178,7 +4206,11 @@ class ColumnClause( ): self.key = self.name = text self.table = _selectable - self.type: TypeEngine[_T] = type_api.to_instance(type_) + + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore + self.is_literal = is_literal def get_children(self, column_tables=False, **kw): @@ -4465,19 +4497,32 @@ class quoted_name(util.MemoizedSlots, str): quote: Optional[bool] - def __new__(cls, value, quote): + @overload + @classmethod + def construct(cls, value: str, quote: Optional[bool]) -> quoted_name: + ... + + @overload + @classmethod + def construct(cls, value: None, quote: Optional[bool]) -> None: + ... + + @classmethod + def construct( + cls, value: Optional[str], quote: Optional[bool] + ) -> Optional[quoted_name]: if value is None: return None - # experimental - don't bother with quoted_name - # if quote flag is None. doesn't seem to make any dent - # in performance however - # elif not sprcls and quote is None: - # return value - elif isinstance(value, cls) and ( - quote is None or value.quote == quote - ): + else: + return quoted_name(value, quote) + + def __new__(cls, value: str, quote: Optional[bool]) -> quoted_name: + assert ( + value is not None + ), "use quoted_name.construct() for None passthrough" + if isinstance(value, cls) and (quote is None or value.quote == quote): return value - self = super(quoted_name, cls).__new__(cls, value) + self = super().__new__(cls, value) self.quote = quote return self @@ -4579,15 +4624,15 @@ class _truncated_label(quoted_name): __slots__ = () - def __new__(cls, value, quote=None): + def __new__(cls, value: str, quote: Optional[bool] = None) -> Any: quote = getattr(value, "quote", quote) # return super(_truncated_label, cls).__new__(cls, value, quote, True) return super(_truncated_label, cls).__new__(cls, value, quote) - def __reduce__(self): + def __reduce__(self) -> Any: return self.__class__, (str(self), self.quote) - def apply_map(self, map_): + def apply_map(self, map_: Mapping[str, Any]) -> str: return self |