summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-17 16:18:55 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-19 23:15:15 -0400
commit6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f (patch)
treeae142d45de71d1ebd43df1a38e54e1d3cf1063ec /lib/sqlalchemy/sql/compiler.py
parentc2fe4a264003933ff895c51f5d07a8456ac86382 (diff)
downloadsqlalchemy-6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f.tar.gz
pep 484 for types
strict types type_api.py, including TypeDecorator, NativeForEmulated, etc. Change-Id: Ib2eba26de0981324a83733954cb7044a29bbd7db
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py111
1 files changed, 80 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8f878b66c..f8019b9c6 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -43,6 +43,7 @@ from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Set
@@ -51,6 +52,7 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import Union
+from sqlalchemy.sql.ddl import DDLElement
from . import base
from . import coercions
from . import crud
@@ -61,7 +63,9 @@ from . import schema
from . import selectable
from . import sqltypes
from .base import _from_objects
+from .base import Executable
from .base import NO_ARG
+from .elements import ClauseElement
from .elements import quoted_name
from .schema import Column
from .sqltypes import TupleType
@@ -78,6 +82,10 @@ if typing.TYPE_CHECKING:
from .base import _AmbiguousTableNameMap
from .base import CompileState
from .cache_key import CacheKey
+ from .dml import Insert
+ from .dml import UpdateBase
+ from .dml import ValuesBase
+ from .elements import _truncated_label
from .elements import BindParameter
from .elements import ColumnClause
from .elements import Label
@@ -91,12 +99,13 @@ if typing.TYPE_CHECKING:
from .selectable import ReturnsRows
from .selectable import Select
from .selectable import SelectState
+ from .type_api import _BindProcessorType
from ..engine.cursor import CursorResultMetaData
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptions
from ..engine.interfaces import _MutableCoreSingleExecuteParams
from ..engine.interfaces import _SchemaTranslateMapType
- from ..engine.result import _ProcessorType
+ from ..engine.interfaces import Dialect
_FromHintsType = Dict["FromClause", str]
@@ -378,7 +387,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
class ExpandedState(NamedTuple):
statement: str
additional_parameters: _CoreSingleExecuteParams
- processors: Mapping[str, _ProcessorType]
+ processors: Mapping[str, _BindProcessorType[Any]]
positiontup: Optional[Sequence[str]]
parameter_expansion: Mapping[str, List[str]]
@@ -531,11 +540,11 @@ class Compiled:
def __init__(
self,
- dialect,
- statement,
- schema_translate_map=None,
- render_schema_translate=False,
- compile_kwargs=util.immutabledict(),
+ dialect: Dialect,
+ statement: Optional[ClauseElement],
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+ render_schema_translate: bool = False,
+ compile_kwargs: Mapping[str, Any] = util.immutabledict(),
):
"""Construct a new :class:`.Compiled` object.
@@ -571,6 +580,8 @@ class Compiled:
self.can_execute = statement.supports_execution
self._annotations = statement._annotations
if self.can_execute:
+ if TYPE_CHECKING:
+ assert isinstance(statement, Executable)
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
@@ -636,10 +647,10 @@ class TypeCompiler(util.EnsureKWArg):
ensure_kwarg = r"visit_\w+"
- def __init__(self, dialect):
+ def __init__(self, dialect: Dialect):
self.dialect = dialect
- def process(self, type_, **kw):
+ def process(self, type_: TypeEngine[Any], **kw: Any) -> str:
if (
type_._variant_mapping
and self.dialect.name in type_._variant_mapping
@@ -647,7 +658,9 @@ class TypeCompiler(util.EnsureKWArg):
type_ = type_._variant_mapping[self.dialect.name]
return type_._compiler_dispatch(self, **kw)
- def visit_unsupported_compilation(self, element, err, **kw):
+ def visit_unsupported_compilation(
+ self, element: Any, err: Exception, **kw: Any
+ ) -> NoReturn:
raise exc.UnsupportedCompilationError(self, element) from err
@@ -877,13 +890,13 @@ class SQLCompiler(Compiled):
def __init__(
self,
- dialect,
- statement,
- cache_key=None,
- column_keys=None,
- for_executemany=False,
- linting=NO_LINTING,
- **kwargs,
+ dialect: Dialect,
+ statement: Optional[ClauseElement],
+ cache_key: Optional[CacheKey] = None,
+ column_keys: Optional[Sequence[str]] = None,
+ for_executemany: bool = False,
+ linting: Linting = NO_LINTING,
+ **kwargs: Any,
):
"""Construct a new :class:`.SQLCompiler` object.
@@ -954,15 +967,21 @@ class SQLCompiler(Compiled):
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
- self.truncated_names = {}
+ self.truncated_names: Dict[Tuple[str, str], str] = {}
+ self._truncated_counters: Dict[str, int] = {}
Compiled.__init__(self, dialect, statement, **kwargs)
if self.isinsert or self.isupdate or self.isdelete:
+ if TYPE_CHECKING:
+ assert isinstance(statement, UpdateBase)
+
if statement._returning:
self.returning = statement._returning
if self.isinsert or self.isupdate:
+ if TYPE_CHECKING:
+ assert isinstance(statement, ValuesBase)
if statement._inline:
self.inline = True
elif self.for_executemany and (
@@ -1082,9 +1101,14 @@ class SQLCompiler(Compiled):
@util.memoized_property
def _bind_processors(
self,
- ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]:
+ ) -> MutableMapping[
+ str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
+ ]:
+
+ # mypy is not able to see the two value types as the above Union,
+ # it just sees "object". don't know how to resolve
return dict(
- (key, value)
+ (key, value) # type: ignore
for key, value in (
(
self.bind_names[bindparam],
@@ -1301,12 +1325,14 @@ class SQLCompiler(Compiled):
positiontup = None
processors = self._bind_processors
- single_processors = cast("Mapping[str, _ProcessorType]", processors)
+ single_processors = cast(
+ "Mapping[str, _BindProcessorType[Any]]", processors
+ )
tuple_processors = cast(
- "Mapping[str, Sequence[_ProcessorType]]", processors
+ "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors
)
- new_processors: Dict[str, _ProcessorType] = {}
+ new_processors: Dict[str, _BindProcessorType[Any]] = {}
if self.positional and self._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'.
@@ -1484,6 +1510,10 @@ class SQLCompiler(Compiled):
result = util.preloaded.engine_result
param_key_getter = self._within_exec_param_key_getter
+
+ if TYPE_CHECKING:
+ assert isinstance(self.statement, Insert)
+
table = self.statement.table
getters = [
@@ -1530,6 +1560,9 @@ class SQLCompiler(Compiled):
else:
result = util.preloaded.engine_result
+ if TYPE_CHECKING:
+ assert isinstance(self.statement, Insert)
+
param_key_getter = self._within_exec_param_key_getter
table = self.statement.table
@@ -1796,7 +1829,9 @@ class SQLCompiler(Compiled):
def visit_typeclause(self, typeclause, **kw):
kw["type_expression"] = typeclause
kw["identifier_preparer"] = self.preparer
- return self.dialect.type_compiler.process(typeclause.type, **kw)
+ return self.dialect.type_compiler_instance.process(
+ typeclause.type, **kw
+ )
def post_process_text(self, text):
if self.preparer._double_percents:
@@ -2855,26 +2890,28 @@ class SQLCompiler(Compiled):
return bind_name
- def _truncated_identifier(self, ident_class, name):
+ def _truncated_identifier(
+ self, ident_class: str, name: _truncated_label
+ ) -> str:
if (ident_class, name) in self.truncated_names:
return self.truncated_names[(ident_class, name)]
anonname = name.apply_map(self.anon_map)
if len(anonname) > self.label_length - 6:
- counter = self.truncated_names.get(ident_class, 1)
+ counter = self._truncated_counters.get(ident_class, 1)
truncname = (
anonname[0 : max(self.label_length - 6, 0)]
+ "_"
+ hex(counter)[2:]
)
- self.truncated_names[ident_class] = counter + 1
+ self._truncated_counters[ident_class] = counter + 1
else:
truncname = anonname
self.truncated_names[(ident_class, name)] = truncname
return truncname
- def _anonymize(self, name):
+ def _anonymize(self, name: str) -> str:
return name % self.anon_map
def bindparam_string(
@@ -3221,7 +3258,7 @@ class SQLCompiler(Compiled):
% (
self.preparer.quote(col.name),
" %s"
- % self.dialect.type_compiler.process(
+ % self.dialect.type_compiler_instance.process(
col.type, **kwargs
)
if alias._render_derived_w_types
@@ -4685,6 +4722,18 @@ class StrSQLCompiler(SQLCompiler):
class DDLCompiler(Compiled):
+ if TYPE_CHECKING:
+
+ def __init__(
+ self,
+ dialect: Dialect,
+ statement: DDLElement,
+ schema_translate_map: Optional[_SchemaTranslateMapType] = ...,
+ render_schema_translate: bool = ...,
+ compile_kwargs: Mapping[str, Any] = ...,
+ ):
+ ...
+
@util.memoized_property
def sql_compiler(self):
return self.dialect.statement_compiler(
@@ -4693,7 +4742,7 @@ class DDLCompiler(Compiled):
@util.memoized_property
def type_compiler(self):
- return self.dialect.type_compiler
+ return self.dialect.type_compiler_instance
def construct_params(
self,
@@ -5010,7 +5059,7 @@ class DDLCompiler(Compiled):
colspec = (
self.preparer.format_column(column)
+ " "
- + self.dialect.type_compiler.process(
+ + self.dialect.type_compiler_instance.process(
column.type, type_expression=column
)
)