summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/_elements_constructors.py4
-rw-r--r--lib/sqlalchemy/sql/base.py6
-rw-r--r--lib/sqlalchemy/sql/compiler.py111
-rw-r--r--lib/sqlalchemy/sql/dml.py4
-rw-r--r--lib/sqlalchemy/sql/elements.py119
-rw-r--r--lib/sqlalchemy/sql/schema.py6
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py293
-rw-r--r--lib/sqlalchemy/sql/type_api.py678
8 files changed, 838 insertions, 383 deletions
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py
index 770fbe40c..aabd3871e 100644
--- a/lib/sqlalchemy/sql/_elements_constructors.py
+++ b/lib/sqlalchemy/sql/_elements_constructors.py
@@ -58,7 +58,7 @@ if typing.TYPE_CHECKING:
_T = TypeVar("_T")
-def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]:
+def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]:
"""Produce an ALL expression.
For dialects such as that of PostgreSQL, this operator applies
@@ -173,7 +173,7 @@ def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList:
return BooleanClauseList.and_(*clauses)
-def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]:
+def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]:
"""Produce an ANY expression.
For dialects such as that of PostgreSQL, this operator applies
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 29f9028c8..6a6b389de 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -1132,10 +1132,12 @@ class SchemaEventTarget:
"""
- def _set_parent(self, parent, **kw):
+ def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
"""Associate with this SchemaEvent's parent object."""
- def _set_parent_with_dispatch(self, parent, **kw):
+ def _set_parent_with_dispatch(
+ self, parent: SchemaEventTarget, **kw: Any
+ ) -> None:
self.dispatch.before_parent_attach(self, parent)
self._set_parent(parent, **kw)
self.dispatch.after_parent_attach(self, parent)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8f878b66c..f8019b9c6 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -43,6 +43,7 @@ from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Set
@@ -51,6 +52,7 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import Union
+from sqlalchemy.sql.ddl import DDLElement
from . import base
from . import coercions
from . import crud
@@ -61,7 +63,9 @@ from . import schema
from . import selectable
from . import sqltypes
from .base import _from_objects
+from .base import Executable
from .base import NO_ARG
+from .elements import ClauseElement
from .elements import quoted_name
from .schema import Column
from .sqltypes import TupleType
@@ -78,6 +82,10 @@ if typing.TYPE_CHECKING:
from .base import _AmbiguousTableNameMap
from .base import CompileState
from .cache_key import CacheKey
+ from .dml import Insert
+ from .dml import UpdateBase
+ from .dml import ValuesBase
+ from .elements import _truncated_label
from .elements import BindParameter
from .elements import ColumnClause
from .elements import Label
@@ -91,12 +99,13 @@ if typing.TYPE_CHECKING:
from .selectable import ReturnsRows
from .selectable import Select
from .selectable import SelectState
+ from .type_api import _BindProcessorType
from ..engine.cursor import CursorResultMetaData
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptions
from ..engine.interfaces import _MutableCoreSingleExecuteParams
from ..engine.interfaces import _SchemaTranslateMapType
- from ..engine.result import _ProcessorType
+ from ..engine.interfaces import Dialect
_FromHintsType = Dict["FromClause", str]
@@ -378,7 +387,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
class ExpandedState(NamedTuple):
statement: str
additional_parameters: _CoreSingleExecuteParams
- processors: Mapping[str, _ProcessorType]
+ processors: Mapping[str, _BindProcessorType[Any]]
positiontup: Optional[Sequence[str]]
parameter_expansion: Mapping[str, List[str]]
@@ -531,11 +540,11 @@ class Compiled:
def __init__(
self,
- dialect,
- statement,
- schema_translate_map=None,
- render_schema_translate=False,
- compile_kwargs=util.immutabledict(),
+ dialect: Dialect,
+ statement: Optional[ClauseElement],
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+ render_schema_translate: bool = False,
+ compile_kwargs: Mapping[str, Any] = util.immutabledict(),
):
"""Construct a new :class:`.Compiled` object.
@@ -571,6 +580,8 @@ class Compiled:
self.can_execute = statement.supports_execution
self._annotations = statement._annotations
if self.can_execute:
+ if TYPE_CHECKING:
+ assert isinstance(statement, Executable)
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
@@ -636,10 +647,10 @@ class TypeCompiler(util.EnsureKWArg):
ensure_kwarg = r"visit_\w+"
- def __init__(self, dialect):
+ def __init__(self, dialect: Dialect):
self.dialect = dialect
- def process(self, type_, **kw):
+ def process(self, type_: TypeEngine[Any], **kw: Any) -> str:
if (
type_._variant_mapping
and self.dialect.name in type_._variant_mapping
@@ -647,7 +658,9 @@ class TypeCompiler(util.EnsureKWArg):
type_ = type_._variant_mapping[self.dialect.name]
return type_._compiler_dispatch(self, **kw)
- def visit_unsupported_compilation(self, element, err, **kw):
+ def visit_unsupported_compilation(
+ self, element: Any, err: Exception, **kw: Any
+ ) -> NoReturn:
raise exc.UnsupportedCompilationError(self, element) from err
@@ -877,13 +890,13 @@ class SQLCompiler(Compiled):
def __init__(
self,
- dialect,
- statement,
- cache_key=None,
- column_keys=None,
- for_executemany=False,
- linting=NO_LINTING,
- **kwargs,
+ dialect: Dialect,
+ statement: Optional[ClauseElement],
+ cache_key: Optional[CacheKey] = None,
+ column_keys: Optional[Sequence[str]] = None,
+ for_executemany: bool = False,
+ linting: Linting = NO_LINTING,
+ **kwargs: Any,
):
"""Construct a new :class:`.SQLCompiler` object.
@@ -954,15 +967,21 @@ class SQLCompiler(Compiled):
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
- self.truncated_names = {}
+ self.truncated_names: Dict[Tuple[str, str], str] = {}
+ self._truncated_counters: Dict[str, int] = {}
Compiled.__init__(self, dialect, statement, **kwargs)
if self.isinsert or self.isupdate or self.isdelete:
+ if TYPE_CHECKING:
+ assert isinstance(statement, UpdateBase)
+
if statement._returning:
self.returning = statement._returning
if self.isinsert or self.isupdate:
+ if TYPE_CHECKING:
+ assert isinstance(statement, ValuesBase)
if statement._inline:
self.inline = True
elif self.for_executemany and (
@@ -1082,9 +1101,14 @@ class SQLCompiler(Compiled):
@util.memoized_property
def _bind_processors(
self,
- ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]:
+ ) -> MutableMapping[
+ str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
+ ]:
+
+ # mypy is not able to see the two value types as the above Union,
+ # it just sees "object". don't know how to resolve
return dict(
- (key, value)
+ (key, value) # type: ignore
for key, value in (
(
self.bind_names[bindparam],
@@ -1301,12 +1325,14 @@ class SQLCompiler(Compiled):
positiontup = None
processors = self._bind_processors
- single_processors = cast("Mapping[str, _ProcessorType]", processors)
+ single_processors = cast(
+ "Mapping[str, _BindProcessorType[Any]]", processors
+ )
tuple_processors = cast(
- "Mapping[str, Sequence[_ProcessorType]]", processors
+ "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors
)
- new_processors: Dict[str, _ProcessorType] = {}
+ new_processors: Dict[str, _BindProcessorType[Any]] = {}
if self.positional and self._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'.
@@ -1484,6 +1510,10 @@ class SQLCompiler(Compiled):
result = util.preloaded.engine_result
param_key_getter = self._within_exec_param_key_getter
+
+ if TYPE_CHECKING:
+ assert isinstance(self.statement, Insert)
+
table = self.statement.table
getters = [
@@ -1530,6 +1560,9 @@ class SQLCompiler(Compiled):
else:
result = util.preloaded.engine_result
+ if TYPE_CHECKING:
+ assert isinstance(self.statement, Insert)
+
param_key_getter = self._within_exec_param_key_getter
table = self.statement.table
@@ -1796,7 +1829,9 @@ class SQLCompiler(Compiled):
def visit_typeclause(self, typeclause, **kw):
kw["type_expression"] = typeclause
kw["identifier_preparer"] = self.preparer
- return self.dialect.type_compiler.process(typeclause.type, **kw)
+ return self.dialect.type_compiler_instance.process(
+ typeclause.type, **kw
+ )
def post_process_text(self, text):
if self.preparer._double_percents:
@@ -2855,26 +2890,28 @@ class SQLCompiler(Compiled):
return bind_name
- def _truncated_identifier(self, ident_class, name):
+ def _truncated_identifier(
+ self, ident_class: str, name: _truncated_label
+ ) -> str:
if (ident_class, name) in self.truncated_names:
return self.truncated_names[(ident_class, name)]
anonname = name.apply_map(self.anon_map)
if len(anonname) > self.label_length - 6:
- counter = self.truncated_names.get(ident_class, 1)
+ counter = self._truncated_counters.get(ident_class, 1)
truncname = (
anonname[0 : max(self.label_length - 6, 0)]
+ "_"
+ hex(counter)[2:]
)
- self.truncated_names[ident_class] = counter + 1
+ self._truncated_counters[ident_class] = counter + 1
else:
truncname = anonname
self.truncated_names[(ident_class, name)] = truncname
return truncname
- def _anonymize(self, name):
+ def _anonymize(self, name: str) -> str:
return name % self.anon_map
def bindparam_string(
@@ -3221,7 +3258,7 @@ class SQLCompiler(Compiled):
% (
self.preparer.quote(col.name),
" %s"
- % self.dialect.type_compiler.process(
+ % self.dialect.type_compiler_instance.process(
col.type, **kwargs
)
if alias._render_derived_w_types
@@ -4685,6 +4722,18 @@ class StrSQLCompiler(SQLCompiler):
class DDLCompiler(Compiled):
+ if TYPE_CHECKING:
+
+ def __init__(
+ self,
+ dialect: Dialect,
+ statement: DDLElement,
+ schema_translate_map: Optional[_SchemaTranslateMapType] = ...,
+ render_schema_translate: bool = ...,
+ compile_kwargs: Mapping[str, Any] = ...,
+ ):
+ ...
+
@util.memoized_property
def sql_compiler(self):
return self.dialect.statement_compiler(
@@ -4693,7 +4742,7 @@ class DDLCompiler(Compiled):
@util.memoized_property
def type_compiler(self):
- return self.dialect.type_compiler
+ return self.dialect.type_compiler_instance
def construct_params(
self,
@@ -5010,7 +5059,7 @@ class DDLCompiler(Compiled):
colspec = (
self.preparer.format_column(column)
+ " "
- + self.dialect.type_compiler.process(
+ + self.dialect.type_compiler_instance.process(
column.type, type_expression=column
)
)
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 96e90b0ea..1271c5977 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -433,7 +433,7 @@ class ValuesBase(UpdateBase):
_multi_values = ()
_ordered_values = None
_select_names = None
-
+ _inline: bool = False
_returning = ()
def __init__(self, table):
@@ -742,7 +742,6 @@ class Insert(ValuesBase):
select = None
include_insert_from_select_defaults = False
- _inline = False
is_insert = True
@@ -959,7 +958,6 @@ class Update(DMLWhereBase, ValuesBase):
is_update = True
_preserve_parameter_order = False
- _inline = False
_traverse_internals = (
[
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 696d3c6f2..48c3c3be6 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -258,6 +258,8 @@ class CompilerElement(Visitable):
"""Return a compiler appropriate for this ClauseElement, given a
Dialect."""
+ if TYPE_CHECKING:
+ assert isinstance(self, ClauseElement)
return dialect.statement_compiler(dialect, self, **kw)
def __str__(self) -> str:
@@ -663,6 +665,11 @@ class DQLDMLClauseElement(ClauseElement):
if typing.TYPE_CHECKING:
+ def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler:
+ """Return a compiler appropriate for this ClauseElement, given a
+ Dialect."""
+ ...
+
def compile( # noqa: A001
self,
bind: Optional[Union[Engine, Connection]] = None,
@@ -671,9 +678,6 @@ class DQLDMLClauseElement(ClauseElement):
) -> SQLCompiler:
...
- def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler:
- ...
-
class CompilerColumnElement(
roles.DMLColumnRole,
@@ -1274,14 +1278,20 @@ class ColumnElement(
@overload
def self_group(
+ self: ColumnElement[_T], against: Optional[OperatorType] = None
+ ) -> ColumnElement[_T]:
+ ...
+
+ @overload
+ def self_group(
self: ColumnElement[bool], against: Optional[OperatorType] = None
) -> ColumnElement[bool]:
...
@overload
def self_group(
- self: ColumnElement[_T], against: Optional[OperatorType] = None
- ) -> ColumnElement[_T]:
+ self: ColumnElement[Any], against: Optional[OperatorType] = None
+ ) -> ColumnElement[Any]:
...
def self_group(
@@ -1777,7 +1787,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
value = None
if quote is not None:
- key = quoted_name(key, quote)
+ key = quoted_name.construct(key, quote)
if unique:
self.key = _anonymous_label.safe_construct(
@@ -3121,7 +3131,11 @@ class UnaryExpression(ColumnElement[_T]):
self.element = element.self_group(
against=self.operator or self.modifier
)
- self.type: TypeEngine[_T] = type_api.to_instance(type_)
+
+ # if type is None, we get NULLTYPE, which is our _T. But I don't
+ # know how to get the overloads to express that correctly
+ self.type = type_api.to_instance(type_) # type: ignore
+
self.wraps_column_expression = wraps_column_expression
@classmethod
@@ -3224,27 +3238,32 @@ class CollectionAggregate(UnaryExpression[_T]):
@classmethod
def _create_any(
cls, expr: _ColumnExpression[_T]
- ) -> CollectionAggregate[_T]:
- expr = coercions.expect(roles.ExpressionElementRole, expr)
-
- expr = expr.self_group()
- return CollectionAggregate(
+ ) -> CollectionAggregate[bool]:
+ col_expr = coercions.expect(
+ roles.ExpressionElementRole,
expr,
+ )
+ col_expr = col_expr.self_group()
+ return CollectionAggregate(
+ col_expr,
operator=operators.any_op,
- type_=type_api.NULLTYPE,
+ type_=type_api.BOOLEANTYPE,
wraps_column_expression=False,
)
@classmethod
def _create_all(
cls, expr: _ColumnExpression[_T]
- ) -> CollectionAggregate[_T]:
- expr = coercions.expect(roles.ExpressionElementRole, expr)
- expr = expr.self_group()
- return CollectionAggregate(
+ ) -> CollectionAggregate[bool]:
+ col_expr = coercions.expect(
+ roles.ExpressionElementRole,
expr,
+ )
+ col_expr = col_expr.self_group()
+ return CollectionAggregate(
+ col_expr,
operator=operators.all_op,
- type_=type_api.NULLTYPE,
+ type_=type_api.BOOLEANTYPE,
wraps_column_expression=False,
)
@@ -3347,7 +3366,11 @@ class BinaryExpression(ColumnElement[_T]):
self.left = left.self_group(against=operator)
self.right = right.self_group(against=operator)
self.operator = operator
- self.type: TypeEngine[_T] = type_api.to_instance(type_)
+
+ # if type is None, we get NULLTYPE, which is our _T. But I don't
+ # know how to get the overloads to express that correctly
+ self.type = type_api.to_instance(type_) # type: ignore
+
self.negate = negate
self._is_implicitly_boolean = operators.is_boolean(operator)
@@ -3509,7 +3532,9 @@ class Grouping(GroupedElement, ColumnElement[_T]):
self, element: Union[TextClause, ClauseList, ColumnElement[_T]]
):
self.element = element
- self.type = getattr(element, "type", type_api.NULLTYPE)
+
+ # nulltype assignment issue
+ self.type = getattr(element, "type", type_api.NULLTYPE) # type: ignore
def _with_binary_element_type(self, type_):
return self.__class__(self.element._with_binary_element_type(type_))
@@ -3926,10 +3951,13 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]):
self.key = self._tq_label = self._tq_key_label = self.name
self._element = element
- # self._type = type_
- self.type = type_api.to_instance(
- type_ or getattr(self._element, "type", None)
+
+ self.type = (
+ type_api.to_instance(type_)
+ if type_ is not None
+ else self._element.type
)
+
self._proxies = [element]
def __reduce__(self):
@@ -4178,7 +4206,11 @@ class ColumnClause(
):
self.key = self.name = text
self.table = _selectable
- self.type: TypeEngine[_T] = type_api.to_instance(type_)
+
+ # if type is None, we get NULLTYPE, which is our _T. But I don't
+ # know how to get the overloads to express that correctly
+ self.type = type_api.to_instance(type_) # type: ignore
+
self.is_literal = is_literal
def get_children(self, column_tables=False, **kw):
@@ -4465,19 +4497,32 @@ class quoted_name(util.MemoizedSlots, str):
quote: Optional[bool]
- def __new__(cls, value, quote):
+ @overload
+ @classmethod
+ def construct(cls, value: str, quote: Optional[bool]) -> quoted_name:
+ ...
+
+ @overload
+ @classmethod
+ def construct(cls, value: None, quote: Optional[bool]) -> None:
+ ...
+
+ @classmethod
+ def construct(
+ cls, value: Optional[str], quote: Optional[bool]
+ ) -> Optional[quoted_name]:
if value is None:
return None
- # experimental - don't bother with quoted_name
- # if quote flag is None. doesn't seem to make any dent
- # in performance however
- # elif not sprcls and quote is None:
- # return value
- elif isinstance(value, cls) and (
- quote is None or value.quote == quote
- ):
+ else:
+ return quoted_name(value, quote)
+
+ def __new__(cls, value: str, quote: Optional[bool]) -> quoted_name:
+ assert (
+ value is not None
+ ), "use quoted_name.construct() for None passthrough"
+ if isinstance(value, cls) and (quote is None or value.quote == quote):
return value
- self = super(quoted_name, cls).__new__(cls, value)
+ self = super().__new__(cls, value)
self.quote = quote
return self
@@ -4579,15 +4624,15 @@ class _truncated_label(quoted_name):
__slots__ = ()
- def __new__(cls, value, quote=None):
+ def __new__(cls, value: str, quote: Optional[bool] = None) -> Any:
quote = getattr(value, "quote", quote)
# return super(_truncated_label, cls).__new__(cls, value, quote, True)
return super(_truncated_label, cls).__new__(cls, value, quote)
- def __reduce__(self):
+ def __reduce__(self) -> Any:
return self.__class__, (str(self), self.quote)
- def apply_map(self, map_):
+ def apply_map(self, map_: Mapping[str, Any]) -> str:
return self
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 78d524127..5cfb55603 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -3077,7 +3077,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
elif metadata is not None and schema is None and metadata.schema:
self.schema = schema = metadata.schema
else:
- self.schema = quoted_name(schema, quote_schema)
+ self.schema = quoted_name.construct(schema, quote_schema)
self.metadata = metadata
self._key = _get_table_key(name, schema)
if metadata:
@@ -4258,7 +4258,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
"""
self.table = table = None
- self.name = quoted_name(name, kw.pop("quote", None))
+ self.name = quoted_name.construct(name, kw.pop("quote", None))
self.unique = kw.pop("unique", False)
_column_flag = kw.pop("_column_flag", False)
if "info" in kw:
@@ -4493,7 +4493,7 @@ class MetaData(HasSchemaAttr):
"""
self.tables = util.FacadeDict()
- self.schema = quoted_name(schema, quote_schema)
+ self.schema = quoted_name.construct(schema, quote_schema)
self.naming_convention = (
naming_convention
if naming_convention
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 4d0169370..829c1b72e 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -17,9 +17,15 @@ import enum
import json
import pickle
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -40,6 +46,7 @@ from .type_api import NativeForEmulated # noqa
from .type_api import to_instance
from .type_api import TypeDecorator
from .type_api import TypeEngine
+from .type_api import TypeEngineMixin
from .type_api import Variant # noqa
from .visitors import InternalTraversal
from .. import event
@@ -51,11 +58,19 @@ from ..util import langhelpers
from ..util import OrderedDict
from ..util.typing import Literal
+if TYPE_CHECKING:
+ from .operators import OperatorType
+ from .type_api import _BindProcessorType
+ from .type_api import _ComparatorFactory
+ from .type_api import _ResultProcessorType
+ from ..engine.interfaces import Dialect
_T = TypeVar("_T", bound="Any")
+_CT = TypeVar("_CT", bound=Any)
+_TE = TypeVar("_TE", bound="TypeEngine[Any]")
-class _LookupExpressionAdapter:
+class HasExpressionLookup(TypeEngineMixin):
"""Mixin expression adaptations based on lookup tables.
@@ -68,11 +83,18 @@ class _LookupExpressionAdapter:
def _expression_adaptations(self):
raise NotImplementedError()
- class Comparator(TypeEngine.Comparator[_T]):
- _blank_dict = util.immutabledict()
+ class Comparator(TypeEngine.Comparator[_CT]):
+
+ _blank_dict = util.EMPTY_DICT
- def _adapt_expression(self, op, other_comparator):
+ def _adapt_expression(
+ self,
+ op: OperatorType,
+ other_comparator: TypeEngine.Comparator[Any],
+ ) -> Tuple[OperatorType, TypeEngine[Any]]:
othertype = other_comparator.type._type_affinity
+ if TYPE_CHECKING:
+ assert isinstance(self.type, HasExpressionLookup)
lookup = self.type._expression_adaptations.get(
op, self._blank_dict
).get(othertype, self.type)
@@ -83,16 +105,20 @@ class _LookupExpressionAdapter:
else:
return (op, to_instance(lookup))
- comparator_factory = Comparator
+ comparator_factory: _ComparatorFactory[Any] = Comparator
-class Concatenable:
+class Concatenable(TypeEngineMixin):
"""A mixin that marks a type as supporting 'concatenation',
typically strings."""
class Comparator(TypeEngine.Comparator[_T]):
- def _adapt_expression(self, op, other_comparator):
+ def _adapt_expression(
+ self,
+ op: OperatorType,
+ other_comparator: TypeEngine.Comparator[Any],
+ ) -> Tuple[OperatorType, TypeEngine[Any]]:
if op is operators.add and isinstance(
other_comparator,
(Concatenable.Comparator, NullType.Comparator),
@@ -103,10 +129,10 @@ class Concatenable:
op, other_comparator
)
- comparator_factory = Comparator
+ comparator_factory: _ComparatorFactory[Any] = Comparator
-class Indexable:
+class Indexable(TypeEngineMixin):
"""A mixin that marks a type as supporting indexing operations,
such as array or JSON structures.
@@ -151,7 +177,7 @@ class String(Concatenable, TypeEngine[str]):
# note pylance appears to require the "self" type in a constructor
# for the _T type to be correctly recognized when we send the
# class as the argument, e.g. `column("somecol", String)`
- self: "String",
+ self,
length=None,
collation=None,
):
@@ -313,7 +339,7 @@ class UnicodeText(Text):
super(UnicodeText, self).__init__(length=length, **kwargs)
-class Integer(_LookupExpressionAdapter, TypeEngine[int]):
+class Integer(HasExpressionLookup, TypeEngine[int]):
"""A type for ``int`` integers."""
@@ -378,7 +404,7 @@ class BigInteger(Integer):
_N = TypeVar("_N", bound=Union[decimal.Decimal, float])
-class Numeric(_LookupExpressionAdapter, TypeEngine[_N]):
+class Numeric(HasExpressionLookup, TypeEngine[_N]):
"""A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``.
@@ -423,7 +449,7 @@ class Numeric(_LookupExpressionAdapter, TypeEngine[_N]):
_default_decimal_return_scale = 10
def __init__(
- self: "Numeric",
+ self,
precision=None,
scale=None,
decimal_return_scale=None,
@@ -573,34 +599,26 @@ class Float(Numeric[_N]):
@overload
def __init__(
self: Float[float],
- precision=...,
- decimal_return_scale=...,
+ precision: Optional[int] = ...,
+ asdecimal: Literal[False] = ...,
+ decimal_return_scale: Optional[int] = ...,
):
...
@overload
def __init__(
self: Float[decimal.Decimal],
- precision=...,
+ precision: Optional[int] = ...,
asdecimal: Literal[True] = ...,
- decimal_return_scale=...,
- ):
- ...
-
- @overload
- def __init__(
- self: Float[float],
- precision=...,
- asdecimal: Literal[False] = ...,
- decimal_return_scale=...,
+ decimal_return_scale: Optional[int] = ...,
):
...
def __init__(
self: Float[_N],
- precision=None,
- asdecimal=False,
- decimal_return_scale=None,
+ precision: Optional[int] = None,
+ asdecimal: bool = False,
+ decimal_return_scale: Optional[int] = None,
):
r"""
Construct a Float.
@@ -662,7 +680,7 @@ class Float(Numeric[_N]):
return None
-class Double(Float):
+class Double(Float[_N]):
"""A type for double ``FLOAT`` floating point types.
Typically generates a ``DOUBLE`` or ``DOUBLE_PRECISION`` in DDL,
@@ -676,7 +694,7 @@ class Double(Float):
__visit_name__ = "double"
-class DateTime(_LookupExpressionAdapter, TypeEngine[dt.datetime]):
+class DateTime(HasExpressionLookup, TypeEngine[dt.datetime]):
"""A type for ``datetime.datetime()`` objects.
@@ -738,7 +756,7 @@ class DateTime(_LookupExpressionAdapter, TypeEngine[dt.datetime]):
}
-class Date(_LookupExpressionAdapter, TypeEngine[dt.date]):
+class Date(HasExpressionLookup, TypeEngine[dt.date]):
"""A type for ``datetime.date()`` objects."""
@@ -776,7 +794,7 @@ class Date(_LookupExpressionAdapter, TypeEngine[dt.date]):
}
-class Time(_LookupExpressionAdapter, TypeEngine[dt.time]):
+class Time(HasExpressionLookup, TypeEngine[dt.time]):
"""A type for ``datetime.time()`` objects."""
@@ -895,9 +913,10 @@ class LargeBinary(_Binary):
_Binary.__init__(self, length=length)
-class SchemaType(SchemaEventTarget):
+class SchemaType(SchemaEventTarget, TypeEngineMixin):
- """Mark a type as possibly requiring schema-level DDL for usage.
+ """Add capabilities to a type which allow for schema-level DDL to be
+ associated with a type.
Supports types that must be explicitly created/dropped (i.e. PG ENUM type)
as well as types that are complimented by table or schema level
@@ -920,6 +939,8 @@ class SchemaType(SchemaEventTarget):
_use_schema_map = True
+ name: Optional[str]
+
def __init__(
self,
name=None,
@@ -1021,33 +1042,37 @@ class SchemaType(SchemaEventTarget):
)
def copy(self, **kw):
- return self.adapt(self.__class__, _create_events=True)
-
- def adapt(self, impltype, **kw):
- schema = kw.pop("schema", self.schema)
- metadata = kw.pop("metadata", self.metadata)
- _create_events = kw.pop("_create_events", False)
- return impltype(
- name=self.name,
- schema=schema,
- inherit_schema=self.inherit_schema,
- metadata=metadata,
- _create_events=_create_events,
- **kw,
+ return self.adapt(
+ cast("Type[TypeEngine[Any]]", self.__class__),
+ _create_events=True,
)
+ @overload
+ def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+ ...
+
+ @overload
+ def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
+ ...
+
+ def adapt(
+ self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
+ ) -> TypeEngine[Any]:
+ kw.setdefault("_create_events", False)
+ return super().adapt(cls, **kw)
+
def create(self, bind, checkfirst=False):
"""Issue CREATE DDL for this type, if applicable."""
t = self.dialect_impl(bind.dialect)
- if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
t.create(bind, checkfirst=checkfirst)
def drop(self, bind, checkfirst=False):
"""Issue DROP DDL for this type, if applicable."""
t = self.dialect_impl(bind.dialect)
- if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
t.drop(bind, checkfirst=checkfirst)
def _on_table_create(self, target, bind, **kw):
@@ -1055,7 +1080,7 @@ class SchemaType(SchemaEventTarget):
return
t = self.dialect_impl(bind.dialect)
- if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
t._on_table_create(target, bind, **kw)
def _on_table_drop(self, target, bind, **kw):
@@ -1063,7 +1088,7 @@ class SchemaType(SchemaEventTarget):
return
t = self.dialect_impl(bind.dialect)
- if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
t._on_table_drop(target, bind, **kw)
def _on_metadata_create(self, target, bind, **kw):
@@ -1071,7 +1096,7 @@ class SchemaType(SchemaEventTarget):
return
t = self.dialect_impl(bind.dialect)
- if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
t._on_metadata_create(target, bind, **kw)
def _on_metadata_drop(self, target, bind, **kw):
@@ -1079,7 +1104,7 @@ class SchemaType(SchemaEventTarget):
return
t = self.dialect_impl(bind.dialect)
- if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
t._on_metadata_drop(target, bind, **kw)
def _is_impl_for_variant(self, dialect, kw):
@@ -1112,7 +1137,7 @@ class SchemaType(SchemaEventTarget):
return _we_are_the_impl(variant_mapping["_default"])
-class Enum(Emulated, String, TypeEngine[Union[str, enum.Enum]], SchemaType):
+class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
"""Generic Enum Type.
The :class:`.Enum` type provides a set of possible string values
@@ -1464,8 +1489,14 @@ class Enum(Emulated, String, TypeEngine[Union[str, enum.Enum]], SchemaType):
)
) from err
- class Comparator(String.Comparator[_T]):
- def _adapt_expression(self, op, other_comparator):
+ class Comparator(String.Comparator[str]):
+ type: String
+
+ def _adapt_expression(
+ self,
+ op: OperatorType,
+ other_comparator: TypeEngine.Comparator[Any],
+ ) -> Tuple[OperatorType, TypeEngine[Any]]:
op, typ = super(Enum.Comparator, self)._adapt_expression(
op, other_comparator
)
@@ -1663,15 +1694,16 @@ class PickleType(TypeDecorator[object]):
return PickleType, (self.protocol, None, self.comparator)
def bind_processor(self, dialect):
- impl_processor = self.impl.bind_processor(dialect)
+ impl_processor = self.impl_instance.bind_processor(dialect)
dumps = self.pickler.dumps
protocol = self.protocol
if impl_processor:
+ fixed_impl_processor = impl_processor
def process(value):
if value is not None:
value = dumps(value, protocol)
- return impl_processor(value)
+ return fixed_impl_processor(value)
else:
@@ -1683,12 +1715,13 @@ class PickleType(TypeDecorator[object]):
return process
def result_processor(self, dialect, coltype):
- impl_processor = self.impl.result_processor(dialect, coltype)
+ impl_processor = self.impl_instance.result_processor(dialect, coltype)
loads = self.pickler.loads
if impl_processor:
+ fixed_impl_processor = impl_processor
def process(value):
- value = impl_processor(value)
+ value = fixed_impl_processor(value)
if value is None:
return None
return loads(value)
@@ -1709,7 +1742,7 @@ class PickleType(TypeDecorator[object]):
return x == y
-class Boolean(Emulated, TypeEngine[bool], SchemaType):
+class Boolean(SchemaType, Emulated, TypeEngine[bool]):
"""A bool datatype.
@@ -1733,7 +1766,7 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType):
native = True
def __init__(
- self: "Boolean",
+ self,
create_constraint=False,
name=None,
_create_events=True,
@@ -1818,6 +1851,9 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType):
def bind_processor(self, dialect):
_strict_as_bool = self._strict_as_bool
+
+ _coerce: Union[Type[bool], Type[int]]
+
if dialect.supports_native_boolean:
_coerce = bool
else:
@@ -1838,7 +1874,7 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType):
return processors.int_to_boolean
-class _AbstractInterval(_LookupExpressionAdapter, TypeEngine[dt.timedelta]):
+class _AbstractInterval(HasExpressionLookup, TypeEngine[dt.timedelta]):
@util.memoized_property
def _expression_adaptations(self):
# Based on https://www.postgresql.org/docs/current/\
@@ -1856,16 +1892,12 @@ class _AbstractInterval(_LookupExpressionAdapter, TypeEngine[dt.timedelta]):
operators.truediv: {Numeric: self.__class__},
}
- @property
- def _type_affinity(self):
+ @util.non_memoized_property
+ def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
return Interval
- def coerce_compared_value(self, op, value):
- """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
- return self.impl.coerce_compared_value(op, value)
-
-class Interval(Emulated, _AbstractInterval, TypeDecorator):
+class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]):
"""A type for ``datetime.timedelta()`` objects.
@@ -1909,6 +1941,14 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
self.second_precision = second_precision
self.day_precision = day_precision
+ class Comparator(
+ TypeDecorator.Comparator[_CT],
+ _AbstractInterval.Comparator[_CT],
+ ):
+ pass
+
+ comparator_factory = Comparator
+
@property
def python_type(self):
return dt.timedelta
@@ -1916,42 +1956,63 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
def adapt_to_emulated(self, impltype, **kw):
return _AbstractInterval.adapt(self, impltype, **kw)
- def bind_processor(self, dialect):
- impl_processor = self.impl.bind_processor(dialect)
+ def coerce_compared_value(self, op, value):
+ return self.impl_instance.coerce_compared_value(op, value)
+
+ def bind_processor(
+ self, dialect: Dialect
+ ) -> _BindProcessorType[dt.timedelta]:
+ if TYPE_CHECKING:
+ assert isinstance(self.impl_instance, DateTime)
+ impl_processor = self.impl_instance.bind_processor(dialect)
epoch = self.epoch
if impl_processor:
+ fixed_impl_processor = impl_processor
- def process(value):
+ def process(
+ value: Optional[dt.timedelta],
+ ) -> Any:
if value is not None:
- value = epoch + value
- return impl_processor(value)
+ dt_value = epoch + value
+ else:
+ dt_value = None
+ return fixed_impl_processor(dt_value)
else:
- def process(value):
+ def process(
+ value: Optional[dt.timedelta],
+ ) -> Any:
if value is not None:
- value = epoch + value
- return value
+ dt_value = epoch + value
+ else:
+ dt_value = None
+ return dt_value
return process
- def result_processor(self, dialect, coltype):
- impl_processor = self.impl.result_processor(dialect, coltype)
+ def result_processor(
+ self, dialect: Dialect, coltype: Any
+ ) -> _ResultProcessorType[dt.timedelta]:
+ if TYPE_CHECKING:
+ assert isinstance(self.impl_instance, DateTime)
+ impl_processor = self.impl_instance.result_processor(dialect, coltype)
epoch = self.epoch
if impl_processor:
+ fixed_impl_processor = impl_processor
- def process(value):
- value = impl_processor(value)
- if value is None:
+ def process(value: Any) -> Optional[dt.timedelta]:
+ dt_value = fixed_impl_processor(value)
+ if dt_value is None:
return None
- return value - epoch
+ return dt_value - epoch
else:
- def process(value):
+ def process(value: Any) -> Optional[dt.timedelta]:
if value is None:
return None
- return value - epoch
+ return value - epoch # type: ignore
return process
@@ -2233,7 +2294,7 @@ class JSON(Indexable, TypeEngine[Any]):
"""
self.none_as_null = none_as_null
- class JSONElementType(TypeEngine):
+ class JSONElementType(TypeEngine[Any]):
"""Common function for index / path elements in a JSON expression."""
_integer = Integer()
@@ -2457,7 +2518,7 @@ class JSON(Indexable, TypeEngine[Any]):
def python_type(self):
return dict
- @property
+ @property # type: ignore # mypy property bug
def should_evaluate_none(self):
"""Alias of :attr:`_types.JSON.none_as_null`"""
return not self.none_as_null
@@ -2632,23 +2693,26 @@ class ARRAY(
"""
def _setup_getitem(self, index):
+
+ arr_type = cast(ARRAY, self.type)
+
if isinstance(index, slice):
- return_type = self.type
- if self.type.zero_indexes:
+ return_type = arr_type
+ if arr_type.zero_indexes:
index = slice(index.start + 1, index.stop + 1, index.step)
slice_ = Slice(
index.start, index.stop, index.step, _name=self.expr.key
)
return operators.getitem, slice_, return_type
else:
- if self.type.zero_indexes:
+ if arr_type.zero_indexes:
index += 1
- if self.type.dimensions is None or self.type.dimensions == 1:
- return_type = self.type.item_type
+ if arr_type.dimensions is None or arr_type.dimensions == 1:
+ return_type = arr_type.item_type
else:
- adapt_kw = {"dimensions": self.type.dimensions - 1}
- return_type = self.type.adapt(
- self.type.__class__, **adapt_kw
+ adapt_kw = {"dimensions": arr_type.dimensions - 1}
+ return_type = arr_type.adapt(
+ arr_type.__class__, **adapt_kw
)
return operators.getitem, index, return_type
@@ -2853,7 +2917,7 @@ class TupleType(TypeEngine[Tuple[Any, ...]]):
)
-class REAL(Float):
+class REAL(Float[_N]):
"""The SQL REAL type.
@@ -2866,7 +2930,7 @@ class REAL(Float):
__visit_name__ = "REAL"
-class FLOAT(Float):
+class FLOAT(Float[_N]):
"""The SQL FLOAT type.
@@ -2879,7 +2943,7 @@ class FLOAT(Float):
__visit_name__ = "FLOAT"
-class DOUBLE(Double):
+class DOUBLE(Double[_N]):
"""The SQL DOUBLE type.
.. versionadded:: 2.0
@@ -2893,7 +2957,7 @@ class DOUBLE(Double):
__visit_name__ = "DOUBLE"
-class DOUBLE_PRECISION(Double):
+class DOUBLE_PRECISION(Double[_N]):
"""The SQL DOUBLE PRECISION type.
.. versionadded:: 2.0
@@ -2907,7 +2971,7 @@ class DOUBLE_PRECISION(Double):
__visit_name__ = "DOUBLE_PRECISION"
-class NUMERIC(Numeric):
+class NUMERIC(Numeric[_N]):
"""The SQL NUMERIC type.
@@ -2920,7 +2984,7 @@ class NUMERIC(Numeric):
__visit_name__ = "NUMERIC"
-class DECIMAL(Numeric):
+class DECIMAL(Numeric[_N]):
"""The SQL DECIMAL type.
@@ -3099,7 +3163,7 @@ class BOOLEAN(Boolean):
__visit_name__ = "BOOLEAN"
-class NullType(TypeEngine):
+class NullType(TypeEngine[None]):
"""An unknown type.
@@ -3139,7 +3203,11 @@ class NullType(TypeEngine):
return process
class Comparator(TypeEngine.Comparator[_T]):
- def _adapt_expression(self, op, other_comparator):
+ def _adapt_expression(
+ self,
+ op: OperatorType,
+ other_comparator: TypeEngine.Comparator[Any],
+ ) -> Tuple[OperatorType, TypeEngine[Any]]:
if isinstance(
other_comparator, NullType.Comparator
) or not operators.is_commutative(op):
@@ -3150,7 +3218,7 @@ class NullType(TypeEngine):
comparator_factory = Comparator
-class TableValueType(HasCacheKey, TypeEngine):
+class TableValueType(HasCacheKey, TypeEngine[Any]):
"""Refers to a table value type."""
_is_table_value = True
@@ -3195,7 +3263,7 @@ _TIME = Time()
_STRING = String()
_UNICODE = Unicode()
-_type_map = {
+_type_map: Dict[Type[Any], TypeEngine[Any]] = {
int: Integer(),
float: Float(),
bool: BOOLEANTYPE,
@@ -3204,7 +3272,7 @@ _type_map = {
dt.datetime: _DATETIME,
dt.time: _TIME,
dt.timedelta: Interval(),
- util.NoneType: NULLTYPE,
+ type(None): NULLTYPE,
bytes: LargeBinary(),
str: _STRING,
}
@@ -3213,7 +3281,7 @@ _type_map = {
_type_map_get = _type_map.get
-def _resolve_value_to_type(value):
+def _resolve_value_to_type(value: Any) -> TypeEngine[Any]:
_result_type = _type_map_get(type(value), False)
if _result_type is False:
# use inspect() to detect SQLAlchemy built-in
@@ -3231,7 +3299,9 @@ def _resolve_value_to_type(value):
)
return NULLTYPE
else:
- return _result_type._resolve_for_literal(value)
+ return _result_type._resolve_for_literal( # type: ignore [union-attr]
+ value
+ )
# back-assign to type_api
@@ -3240,7 +3310,6 @@ type_api.STRINGTYPE = STRINGTYPE
type_api.INTEGERTYPE = INTEGERTYPE
type_api.NULLTYPE = NULLTYPE
type_api.MATCHTYPE = MATCHTYPE
-type_api.INDEXABLE = Indexable
+type_api.INDEXABLE = INDEXABLE = Indexable
type_api.TABLEVALUE = TABLEVALUE
type_api._resolve_value_to_type = _resolve_value_to_type
-TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 55997556a..5a0aba694 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -11,41 +11,92 @@
from __future__ import annotations
+from types import ModuleType
import typing
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Dict
from typing import Generic
+from typing import Mapping
from typing import Optional
+from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .base import SchemaEventTarget
+from .cache_key import CacheConst
from .cache_key import NO_CACHE
from .operators import ColumnOperators
from .visitors import Visitable
from .. import exc
from .. import util
+from ..util.typing import Protocol
+from ..util.typing import TypedDict
+from ..util.typing import TypeGuard
# these are back-assigned by sqltypes.
if typing.TYPE_CHECKING:
+ from .elements import BindParameter
from .elements import ColumnElement
from .operators import OperatorType
+ from .schema import Column
from .sqltypes import _resolve_value_to_type as _resolve_value_to_type
from .sqltypes import BOOLEANTYPE as BOOLEANTYPE
- from .sqltypes import Indexable as INDEXABLE
+ from .sqltypes import INDEXABLE as INDEXABLE
from .sqltypes import INTEGERTYPE as INTEGERTYPE
from .sqltypes import MATCHTYPE as MATCHTYPE
from .sqltypes import NULLTYPE as NULLTYPE
+ from .sqltypes import NullType
+ from .sqltypes import STRINGTYPE as STRINGTYPE
+ from .sqltypes import TABLEVALUE as TABLEVALUE
+ from ..engine.interfaces import Dialect
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
+_O = TypeVar("_O", bound=object)
_TE = TypeVar("_TE", bound="TypeEngine[Any]")
_CT = TypeVar("_CT", bound=Any)
# replace with pep-673 when applicable
-SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine")
+SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine[Any]")
+
+
+class _LiteralProcessorType(Protocol[_T_co]):
+ def __call__(self, value: Any) -> str:
+ ...
+
+
+class _BindProcessorType(Protocol[_T_con]):
+ def __call__(self, value: Optional[_T_con]) -> Any:
+ ...
+
+
+class _ResultProcessorType(Protocol[_T_co]):
+ def __call__(self, value: Any) -> Optional[_T_co]:
+ ...
+
+
+class _BaseTypeMemoDict(TypedDict):
+ impl: TypeEngine[Any]
+ result: Dict[Any, Optional[_ResultProcessorType[Any]]]
+
+
+class _TypeMemoDict(_BaseTypeMemoDict, total=False):
+ literal: Optional[_LiteralProcessorType[Any]]
+ bind: Optional[_BindProcessorType[Any]]
+ custom: Dict[Any, object]
+
+
+class _ComparatorFactory(Protocol[_T]):
+ def __call__(self, expr: ColumnElement[_T]) -> TypeEngine.Comparator[_T]:
+ ...
class TypeEngine(Visitable, Generic[_T]):
@@ -70,8 +121,6 @@ class TypeEngine(Visitable, Generic[_T]):
_is_array = False
_is_type_decorator = False
- _block_from_type_affinity = False
-
render_bind_cast = False
"""Render bind casts for :attr:`.BindTyping.RENDER_CASTS` mode.
@@ -99,38 +148,41 @@ class TypeEngine(Visitable, Generic[_T]):
__slots__ = "expr", "type"
- default_comparator = None
+ expr: ColumnElement[_CT]
+ type: TypeEngine[_CT]
- def __clause_element__(self):
+ def __clause_element__(self) -> ColumnElement[_CT]:
return self.expr
- def __init__(self, expr: "ColumnElement[_CT]"):
+ def __init__(self, expr: ColumnElement[_CT]):
self.expr = expr
- self.type: TypeEngine[_CT] = expr.type
+ self.type = expr.type
@util.preload_module("sqlalchemy.sql.default_comparator")
def operate(
- self, op: "OperatorType", *other, **kwargs
- ) -> "ColumnElement":
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[_CT]:
default_comparator = util.preloaded.sql_default_comparator
op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__]
if kwargs:
addtl_kw = addtl_kw.union(kwargs)
- return op_fn(self.expr, op, *other, **addtl_kw)
+ return op_fn(self.expr, op, *other, **addtl_kw) # type: ignore
@util.preload_module("sqlalchemy.sql.default_comparator")
def reverse_operate(
- self, op: "OperatorType", other, **kwargs
- ) -> "ColumnElement":
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[_CT]:
default_comparator = util.preloaded.sql_default_comparator
op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__]
if kwargs:
addtl_kw = addtl_kw.union(kwargs)
- return op_fn(self.expr, op, other, reverse=True, **addtl_kw)
+ return op_fn(self.expr, op, other, reverse=True, **addtl_kw) # type: ignore # noqa E501
def _adapt_expression(
- self, op: "OperatorType", other_comparator
- ) -> Tuple["OperatorType", "TypeEngine[_CT]"]:
+ self,
+ op: OperatorType,
+ other_comparator: TypeEngine.Comparator[Any],
+ ) -> Tuple[OperatorType, TypeEngine[Any]]:
"""evaluate the return type of <self> <op> <othertype>,
and apply any adaptations to the given operator.
@@ -159,7 +211,7 @@ class TypeEngine(Visitable, Generic[_T]):
return op, self.type
- def __reduce__(self):
+ def __reduce__(self) -> Any:
return _reconstitute_comparator, (self.expr,)
hashable = True
@@ -169,7 +221,7 @@ class TypeEngine(Visitable, Generic[_T]):
"""
- comparator_factory = Comparator
+ comparator_factory: _ComparatorFactory[Any] = Comparator
"""A :class:`.TypeEngine.Comparator` class which will apply
to operations performed by owning :class:`_expression.ColumnElement`
objects.
@@ -193,7 +245,7 @@ class TypeEngine(Visitable, Generic[_T]):
"""
- sort_key_function = None
+ sort_key_function: Optional[Callable[[Any], Any]] = None
"""A sorting function that can be passed as the key to sorted.
The default value of ``None`` indicates that the values stored by
@@ -203,7 +255,7 @@ class TypeEngine(Visitable, Generic[_T]):
"""
- should_evaluate_none = False
+ should_evaluate_none: bool = False
"""If True, the Python constant ``None`` is considered to be handled
explicitly by this type.
@@ -226,9 +278,11 @@ class TypeEngine(Visitable, Generic[_T]):
"""
- _variant_mapping = util.EMPTY_DICT
+ _variant_mapping: util.immutabledict[
+ str, TypeEngine[Any]
+ ] = util.EMPTY_DICT
- def evaluates_none(self):
+ def evaluates_none(self: SelfTypeEngine) -> SelfTypeEngine:
"""Return a copy of this type which has the :attr:`.should_evaluate_none`
flag set to True.
@@ -280,10 +334,12 @@ class TypeEngine(Visitable, Generic[_T]):
typ.should_evaluate_none = True
return typ
- def copy(self, **kw):
+ def copy(self: SelfTypeEngine, **kw: Any) -> SelfTypeEngine:
return self.adapt(self.__class__)
- def compare_against_backend(self, dialect, conn_type):
+ def compare_against_backend(
+ self, dialect: Dialect, conn_type: TypeEngine[Any]
+ ) -> Optional[bool]:
"""Compare this type against the given backend type.
This function is currently not implemented for SQLAlchemy
@@ -310,10 +366,12 @@ class TypeEngine(Visitable, Generic[_T]):
"""
return None
- def copy_value(self, value):
+ def copy_value(self, value: Any) -> Any:
return value
- def literal_processor(self, dialect):
+ def literal_processor(
+ self, dialect: Dialect
+ ) -> Optional[_LiteralProcessorType[_T]]:
"""Return a conversion function for processing literal values that are
to be rendered directly without using binds.
@@ -348,7 +406,9 @@ class TypeEngine(Visitable, Generic[_T]):
"""
return None
- def bind_processor(self, dialect):
+ def bind_processor(
+ self, dialect: Dialect
+ ) -> Optional[_BindProcessorType[_T]]:
"""Return a conversion function for processing bind values.
Returns a callable which will receive a bind parameter value
@@ -382,7 +442,9 @@ class TypeEngine(Visitable, Generic[_T]):
"""
return None
- def result_processor(self, dialect, coltype):
+ def result_processor(
+ self, dialect: Dialect, coltype: object
+ ) -> Optional[_ResultProcessorType[_T]]:
"""Return a conversion function for processing result row values.
Returns a callable which will receive a result row column
@@ -417,7 +479,9 @@ class TypeEngine(Visitable, Generic[_T]):
"""
return None
- def column_expression(self, colexpr):
+ def column_expression(
+ self, colexpr: ColumnElement[_T]
+ ) -> Optional[ColumnElement[_T]]:
"""Given a SELECT column expression, return a wrapping SQL expression.
This is typically a SQL function that wraps a column expression
@@ -461,7 +525,7 @@ class TypeEngine(Visitable, Generic[_T]):
return None
@util.memoized_property
- def _has_column_expression(self):
+ def _has_column_expression(self) -> bool:
"""memoized boolean, check if column_expression is implemented.
Allows the method to be skipped for the vast majority of expression
@@ -474,7 +538,9 @@ class TypeEngine(Visitable, Generic[_T]):
is not TypeEngine.column_expression.__code__
)
- def bind_expression(self, bindvalue):
+ def bind_expression(
+ self, bindvalue: BindParameter[_T]
+ ) -> Optional[ColumnElement[_T]]:
"""Given a bind value (i.e. a :class:`.BindParameter` instance),
return a SQL expression in its place.
@@ -521,7 +587,7 @@ class TypeEngine(Visitable, Generic[_T]):
return None
@util.memoized_property
- def _has_bind_expression(self):
+ def _has_bind_expression(self) -> bool:
"""memoized boolean, check if bind_expression is implemented.
Allows the method to be skipped for the vast majority of expression
@@ -535,12 +601,12 @@ class TypeEngine(Visitable, Generic[_T]):
def _to_instance(cls_or_self: Union[Type[_TE], _TE]) -> _TE:
return to_instance(cls_or_self)
- def compare_values(self, x, y):
+ def compare_values(self, x: Any, y: Any) -> bool:
"""Compare two values for equality."""
- return x == y
+ return x == y # type: ignore[no-any-return]
- def get_dbapi_type(self, dbapi):
+ def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
"""Return the corresponding type object from the underlying DB-API, if
any.
@@ -550,7 +616,7 @@ class TypeEngine(Visitable, Generic[_T]):
return None
@property
- def python_type(self):
+ def python_type(self) -> Type[Any]:
"""Return the Python type object expected to be returned
by instances of this type, if known.
@@ -569,7 +635,7 @@ class TypeEngine(Visitable, Generic[_T]):
raise NotImplementedError()
def with_variant(
- self: SelfTypeEngine, type_: "TypeEngine", *dialect_names: str
+ self: SelfTypeEngine, type_: TypeEngine[Any], *dialect_names: str
) -> SelfTypeEngine:
r"""Produce a copy of this type object that will utilize the given
type when applied to the dialect of the given name.
@@ -626,7 +692,9 @@ class TypeEngine(Visitable, Generic[_T]):
)
return new_type
- def _resolve_for_literal(self, value):
+ def _resolve_for_literal(
+ self: SelfTypeEngine, value: Any
+ ) -> SelfTypeEngine:
"""adjust this type given a literal Python value that will be
stored in a bound parameter.
@@ -638,28 +706,28 @@ class TypeEngine(Visitable, Generic[_T]):
return self
@util.memoized_property
- def _type_affinity(self):
+ def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]:
"""Return a rudimental 'affinity' value expressing the general class
of type."""
typ = None
for t in self.__class__.__mro__:
- if t in (TypeEngine, UserDefinedType):
+ if t is TypeEngine or TypeEngineMixin in t.__bases__:
return typ
- elif issubclass(
- t, (TypeEngine, UserDefinedType)
- ) and not t.__dict__.get("_block_from_type_affinity", False):
+ elif issubclass(t, TypeEngine):
typ = t
else:
return self.__class__
@util.memoized_property
- def _generic_type_affinity(self):
+ def _generic_type_affinity(
+ self,
+ ) -> Type[TypeEngine[_T]]:
best_camelcase = None
best_uppercase = None
- if not isinstance(self, (TypeEngine, UserDefinedType)):
- return self.__class__
+ if not isinstance(self, TypeEngine):
+ return self.__class__ # type: ignore # mypy bug?
for t in self.__class__.__mro__:
if (
@@ -669,7 +737,8 @@ class TypeEngine(Visitable, Generic[_T]):
"sqlalchemy.sql.type_api",
)
and issubclass(t, TypeEngine)
- and t is not TypeEngine
+ and TypeEngineMixin not in t.__bases__
+ and t not in (TypeEngine, TypeEngineMixin)
and t.__name__[0] != "_"
):
if t.__name__.isupper() and not best_uppercase:
@@ -677,9 +746,13 @@ class TypeEngine(Visitable, Generic[_T]):
elif not t.__name__.isupper() and not best_camelcase:
best_camelcase = t
- return best_camelcase or best_uppercase or NULLTYPE.__class__
+ return (
+ best_camelcase
+ or best_uppercase
+ or cast("Type[TypeEngine[_T]]", NULLTYPE.__class__)
+ )
- def as_generic(self, allow_nulltype=False):
+ def as_generic(self, allow_nulltype: bool = False) -> TypeEngine[_T]:
"""
Return an instance of the generic type corresponding to this type
using heuristic rule. The method may be overridden if this
@@ -719,18 +792,20 @@ class TypeEngine(Visitable, Generic[_T]):
return util.constructor_copy(self, self._generic_type_affinity)
- def dialect_impl(self, dialect):
+ def dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]:
"""Return a dialect-specific implementation for this
:class:`.TypeEngine`.
"""
try:
- return dialect._type_memos[self]["impl"]
+ tm = dialect._type_memos[self]
except KeyError:
pass
+ else:
+ return tm["impl"]
return self._dialect_info(dialect)["impl"]
- def _unwrapped_dialect_impl(self, dialect):
+ def _unwrapped_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]:
"""Return the 'unwrapped' dialect impl for this type.
For a type that applies wrapping logic (e.g. TypeDecorator), give
@@ -744,60 +819,80 @@ class TypeEngine(Visitable, Generic[_T]):
"""
return self.dialect_impl(dialect)
- def _cached_literal_processor(self, dialect):
+ def _cached_literal_processor(
+ self, dialect: Dialect
+ ) -> Optional[_LiteralProcessorType[_T]]:
"""Return a dialect-specific literal processor for this type."""
+
try:
return dialect._type_memos[self]["literal"]
except KeyError:
pass
+
# avoid KeyError context coming into literal_processor() function
# raises
d = self._dialect_info(dialect)
d["literal"] = lp = d["impl"].literal_processor(dialect)
return lp
- def _cached_bind_processor(self, dialect):
+ def _cached_bind_processor(
+ self, dialect: Dialect
+ ) -> Optional[_BindProcessorType[_T]]:
"""Return a dialect-specific bind processor for this type."""
try:
return dialect._type_memos[self]["bind"]
except KeyError:
pass
+
# avoid KeyError context coming into bind_processor() function
# raises
d = self._dialect_info(dialect)
d["bind"] = bp = d["impl"].bind_processor(dialect)
return bp
- def _cached_result_processor(self, dialect, coltype):
+ def _cached_result_processor(
+ self, dialect: Dialect, coltype: Any
+ ) -> Optional[_ResultProcessorType[_T]]:
"""Return a dialect-specific result processor for this type."""
try:
- return dialect._type_memos[self][coltype]
+ return dialect._type_memos[self]["result"][coltype]
except KeyError:
pass
+
# avoid KeyError context coming into result_processor() function
# raises
d = self._dialect_info(dialect)
# key assumption: DBAPI type codes are
# constants. Else this dictionary would
# grow unbounded.
- d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
+ rp = d["impl"].result_processor(dialect, coltype)
+ d["result"][coltype] = rp
return rp
- def _cached_custom_processor(self, dialect, key, fn):
+ def _cached_custom_processor(
+ self, dialect: Dialect, key: str, fn: Callable[[TypeEngine[_T]], _O]
+ ) -> _O:
+ """return a dialect-specific processing object for
+ custom purposes.
+
+ The cx_Oracle dialect uses this at the moment.
+
+ """
try:
- return dialect._type_memos[self][key]
+ return cast(_O, dialect._type_memos[self]["custom"][key])
except KeyError:
pass
# avoid KeyError context coming into fn() function
# raises
d = self._dialect_info(dialect)
impl = d["impl"]
- d[key] = result = fn(impl)
+ custom_dict = d.setdefault("custom", {})
+ custom_dict[key] = result = fn(impl)
return result
- def _dialect_info(self, dialect):
+ def _dialect_info(self, dialect: Dialect) -> _TypeMemoDict:
"""Return a dialect-specific registry which
caches a dialect-specific implementation, bind processing
function, and one or more result processing functions."""
@@ -810,10 +905,11 @@ class TypeEngine(Visitable, Generic[_T]):
impl = self.adapt(type(self))
# this can't be self, else we create a cycle
assert impl is not self
- dialect._type_memos[self] = d = {"impl": impl}
+ d: _TypeMemoDict = {"impl": impl, "result": {}}
+ dialect._type_memos[self] = d
return d
- def _gen_dialect_impl(self, dialect):
+ def _gen_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name in self._variant_mapping:
return self._variant_mapping[dialect.name]._gen_dialect_impl(
dialect
@@ -822,7 +918,9 @@ class TypeEngine(Visitable, Generic[_T]):
return dialect.type_descriptor(self)
@util.memoized_property
- def _static_cache_key(self):
+ def _static_cache_key(
+ self,
+ ) -> Union[CacheConst, Tuple[Any, ...]]:
names = util.get_cls_kwargs(self.__class__)
return (self.__class__,) + tuple(
(
@@ -835,7 +933,17 @@ class TypeEngine(Visitable, Generic[_T]):
if k in self.__dict__ and not k.startswith("_")
)
- def adapt(self, cls, **kw):
+ @overload
+ def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+ ...
+
+ @overload
+ def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
+ ...
+
+ def adapt(
+ self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
+ ) -> TypeEngine[Any]:
"""Produce an "adapted" form of this type, given an "impl" class
to work with.
@@ -843,9 +951,13 @@ class TypeEngine(Visitable, Generic[_T]):
types with "implementation" types that are specific to a particular
dialect.
"""
- return util.constructor_copy(self, cls, **kw)
+ return util.constructor_copy(
+ self, cast(Type[TypeEngine[Any]], cls), **kw
+ )
- def coerce_compared_value(self, op, value):
+ def coerce_compared_value(
+ self, op: Optional[OperatorType], value: Any
+ ) -> TypeEngine[Any]:
"""Suggest a type for a 'coerced' Python value in an expression.
Given an operator and value, gives the type a chance
@@ -873,10 +985,10 @@ class TypeEngine(Visitable, Generic[_T]):
else:
return _coerced_type
- def _compare_type_affinity(self, other):
+ def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool:
return self._type_affinity is other._type_affinity
- def compile(self, dialect=None):
+ def compile(self, dialect: Optional[Dialect] = None) -> str:
"""Produce a string-compiled form of this :class:`.TypeEngine`.
When called with no arguments, uses a "default" dialect
@@ -888,24 +1000,65 @@ class TypeEngine(Visitable, Generic[_T]):
# arg, return value is inconsistent with
# ClauseElement.compile()....this is a mistake.
- if not dialect:
+ if dialect is None:
dialect = self._default_dialect()
- return dialect.type_compiler.process(self)
+ return dialect.type_compiler_instance.process(self)
@util.preload_module("sqlalchemy.engine.default")
- def _default_dialect(self):
- default = util.preloaded.engine_default
- return default.StrCompileDialect()
+ def _default_dialect(self) -> Dialect:
- def __str__(self):
+ if TYPE_CHECKING:
+ from ..engine import default
+ else:
+ default = util.preloaded.engine_default
+
+ # dmypy / mypy seems to sporadically keep thinking this line is
+ # returning Any, which seems to be caused by the @deprecated_params
+ # decorator on the DefaultDialect constructor
+ return default.StrCompileDialect() # type: ignore
+
+ def __str__(self) -> str:
return str(self.compile())
- def __repr__(self):
+ def __repr__(self) -> str:
return util.generic_repr(self)
-class ExternalType:
+class TypeEngineMixin:
+ """classes which subclass this can act as "mixin" classes for
+ TypeEngine."""
+
+ __slots__ = ()
+
+ if TYPE_CHECKING:
+
+ @util.memoized_property
+ def _static_cache_key(
+ self,
+ ) -> Union[CacheConst, Tuple[Any, ...]]:
+ ...
+
+ @overload
+ def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+ ...
+
+ @overload
+ def adapt(
+ self, cls: Type[TypeEngineMixin], **kw: Any
+ ) -> TypeEngine[Any]:
+ ...
+
+ def adapt(
+ self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
+ ) -> TypeEngine[Any]:
+ ...
+
+ def dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
+ ...
+
+
+class ExternalType(TypeEngineMixin):
"""mixin that defines attributes and behaviors specific to third-party
datatypes.
@@ -1057,13 +1210,18 @@ class ExternalType:
""" # noqa: E501
- @property
- def _static_cache_key(self):
+ @util.non_memoized_property
+ def _static_cache_key(
+ self,
+ ) -> Union[CacheConst, Tuple[Any, ...]]:
cache_ok = self.__class__.__dict__.get("cache_ok", None)
if cache_ok is None:
- subtype_idx = self.__class__.__mro__.index(ExternalType)
- subtype = self.__class__.__mro__[max(subtype_idx - 1, 0)]
+ for subtype in self.__class__.__mro__:
+ if ExternalType in subtype.__bases__:
+ break
+ else:
+ subtype = self.__class__.__mro__[1]
util.warn(
"%s %r will not produce a cache key because "
@@ -1076,12 +1234,14 @@ class ExternalType:
code="cprf",
)
elif cache_ok is True:
- return super(ExternalType, self)._static_cache_key
+ return super()._static_cache_key
return NO_CACHE
-class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg):
+class UserDefinedType(
+ ExternalType, TypeEngineMixin, TypeEngine[_T], util.EnsureKWArg
+):
"""Base for user defined types.
This should be the base of new types. Note that
@@ -1148,7 +1308,9 @@ class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg):
ensure_kwarg = "get_col_spec"
- def coerce_compared_value(self, op, value):
+ def coerce_compared_value(
+ self, op: Optional[OperatorType], value: Any
+ ) -> TypeEngine[Any]:
"""Suggest a type for a 'coerced' Python value in an expression.
Default behavior for :class:`.UserDefinedType` is the
@@ -1162,7 +1324,7 @@ class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg):
return self
-class Emulated:
+class Emulated(TypeEngineMixin):
"""Mixin for base types that emulate the behavior of a DB-native type.
An :class:`.Emulated` type will use an available database type
@@ -1180,7 +1342,13 @@ class Emulated:
"""
- def adapt_to_emulated(self, impltype, **kw):
+ native: bool
+
+ def adapt_to_emulated(
+ self,
+ impltype: Type[Union[TypeEngine[Any], TypeEngineMixin]],
+ **kw: Any,
+ ) -> TypeEngine[Any]:
"""Given an impl class, adapt this type to the impl assuming "emulated".
The impl should also be an "emulated" version of this type,
@@ -1189,27 +1357,43 @@ class Emulated:
e.g.: sqltypes.Enum adapts to the Enum class.
"""
- return super(Emulated, self).adapt(impltype, **kw)
+ return super().adapt(impltype, **kw)
- def adapt(self, impltype, **kw):
- if hasattr(impltype, "adapt_emulated_to_native"):
+ @overload
+ def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+ ...
+
+ @overload
+ def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
+ ...
+
+ def adapt(
+ self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
+ ) -> TypeEngine[Any]:
+ if _is_native_for_emulated(cls):
if self.native:
# native support requested, dialect gave us a native
# implementor, pass control over to it
- return impltype.adapt_emulated_to_native(self, **kw)
+ return cls.adapt_emulated_to_native(self, **kw)
else:
# non-native support, let the native implementor
# decide also, at the moment this is just to help debugging
# as only the default logic is implemented.
- return impltype.adapt_native_to_emulated(self, **kw)
+ return cls.adapt_native_to_emulated(self, **kw)
else:
- if issubclass(impltype, self.__class__):
- return self.adapt_to_emulated(impltype, **kw)
+ if issubclass(cls, self.__class__):
+ return self.adapt_to_emulated(cls, **kw)
else:
- return super(Emulated, self).adapt(impltype, **kw)
+ return super().adapt(cls, **kw)
+
+def _is_native_for_emulated(
+ typ: Type[Union[TypeEngine[Any], TypeEngineMixin]],
+) -> TypeGuard["Type[NativeForEmulated]"]:
+ return hasattr(typ, "adapt_emulated_to_native")
-class NativeForEmulated:
+
+class NativeForEmulated(TypeEngineMixin):
"""Indicates DB-native types supported by an :class:`.Emulated` type.
.. versionadded:: 1.2.0b3
@@ -1217,7 +1401,11 @@ class NativeForEmulated:
"""
@classmethod
- def adapt_native_to_emulated(cls, impl, **kw):
+ def adapt_native_to_emulated(
+ cls,
+ impl: Union[TypeEngine[Any], TypeEngineMixin],
+ **kw: Any,
+ ) -> TypeEngine[Any]:
"""Given an impl, adapt this type's class to the impl assuming
"emulated".
@@ -1227,7 +1415,12 @@ class NativeForEmulated:
return impl.adapt(impltype, **kw)
@classmethod
- def adapt_emulated_to_native(cls, impl, **kw):
+ def adapt_emulated_to_native(
+ cls,
+ impl: Union[TypeEngine[Any], TypeEngineMixin],
+ **kw: Any,
+ ) -> TypeEngine[Any]:
+
"""Given an impl, adapt this type's class to the impl assuming "native".
The impl will be an :class:`.Emulated` class but not a
@@ -1236,10 +1429,20 @@ class NativeForEmulated:
e.g.: postgresql.ENUM produces a type given an Enum instance.
"""
- return cls(**kw)
+ # dmypy seems to crash on this
+ return cls(**kw) # type: ignore
-class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
+ # dmypy seems to crash with this, on repeated runs with changes
+ # if TYPE_CHECKING:
+ # def __init__(self, **kw: Any):
+ # ...
+
+
+SelfTypeDecorator = TypeVar("SelfTypeDecorator", bound="TypeDecorator[Any]")
+
+
+class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
"""Allows the creation of types which add additional functionality
to an existing type.
@@ -1358,9 +1561,24 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
_is_type_decorator = True
+ # this is that pattern I've used in a few places (Dialect.dbapi,
+ # Dialect.type_compiler) where the "cls.attr" is a class to make something,
+ # and "instance.attr" is an instance of that thing. It's such a nifty,
+ # great pattern, and there is zero chance Python typing tools will ever be
+ # OK with it. For TypeDecorator.impl, this is a highly public attribute so
+ # we really can't change its behavior without a major deprecation routine.
impl: Union[TypeEngine[Any], Type[TypeEngine[Any]]]
- def __init__(self, *args, **kwargs):
+ # we are changing its behavior *slightly*, which is that we now consume
+ # the instance level version from this memoized property instead, so you
+ # can't reassign "impl" on an existing TypeDecorator that's already been
+ # used (something one shouldn't do anyway) without also updating
+ # impl_instance.
+ @util.memoized_property
+ def impl_instance(self) -> TypeEngine[Any]:
+ return self.impl # type: ignore
+
+ def __init__(self, *args: Any, **kwargs: Any):
"""Construct a :class:`.TypeDecorator`.
Arguments sent here are passed to the constructor
@@ -1385,9 +1603,10 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
"'impl' which refers to the class of "
"type being decorated"
)
+
self.impl = to_instance(self.__class__.impl, *args, **kwargs)
- coerce_to_is_types = (util.NoneType,)
+ coerce_to_is_types: Sequence[Type[Any]] = (type(None),)
"""Specify those Python types which should be coerced at the expression
level to "IS <constant>" when compared using ``==`` (and same for
``IS NOT`` in conjunction with ``!=``).
@@ -1416,33 +1635,42 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
__slots__ = ()
- def operate(self, op, *other, **kwargs):
+ def operate(
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[_CT]:
+ if TYPE_CHECKING:
+ assert isinstance(self.expr.type, TypeDecorator)
kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
return super(TypeDecorator.Comparator, self).operate(
op, *other, **kwargs
)
- def reverse_operate(self, op, other, **kwargs):
+ def reverse_operate(
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[_CT]:
+ if TYPE_CHECKING:
+ assert isinstance(self.expr.type, TypeDecorator)
kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
return super(TypeDecorator.Comparator, self).reverse_operate(
op, other, **kwargs
)
@property
- def comparator_factory(self) -> Callable[..., TypeEngine.Comparator[_T]]:
- if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__:
+ def comparator_factory( # type: ignore # mypy properties bug
+ self,
+ ) -> _ComparatorFactory[Any]:
+ if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: # type: ignore # noqa E501
return self.impl.comparator_factory
else:
+ # reconcile the Comparator class on the impl with that
+ # of TypeDecorator
return type(
"TDComparator",
- (TypeDecorator.Comparator, self.impl.comparator_factory),
+ (TypeDecorator.Comparator, self.impl.comparator_factory), # type: ignore # noqa E501
{},
)
- def _gen_dialect_impl(self, dialect):
- """
- #todo
- """
+ def _gen_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]:
if dialect.name in self._variant_mapping:
adapted = dialect.type_descriptor(
self._variant_mapping[dialect.name]
@@ -1463,35 +1691,34 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
"implement the copy() method, it must "
"return an object of type %s" % (self, self.__class__)
)
- tt.impl = typedesc
+ tt.impl = tt.impl_instance = typedesc
return tt
- @property
- def _type_affinity(self):
- """
- #todo
- """
- return self.impl._type_affinity
+ @util.non_memoized_property
+ def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
+ return self.impl_instance._type_affinity
- def _set_parent(self, column, outer=False, **kw):
+ def _set_parent(
+ self, parent: SchemaEventTarget, outer: bool = False, **kw: Any
+ ) -> None:
"""Support SchemaEventTarget"""
- super(TypeDecorator, self)._set_parent(column)
+ super()._set_parent(parent)
- if not outer and isinstance(self.impl, SchemaEventTarget):
- self.impl._set_parent(column, outer=False, **kw)
+ if not outer and isinstance(self.impl_instance, SchemaEventTarget):
+ self.impl_instance._set_parent(parent, outer=False, **kw)
- def _set_parent_with_dispatch(self, parent):
+ def _set_parent_with_dispatch(
+ self, parent: SchemaEventTarget, **kw: Any
+ ) -> None:
"""Support SchemaEventTarget"""
- super(TypeDecorator, self)._set_parent_with_dispatch(
- parent, outer=True
- )
+ super()._set_parent_with_dispatch(parent, outer=True, **kw)
- if isinstance(self.impl, SchemaEventTarget):
- self.impl._set_parent_with_dispatch(parent)
+ if isinstance(self.impl_instance, SchemaEventTarget):
+ self.impl_instance._set_parent_with_dispatch(parent)
- def type_engine(self, dialect):
+ def type_engine(self, dialect: Dialect) -> TypeEngine[Any]:
"""Return a dialect-specific :class:`.TypeEngine` instance
for this :class:`.TypeDecorator`.
@@ -1508,7 +1735,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
else:
return self.load_dialect_impl(dialect)
- def load_dialect_impl(self, dialect):
+ def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
"""Return a :class:`.TypeEngine` object corresponding to a dialect.
This is an end-user override hook that can be used to provide
@@ -1520,9 +1747,9 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
By default returns ``self.impl``.
"""
- return self.impl
+ return self.impl_instance
- def _unwrapped_dialect_impl(self, dialect):
+ def _unwrapped_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
"""Return the 'unwrapped' dialect impl for this type.
This is used by the :meth:`.DefaultDialect.set_input_sizes`
@@ -1540,12 +1767,14 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
else:
return typ
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
"""Proxy all other undefined accessors to the underlying
implementation."""
- return getattr(self.impl, key)
+ return getattr(self.impl_instance, key)
- def process_literal_param(self, value, dialect):
+ def process_literal_param(
+ self, value: Optional[_T], dialect: Dialect
+ ) -> str:
"""Receive a literal parameter value to be rendered inline within
a statement.
@@ -1568,7 +1797,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
"""
raise NotImplementedError()
- def process_bind_param(self, value, dialect):
+ def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
"""Receive a bound parameter value to be converted.
Custom subclasses of :class:`_types.TypeDecorator` should override
@@ -1595,7 +1824,9 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
raise NotImplementedError()
- def process_result_value(self, value, dialect):
+ def process_result_value(
+ self, value: Optional[Any], dialect: Any
+ ) -> Optional[_T]:
"""Receive a result-row column value to be converted.
Custom subclasses of :class:`_types.TypeDecorator` should override
@@ -1624,7 +1855,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
raise NotImplementedError()
@util.memoized_property
- def _has_bind_processor(self):
+ def _has_bind_processor(self) -> bool:
"""memoized boolean, check if process_bind_param is implemented.
Allows the base process_bind_param to raise
@@ -1638,14 +1869,16 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
)
@util.memoized_property
- def _has_literal_processor(self):
+ def _has_literal_processor(self) -> bool:
"""memoized boolean, check if process_literal_param is implemented."""
return util.method_is_overridden(
self, TypeDecorator.process_literal_param
)
- def literal_processor(self, dialect):
+ def literal_processor(
+ self, dialect: Dialect
+ ) -> Optional[_LiteralProcessorType[_T]]:
"""Provide a literal processing function for the given
:class:`.Dialect`.
@@ -1661,34 +1894,59 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
"inner" processing provided by the implementing type is maintained.
"""
+
if self._has_literal_processor:
- process_param = self.process_literal_param
+ process_literal_param = self.process_literal_param
+ process_bind_param = None
elif self._has_bind_processor:
- # the bind processor should normally be OK
- # for TypeDecorator since it isn't doing DB-level
- # handling, the handling here won't be different for bound vs.
- # literals.
- process_param = self.process_bind_param
+ # use the bind processor if dont have a literal processor,
+ # but we have an impl literal processor
+ process_literal_param = None
+ process_bind_param = self.process_bind_param
else:
- process_param = None
+ process_literal_param = None
+ process_bind_param = None
- if process_param:
- impl_processor = self.impl.literal_processor(dialect)
+ if process_literal_param is not None:
+ impl_processor = self.impl_instance.literal_processor(dialect)
if impl_processor:
- def process(value):
- return impl_processor(process_param(value, dialect))
+ fixed_impl_processor = impl_processor
+ fixed_process_literal_param = process_literal_param
+
+ def process(value: Any) -> str:
+ return fixed_impl_processor(
+ fixed_process_literal_param(value, dialect)
+ )
else:
+ fixed_process_literal_param = process_literal_param
- def process(value):
- return process_param(value, dialect)
+ def process(value: Any) -> str:
+ return fixed_process_literal_param(value, dialect)
return process
+
+ elif process_bind_param is not None:
+ impl_processor = self.impl_instance.literal_processor(dialect)
+ if not impl_processor:
+ return None
+ else:
+ fixed_impl_processor = impl_processor
+ fixed_process_bind_param = process_bind_param
+
+ def process(value: Any) -> str:
+ return fixed_impl_processor(
+ fixed_process_bind_param(value, dialect)
+ )
+
+ return process
else:
- return self.impl.literal_processor(dialect)
+ return self.impl_instance.literal_processor(dialect)
- def bind_processor(self, dialect):
+ def bind_processor(
+ self, dialect: Dialect
+ ) -> Optional[_BindProcessorType[_T]]:
"""Provide a bound value processing function for the
given :class:`.Dialect`.
@@ -1708,23 +1966,28 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
"""
if self._has_bind_processor:
process_param = self.process_bind_param
- impl_processor = self.impl.bind_processor(dialect)
+ impl_processor = self.impl_instance.bind_processor(dialect)
if impl_processor:
+ fixed_impl_processor = impl_processor
+ fixed_process_param = process_param
- def process(value):
- return impl_processor(process_param(value, dialect))
+ def process(value: Optional[_T]) -> Any:
+ return fixed_impl_processor(
+ fixed_process_param(value, dialect)
+ )
else:
+ fixed_process_param = process_param
- def process(value):
- return process_param(value, dialect)
+ def process(value: Optional[_T]) -> Any:
+ return fixed_process_param(value, dialect)
return process
else:
- return self.impl.bind_processor(dialect)
+ return self.impl_instance.bind_processor(dialect)
@util.memoized_property
- def _has_result_processor(self):
+ def _has_result_processor(self) -> bool:
"""memoized boolean, check if process_result_value is implemented.
Allows the base process_result_value to raise
@@ -1737,7 +2000,9 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
self, TypeDecorator.process_result_value
)
- def result_processor(self, dialect, coltype):
+ def result_processor(
+ self, dialect: Dialect, coltype: Any
+ ) -> Optional[_ResultProcessorType[_T]]:
"""Provide a result value processing function for the given
:class:`.Dialect`.
@@ -1758,30 +2023,39 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
"""
if self._has_result_processor:
process_value = self.process_result_value
- impl_processor = self.impl.result_processor(dialect, coltype)
+ impl_processor = self.impl_instance.result_processor(
+ dialect, coltype
+ )
if impl_processor:
+ fixed_process_value = process_value
+ fixed_impl_processor = impl_processor
- def process(value):
- return process_value(impl_processor(value), dialect)
+ def process(value: Any) -> Optional[_T]:
+ return fixed_process_value(
+ fixed_impl_processor(value), dialect
+ )
else:
+ fixed_process_value = process_value
- def process(value):
- return process_value(value, dialect)
+ def process(value: Any) -> Optional[_T]:
+ return fixed_process_value(value, dialect)
return process
else:
- return self.impl.result_processor(dialect, coltype)
+ return self.impl_instance.result_processor(dialect, coltype)
@util.memoized_property
- def _has_bind_expression(self):
+ def _has_bind_expression(self) -> bool:
return (
util.method_is_overridden(self, TypeDecorator.bind_expression)
- or self.impl._has_bind_expression
+ or self.impl_instance._has_bind_expression
)
- def bind_expression(self, bindparam):
+ def bind_expression(
+ self, bindparam: BindParameter[_T]
+ ) -> Optional[ColumnElement[_T]]:
"""Given a bind value (i.e. a :class:`.BindParameter` instance),
return a SQL expression which will typically wrap the given parameter.
@@ -1800,10 +2074,10 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
type.
"""
- return self.impl.bind_expression(bindparam)
+ return self.impl_instance.bind_expression(bindparam)
@util.memoized_property
- def _has_column_expression(self):
+ def _has_column_expression(self) -> bool:
"""memoized boolean, check if column_expression is implemented.
Allows the method to be skipped for the vast majority of expression
@@ -1813,10 +2087,12 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
return (
util.method_is_overridden(self, TypeDecorator.column_expression)
- or self.impl._has_column_expression
+ or self.impl_instance._has_column_expression
)
- def column_expression(self, column):
+ def column_expression(
+ self, column: ColumnElement[_T]
+ ) -> Optional[ColumnElement[_T]]:
"""Given a SELECT column expression, return a wrapping SQL expression.
.. note::
@@ -1838,9 +2114,11 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
"""
- return self.impl.column_expression(column)
+ return self.impl_instance.column_expression(column)
- def coerce_compared_value(self, op, value):
+ def coerce_compared_value(
+ self, op: Optional[OperatorType], value: Any
+ ) -> Any:
"""Suggest a type for a 'coerced' Python value in an expression.
By default, returns self. This method is called by
@@ -1858,7 +2136,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
"""
return self
- def copy(self, **kw):
+ def copy(self: SelfTypeDecorator, **kw: Any) -> SelfTypeDecorator:
"""Produce a copy of this :class:`.TypeDecorator` instance.
This is a shallow copy and is provided to fulfill part of
@@ -1872,16 +2150,16 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
instance.__dict__.update(self.__dict__)
return instance
- def get_dbapi_type(self, dbapi):
+ def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
"""Return the DBAPI type object represented by this
:class:`.TypeDecorator`.
By default this calls upon :meth:`.TypeEngine.get_dbapi_type` of the
underlying "impl".
"""
- return self.impl.get_dbapi_type(dbapi)
+ return self.impl_instance.get_dbapi_type(dbapi)
- def compare_values(self, x, y):
+ def compare_values(self, x: Any, y: Any) -> bool:
"""Given two values, compare them for equality.
By default this calls upon :meth:`.TypeEngine.compare_values`
@@ -1894,46 +2172,60 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
has occurred.
"""
- return self.impl.compare_values(x, y)
+ return self.impl_instance.compare_values(x, y)
+ # mypy property bug
@property
- def sort_key_function(self):
- return self.impl.sort_key_function
+ def sort_key_function(self) -> Optional[Callable[[Any], Any]]: # type: ignore # noqa E501
+ return self.impl_instance.sort_key_function
- def __repr__(self):
- return util.generic_repr(self, to_inspect=self.impl)
+ def __repr__(self) -> str:
+ return util.generic_repr(self, to_inspect=self.impl_instance)
-class Variant(TypeDecorator):
+class Variant(TypeDecorator[_T]):
"""deprecated. symbol is present for backwards-compatibility with
workaround recipes, however this actual type should not be used.
"""
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg: Any, **kw: Any):
raise NotImplementedError(
"Variant is no longer used in SQLAlchemy; this is a "
"placeholder symbol for backwards compatibility."
)
-def _reconstitute_comparator(expression):
+def _reconstitute_comparator(expression: Any) -> Any:
return expression.comparator
+@overload
+def to_instance(typeobj: Union[Type[_TE], _TE], *arg: Any, **kw: Any) -> _TE:
+ ...
+
+
+@overload
+def to_instance(typeobj: None, *arg: Any, **kw: Any) -> TypeEngine[None]:
+ ...
+
+
def to_instance(
- typeobj: Union[Type[TypeEngine[_T]], TypeEngine[_T], None], *arg, **kw
-) -> TypeEngine[_T]:
+ typeobj: Union[Type[_TE], _TE, None], *arg: Any, **kw: Any
+) -> Union[_TE, TypeEngine[None]]:
if typeobj is None:
return NULLTYPE
if callable(typeobj):
- return typeobj(*arg, **kw)
+ return typeobj(*arg, **kw) # type: ignore # for pyright
else:
return typeobj
-def adapt_type(typeobj, colspecs):
+def adapt_type(
+ typeobj: TypeEngine[Any],
+ colspecs: Mapping[Type[Any], Type[TypeEngine[Any]]],
+) -> TypeEngine[Any]:
if isinstance(typeobj, type):
typeobj = typeobj()
for t in typeobj.__class__.__mro__[0:-1]: