diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-17 16:18:55 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-19 23:15:15 -0400 |
commit | 6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f (patch) | |
tree | ae142d45de71d1ebd43df1a38e54e1d3cf1063ec /lib/sqlalchemy/sql/compiler.py | |
parent | c2fe4a264003933ff895c51f5d07a8456ac86382 (diff) | |
download | sqlalchemy-6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f.tar.gz |
pep 484 for types
strict types type_api.py, including TypeDecorator,
NativeForEmulated, etc.
Change-Id: Ib2eba26de0981324a83733954cb7044a29bbd7db
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 111 |
1 files changed, 80 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8f878b66c..f8019b9c6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -43,6 +43,7 @@ from typing import List from typing import Mapping from typing import MutableMapping from typing import NamedTuple +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Set @@ -51,6 +52,7 @@ from typing import Type from typing import TYPE_CHECKING from typing import Union +from sqlalchemy.sql.ddl import DDLElement from . import base from . import coercions from . import crud @@ -61,7 +63,9 @@ from . import schema from . import selectable from . import sqltypes from .base import _from_objects +from .base import Executable from .base import NO_ARG +from .elements import ClauseElement from .elements import quoted_name from .schema import Column from .sqltypes import TupleType @@ -78,6 +82,10 @@ if typing.TYPE_CHECKING: from .base import _AmbiguousTableNameMap from .base import CompileState from .cache_key import CacheKey + from .dml import Insert + from .dml import UpdateBase + from .dml import ValuesBase + from .elements import _truncated_label from .elements import BindParameter from .elements import ColumnClause from .elements import Label @@ -91,12 +99,13 @@ if typing.TYPE_CHECKING: from .selectable import ReturnsRows from .selectable import Select from .selectable import SelectState + from .type_api import _BindProcessorType from ..engine.cursor import CursorResultMetaData from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import _MutableCoreSingleExecuteParams from ..engine.interfaces import _SchemaTranslateMapType - from ..engine.result import _ProcessorType + from ..engine.interfaces import Dialect _FromHintsType = Dict["FromClause", str] @@ -378,7 +387,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): class ExpandedState(NamedTuple): statement: str additional_parameters: _CoreSingleExecuteParams - processors: Mapping[str, _ProcessorType] + processors: Mapping[str, _BindProcessorType[Any]] positiontup: Optional[Sequence[str]] parameter_expansion: Mapping[str, List[str]] @@ -531,11 +540,11 @@ class Compiled: def __init__( self, - dialect, - statement, - schema_translate_map=None, - render_schema_translate=False, - compile_kwargs=util.immutabledict(), + dialect: Dialect, + statement: Optional[ClauseElement], + schema_translate_map: Optional[_SchemaTranslateMapType] = None, + render_schema_translate: bool = False, + compile_kwargs: Mapping[str, Any] = util.immutabledict(), ): """Construct a new :class:`.Compiled` object. @@ -571,6 +580,8 @@ class Compiled: self.can_execute = statement.supports_execution self._annotations = statement._annotations if self.can_execute: + if TYPE_CHECKING: + assert isinstance(statement, Executable) self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) @@ -636,10 +647,10 @@ class TypeCompiler(util.EnsureKWArg): ensure_kwarg = r"visit_\w+" - def __init__(self, dialect): + def __init__(self, dialect: Dialect): self.dialect = dialect - def process(self, type_, **kw): + def process(self, type_: TypeEngine[Any], **kw: Any) -> str: if ( type_._variant_mapping and self.dialect.name in type_._variant_mapping @@ -647,7 +658,9 @@ class TypeCompiler(util.EnsureKWArg): type_ = type_._variant_mapping[self.dialect.name] return type_._compiler_dispatch(self, **kw) - def visit_unsupported_compilation(self, element, err, **kw): + def visit_unsupported_compilation( + self, element: Any, err: Exception, **kw: Any + ) -> NoReturn: raise exc.UnsupportedCompilationError(self, element) from err @@ -877,13 +890,13 @@ class SQLCompiler(Compiled): def __init__( self, - dialect, - statement, - cache_key=None, - column_keys=None, - for_executemany=False, - linting=NO_LINTING, - **kwargs, + dialect: Dialect, + statement: Optional[ClauseElement], + cache_key: Optional[CacheKey] = None, + column_keys: Optional[Sequence[str]] = None, + for_executemany: bool = False, + linting: Linting = NO_LINTING, + **kwargs: Any, ): """Construct a new :class:`.SQLCompiler` object. @@ -954,15 +967,21 @@ class SQLCompiler(Compiled): # a map which tracks "truncated" names based on # dialect.label_length or dialect.max_identifier_length - self.truncated_names = {} + self.truncated_names: Dict[Tuple[str, str], str] = {} + self._truncated_counters: Dict[str, int] = {} Compiled.__init__(self, dialect, statement, **kwargs) if self.isinsert or self.isupdate or self.isdelete: + if TYPE_CHECKING: + assert isinstance(statement, UpdateBase) + if statement._returning: self.returning = statement._returning if self.isinsert or self.isupdate: + if TYPE_CHECKING: + assert isinstance(statement, ValuesBase) if statement._inline: self.inline = True elif self.for_executemany and ( @@ -1082,9 +1101,14 @@ class SQLCompiler(Compiled): @util.memoized_property def _bind_processors( self, - ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]: + ) -> MutableMapping[ + str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]] + ]: + + # mypy is not able to see the two value types as the above Union, + # it just sees "object". don't know how to resolve return dict( - (key, value) + (key, value) # type: ignore for key, value in ( ( self.bind_names[bindparam], @@ -1301,12 +1325,14 @@ class SQLCompiler(Compiled): positiontup = None processors = self._bind_processors - single_processors = cast("Mapping[str, _ProcessorType]", processors) + single_processors = cast( + "Mapping[str, _BindProcessorType[Any]]", processors + ) tuple_processors = cast( - "Mapping[str, Sequence[_ProcessorType]]", processors + "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors ) - new_processors: Dict[str, _ProcessorType] = {} + new_processors: Dict[str, _BindProcessorType[Any]] = {} if self.positional and self._numeric_binds: # I'm not familiar with any DBAPI that uses 'numeric'. @@ -1484,6 +1510,10 @@ class SQLCompiler(Compiled): result = util.preloaded.engine_result param_key_getter = self._within_exec_param_key_getter + + if TYPE_CHECKING: + assert isinstance(self.statement, Insert) + table = self.statement.table getters = [ @@ -1530,6 +1560,9 @@ class SQLCompiler(Compiled): else: result = util.preloaded.engine_result + if TYPE_CHECKING: + assert isinstance(self.statement, Insert) + param_key_getter = self._within_exec_param_key_getter table = self.statement.table @@ -1796,7 +1829,9 @@ class SQLCompiler(Compiled): def visit_typeclause(self, typeclause, **kw): kw["type_expression"] = typeclause kw["identifier_preparer"] = self.preparer - return self.dialect.type_compiler.process(typeclause.type, **kw) + return self.dialect.type_compiler_instance.process( + typeclause.type, **kw + ) def post_process_text(self, text): if self.preparer._double_percents: @@ -2855,26 +2890,28 @@ class SQLCompiler(Compiled): return bind_name - def _truncated_identifier(self, ident_class, name): + def _truncated_identifier( + self, ident_class: str, name: _truncated_label + ) -> str: if (ident_class, name) in self.truncated_names: return self.truncated_names[(ident_class, name)] anonname = name.apply_map(self.anon_map) if len(anonname) > self.label_length - 6: - counter = self.truncated_names.get(ident_class, 1) + counter = self._truncated_counters.get(ident_class, 1) truncname = ( anonname[0 : max(self.label_length - 6, 0)] + "_" + hex(counter)[2:] ) - self.truncated_names[ident_class] = counter + 1 + self._truncated_counters[ident_class] = counter + 1 else: truncname = anonname self.truncated_names[(ident_class, name)] = truncname return truncname - def _anonymize(self, name): + def _anonymize(self, name: str) -> str: return name % self.anon_map def bindparam_string( @@ -3221,7 +3258,7 @@ class SQLCompiler(Compiled): % ( self.preparer.quote(col.name), " %s" - % self.dialect.type_compiler.process( + % self.dialect.type_compiler_instance.process( col.type, **kwargs ) if alias._render_derived_w_types @@ -4685,6 +4722,18 @@ class StrSQLCompiler(SQLCompiler): class DDLCompiler(Compiled): + if TYPE_CHECKING: + + def __init__( + self, + dialect: Dialect, + statement: DDLElement, + schema_translate_map: Optional[_SchemaTranslateMapType] = ..., + render_schema_translate: bool = ..., + compile_kwargs: Mapping[str, Any] = ..., + ): + ... + @util.memoized_property def sql_compiler(self): return self.dialect.statement_compiler( @@ -4693,7 +4742,7 @@ class DDLCompiler(Compiled): @util.memoized_property def type_compiler(self): - return self.dialect.type_compiler + return self.dialect.type_compiler_instance def construct_params( self, @@ -5010,7 +5059,7 @@ class DDLCompiler(Compiled): colspec = ( self.preparer.format_column(column) + " " - + self.dialect.type_compiler.process( + + self.dialect.type_compiler_instance.process( column.type, type_expression=column ) ) |