diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-08 17:14:41 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-13 15:29:20 -0400 |
commit | 769fa67d842035dd852ab8b6a26ea3f110a51131 (patch) | |
tree | 5c121caca336071091c6f5ea4c54743c92d6458a /lib/sqlalchemy/sql/compiler.py | |
parent | 77fc8216a74e6b2d0efc6591c6c735687bd10002 (diff) | |
download | sqlalchemy-769fa67d842035dd852ab8b6a26ea3f110a51131.tar.gz |
pep-484: sqlalchemy.sql pass one
sqlalchemy.sql will require many passes to get all
modules even gradually typed. Will have to pick and
choose what modules can be strictly typed vs. which
can be gradual.
in this patch, emphasis is on visitors.py, cache_key.py,
annotations.py for strict typing, compiler.py is on gradual
typing but has much more structure, in particular where it
connects with the outside world.
The work within compiler.py also reached back out to
engine/cursor.py , default.py quite a bit.
References: #6810
Change-Id: I6e8a29f6013fd216e43d45091bc193f8be0368fd
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 447 |
1 files changed, 328 insertions, 119 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 423c3d446..f28dceefc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -35,14 +35,19 @@ from time import perf_counter import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict +from typing import FrozenSet +from typing import Iterable from typing import List from typing import Mapping from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type from typing import Union from . import base @@ -54,19 +59,42 @@ from . import operators from . import schema from . import selectable from . import sqltypes +from .base import _from_objects from .base import NO_ARG -from .base import prefix_anon_map from .elements import quoted_name from .schema import Column +from .sqltypes import TupleType from .type_api import TypeEngine +from .visitors import prefix_anon_map from .. import exc from .. import util from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import TypedDict if typing.TYPE_CHECKING: + from .annotation import _AnnotationDict + from .base import _AmbiguousTableNameMap + from .base import CompileState + from .cache_key import CacheKey + from .elements import BindParameter + from .elements import ColumnClause + from .elements import Label + from .functions import Function + from .selectable import Alias + from .selectable import AliasedReturnsRows + from .selectable import CompoundSelectState from .selectable import CTE from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import SelectState + from ..engine.cursor import CursorResultMetaData from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _MutableCoreSingleExecuteParams + from ..engine.interfaces import _SchemaTranslateMapType from ..engine.result import _ProcessorType _FromHintsType = Dict["FromClause", str] @@ -236,7 +264,7 @@ OPERATORS = { operators.nulls_last_op: " NULLS LAST", } -FUNCTIONS = { +FUNCTIONS: Dict[Type[Function], str] = { functions.coalesce: "coalesce", functions.current_date: "CURRENT_DATE", functions.current_time: "CURRENT_TIME", @@ -298,8 +326,8 @@ class ResultColumnsEntry(NamedTuple): name: str """column name, may be labeled""" - objects: List[Any] - """list of objects that should be able to locate this column + objects: Tuple[Any, ...] + """sequence of objects that should be able to locate this column in a RowMapping. This is typically string names and aliases as well as Column objects. @@ -313,6 +341,17 @@ class ResultColumnsEntry(NamedTuple): """ +class _ResultMapAppender(Protocol): + def __call__( + self, + keyname: str, + name: str, + objects: Sequence[Any], + type_: TypeEngine[Any], + ) -> None: + ... + + # integer indexes into ResultColumnsEntry used by cursor.py. # some profiling showed integer access faster than named tuple RM_RENDERED_NAME: Literal[0] = 0 @@ -321,6 +360,20 @@ RM_OBJECTS: Literal[2] = 2 RM_TYPE: Literal[3] = 3 +class _BaseCompilerStackEntry(TypedDict): + asfrom_froms: Set[FromClause] + correlate_froms: Set[FromClause] + selectable: ReturnsRows + + +class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): + compile_state: CompileState + need_result_map_for_nested: bool + need_result_map_for_compound: bool + select_0: ReturnsRows + insert_from_select: Select + + class ExpandedState(NamedTuple): statement: str additional_parameters: _CoreSingleExecuteParams @@ -427,21 +480,23 @@ class Compiled: defaults. """ - _cached_metadata = None + _cached_metadata: Optional[CursorResultMetaData] = None _result_columns: Optional[List[ResultColumnsEntry]] = None - schema_translate_map = None + schema_translate_map: Optional[_SchemaTranslateMapType] = None - execution_options = util.EMPTY_DICT + execution_options: _ExecuteOptions = util.EMPTY_DICT """ Execution options propagated from the statement. In some cases, sub-elements of the statement can modify these. """ - _annotations = util.EMPTY_DICT + preparer: IdentifierPreparer + + _annotations: _AnnotationDict = util.EMPTY_DICT - compile_state = None + compile_state: Optional[CompileState] = None """Optional :class:`.CompileState` object that maintains additional state used by the compiler. @@ -457,9 +512,21 @@ class Compiled: """ - cache_key = None + cache_key: Optional[CacheKey] = None + """The :class:`.CacheKey` that was generated ahead of creating this + :class:`.Compiled` object. + + This is used for routines that need access to the original + :class:`.CacheKey` instance generated when the :class:`.Compiled` + instance was first cached, typically in order to reconcile + the original list of :class:`.BindParameter` objects with a + per-statement list that's generated on each call. + + """ _gen_time: float + """Generation time of this :class:`.Compiled`, used for reporting + cache stats.""" def __init__( self, @@ -543,7 +610,11 @@ class Compiled: return self.string or "" - def construct_params(self, params=None, extracted_parameters=None): + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + ) -> Optional[_MutableCoreSingleExecuteParams]: """Return the bind params for this compiled object. :param params: a dict of string/object pairs whose values will @@ -646,6 +717,17 @@ class SQLCompiler(Compiled): isplaintext: bool = False + binds: Dict[str, BindParameter[Any]] + """a dictionary of bind parameter keys to BindParameter instances.""" + + bind_names: Dict[BindParameter[Any], str] + """a dictionary of BindParameter instances to "compiled" names + that are actually present in the generated SQL""" + + stack: List[_CompilerStackEntry] + """major statements such as SELECT, INSERT, UPDATE, DELETE are + tracked in this stack using an entry format.""" + result_columns: List[ResultColumnsEntry] """relates label names in the final SQL to a tuple of local column/label name, ColumnElement object (if any) and @@ -709,7 +791,7 @@ class SQLCompiler(Compiled): """ - insert_single_values_expr = None + insert_single_values_expr: Optional[str] = None """When an INSERT is compiled with a single set of parameters inside a VALUES expression, the string is assigned here, where it can be used for insert batching schemes to rewrite the VALUES expression. @@ -718,19 +800,19 @@ class SQLCompiler(Compiled): """ - literal_execute_params = frozenset() + literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset() """bindparameter objects that are rendered as literal values at statement execution time. """ - post_compile_params = frozenset() + post_compile_params: FrozenSet[BindParameter[Any]] = frozenset() """bindparameter objects that are rendered as bound parameter placeholders at statement execution time. """ - escaped_bind_names = util.EMPTY_DICT + escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT """Late escaping of bound parameter names that has to be converted to the original name when looking in the parameter dictionary. @@ -744,14 +826,25 @@ class SQLCompiler(Compiled): """if True, and this in insert, use cursor.lastrowid to populate result.inserted_primary_key. """ - _cache_key_bind_match = None + _cache_key_bind_match: Optional[ + Tuple[ + Dict[ + BindParameter[Any], + List[BindParameter[Any]], + ], + Dict[ + str, + BindParameter[Any], + ], + ] + ] = None """a mapping that will relate the BindParameter object we compile to those that are part of the extracted collection of parameters in the cache key, if we were given a cache key. """ - positiontup: Optional[Sequence[str]] = None + positiontup: Optional[List[str]] = None """for a compiled construct that uses a positional paramstyle, will be a sequence of strings, indicating the names of bound parameters in order. @@ -768,6 +861,19 @@ class SQLCompiler(Compiled): inline: bool = False + ctes: Optional[MutableMapping[CTE, str]] + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + ctes_by_level_name: Dict[Tuple[int, str], CTE] + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name, cte_opts)] + level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]] + + ctes_recursive: bool + cte_positional: Dict[CTE, List[str]] + def __init__( self, dialect, @@ -804,10 +910,9 @@ class SQLCompiler(Compiled): self.cache_key = cache_key if cache_key: - self._cache_key_bind_match = ckbm = { - b.key: b for b in cache_key[1] - } - ckbm.update({b: [b] for b in cache_key[1]}) + cksm = {b.key: b for b in cache_key[1]} + ckbm = {b: [b] for b in cache_key[1]} + self._cache_key_bind_match = (ckbm, cksm) # compile INSERT/UPDATE defaults/sequences to expect executemany # style execution, which may mean no pre-execute of defaults, @@ -911,14 +1016,14 @@ class SQLCompiler(Compiled): @property def prefetch(self): - return list(self.insert_prefetch + self.update_prefetch) + return list(self.insert_prefetch) + list(self.update_prefetch) @util.memoized_property def _global_attributes(self): return {} @util.memoized_instancemethod - def _init_cte_state(self) -> None: + def _init_cte_state(self) -> MutableMapping[CTE, str]: """Initialize collections related to CTEs only if a CTE is located, to save on the overhead of these collections otherwise. @@ -926,21 +1031,22 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT # To store the query to print - Dict[cte, text_query] - self.ctes: MutableMapping[CTE, str] = util.OrderedDict() + ctes: MutableMapping[CTE, str] = util.OrderedDict() + self.ctes = ctes # Detect same CTE references - Dict[(level, name), cte] # Level is required for supporting nesting - self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {} + self.ctes_by_level_name = {} # To retrieve key/level in ctes_by_level_name - # Dict[cte_reference, (level, cte_name, cte_opts)] - self.level_name_by_cte: Dict[ - CTE, Tuple[int, str, selectable._CTEOpts] - ] = {} + self.level_name_by_cte = {} - self.ctes_recursive: bool = False + self.ctes_recursive = False if self.positional: - self.cte_positional: Dict[CTE, List[str]] = {} + self.cte_positional = {} + + return ctes @contextlib.contextmanager def _nested_result(self): @@ -985,7 +1091,7 @@ class SQLCompiler(Compiled): if not bindparam.type._is_tuple_type else tuple( elem_type._cached_bind_processor(self.dialect) - for elem_type in bindparam.type.types + for elem_type in cast(TupleType, bindparam.type).types ), ) for bindparam in self.bind_names @@ -1002,11 +1108,11 @@ class SQLCompiler(Compiled): def construct_params( self, - params=None, - _group_number=None, - _check=True, - extracted_parameters=None, - ): + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + _group_number: Optional[int] = None, + _check: bool = True, + ) -> _MutableCoreSingleExecuteParams: """return a dictionary of bind parameter keys and values""" has_escaped_names = bool(self.escaped_bind_names) @@ -1018,15 +1124,17 @@ class SQLCompiler(Compiled): # way. The parameters present in self.bind_names may be clones of # these original cache key params in the case of DML but the .key # will be guaranteed to match. - try: - orig_extracted = self.cache_key[1] - except TypeError as err: + if self.cache_key is None: raise exc.CompileError( "This compiled object has no original cache key; " "can't pass extracted_parameters to construct_params" - ) from err + ) + else: + orig_extracted = self.cache_key[1] - ckbm = self._cache_key_bind_match + ckbm_tuple = self._cache_key_bind_match + assert ckbm_tuple is not None + ckbm, _ = ckbm_tuple resolved_extracted = { bind: extracted for b, extracted in zip(orig_extracted, extracted_parameters) @@ -1142,7 +1250,8 @@ class SQLCompiler(Compiled): if bindparam.type._is_tuple_type: inputsizes[bindparam] = [ - lookup_type(typ) for typ in bindparam.type.types + lookup_type(typ) + for typ in cast(TupleType, bindparam.type).types ] else: inputsizes[bindparam] = lookup_type(bindparam.type) @@ -1164,7 +1273,7 @@ class SQLCompiler(Compiled): def _process_parameters_for_postcompile( self, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_MutableCoreSingleExecuteParams] = None, _populate_self: bool = False, ) -> ExpandedState: """handle special post compile parameters. @@ -1183,14 +1292,20 @@ class SQLCompiler(Compiled): parameters = self.construct_params() expanded_parameters = {} + positiontup: Optional[List[str]] + if self.positional: positiontup = [] else: positiontup = None processors = self._bind_processors + single_processors = cast("Mapping[str, _ProcessorType]", processors) + tuple_processors = cast( + "Mapping[str, Sequence[_ProcessorType]]", processors + ) - new_processors = {} + new_processors: Dict[str, _ProcessorType] = {} if self.positional and self._numeric_binds: # I'm not familiar with any DBAPI that uses 'numeric'. @@ -1203,8 +1318,8 @@ class SQLCompiler(Compiled): "the 'numeric' paramstyle at this time." ) - replacement_expressions = {} - to_update_sets = {} + replacement_expressions: Dict[str, Any] = {} + to_update_sets: Dict[str, Any] = {} # notes: # *unescaped* parameter names in: @@ -1213,9 +1328,12 @@ class SQLCompiler(Compiled): # *escaped* parameter names in: # construct_params(), replacement_expressions - for name in ( - self.positiontup if self.positional else self.bind_names.values() - ): + if self.positional and self.positiontup is not None: + names: Iterable[str] = self.positiontup + else: + names = self.bind_names.values() + + for name in names: escaped_name = ( self.escaped_bind_names.get(name, name) if self.escaped_bind_names @@ -1236,6 +1354,7 @@ class SQLCompiler(Compiled): if parameter in self.post_compile_params: if escaped_name in replacement_expressions: to_update = to_update_sets[escaped_name] + values = None else: # we are removing the parameter from parameters # because it is a list value, which is not expected by @@ -1256,28 +1375,29 @@ class SQLCompiler(Compiled): if not parameter.literal_execute: parameters.update(to_update) if parameter.type._is_tuple_type: + assert values is not None new_processors.update( ( "%s_%s_%s" % (name, i, j), - processors[name][j - 1], + tuple_processors[name][j - 1], ) for i, tuple_element in enumerate(values, 1) - for j, value in enumerate(tuple_element, 1) - if name in processors - and processors[name][j - 1] is not None + for j, _ in enumerate(tuple_element, 1) + if name in tuple_processors + and tuple_processors[name][j - 1] is not None ) else: new_processors.update( - (key, processors[name]) - for key, value in to_update - if name in processors + (key, single_processors[name]) + for key, _ in to_update + if name in single_processors ) - if self.positional: - positiontup.extend(name for name, value in to_update) + if positiontup is not None: + positiontup.extend(name for name, _ in to_update) expanded_parameters[name] = [ - expand_key for expand_key, value in to_update + expand_key for expand_key, _ in to_update ] - elif self.positional: + elif positiontup is not None: positiontup.append(name) def process_expanding(m): @@ -1315,7 +1435,7 @@ class SQLCompiler(Compiled): # special use cases. self.string = expanded_state.statement self._bind_processors.update(expanded_state.processors) - self.positiontup = expanded_state.positiontup + self.positiontup = list(expanded_state.positiontup or ()) self.post_compile_params = frozenset() for key in expanded_state.parameter_expansion: bind = self.binds.pop(key) @@ -1338,6 +1458,12 @@ class SQLCompiler(Compiled): self._result_columns ) + _key_getters_for_crud_column: Tuple[ + Callable[[Union[str, Column[Any]]], str], + Callable[[Column[Any]], str], + Callable[[Column[Any]], str], + ] + @util.memoized_property def _within_exec_param_key_getter(self) -> Callable[[Any], str]: getter = self._key_getters_for_crud_column[2] @@ -1398,22 +1524,30 @@ class SQLCompiler(Compiled): @util.memoized_property @util.preload_module("sqlalchemy.engine.result") def _inserted_primary_key_from_returning_getter(self): - result = util.preloaded.engine_result + if typing.TYPE_CHECKING: + from ..engine import result + else: + result = util.preloaded.engine_result param_key_getter = self._within_exec_param_key_getter table = self.statement.table - ret = {col: idx for idx, col in enumerate(self.returning)} + returning = self.returning + assert returning is not None + ret = {col: idx for idx, col in enumerate(returning)} - getters = [ - (operator.itemgetter(ret[col]), True) - if col in ret - else ( - operator.methodcaller("get", param_key_getter(col), None), - False, - ) - for col in table.primary_key - ] + getters = cast( + "List[Tuple[Callable[[Any], Any], bool]]", + [ + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller("get", param_key_getter(col), None), + False, + ) + for col in table.primary_key + ], + ) row_fn = result.result_tuple([col.key for col in table.primary_key]) @@ -1444,7 +1578,16 @@ class SQLCompiler(Compiled): self, element, within_columns_clause=False, **kwargs ): if self.stack and self.dialect.supports_simple_order_by_label: - compile_state = self.stack[-1]["compile_state"] + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + raise exc.CompileError( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ) from ke ( with_cols, @@ -1485,7 +1628,22 @@ class SQLCompiler(Compiled): # compiling the element outside of the context of a SELECT return self.process(element._text_clause) - compile_state = self.stack[-1]["compile_state"] + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + coercions._no_text_coercion( + element.element, + extra=( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ), + exc_cls=exc.CompileError, + err=ke, + ) + with_cols, only_froms, only_cols = compile_state._label_resolve_dict try: if within_columns_clause: @@ -1568,13 +1726,13 @@ class SQLCompiler(Compiled): def visit_column( self, - column, - add_to_result_map=None, - include_table=True, - result_map_targets=(), - ambiguous_table_name_map=None, - **kwargs, - ): + column: ColumnClause[Any], + add_to_result_map: Optional[_ResultMapAppender] = None, + include_table: bool = True, + result_map_targets: Tuple[Any, ...] = (), + ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, + **kwargs: Any, + ) -> str: name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -1608,7 +1766,8 @@ class SQLCompiler(Compiled): ) else: schema_prefix = "" - tablename = table.name + + tablename = cast("NamedFromClause", table).name if ( not effective_schema @@ -1678,7 +1837,7 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - new_entry = { + new_entry: _CompilerStackEntry = { "correlate_froms": set(), "asfrom_froms": set(), "selectable": taf, @@ -1879,11 +2038,19 @@ class SQLCompiler(Compiled): compiled_col = self.visit_column(element, **kw) return "(%s).%s" % (compiled_fn, compiled_col) - def visit_function(self, func, add_to_result_map=None, **kwargs): + def visit_function( + self, + func: Function, + add_to_result_map: Optional[_ResultMapAppender] = None, + **kwargs: Any, + ) -> str: if add_to_result_map is not None: add_to_result_map(func.name, func.name, (), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + + text: str + if disp: text = disp(func, **kwargs) else: @@ -1964,7 +2131,7 @@ class SQLCompiler(Compiled): if compound_stmt._independent_ctes: self._dispatch_independent_ctes(compound_stmt, kwargs) - keyword = self.compound_keywords.get(cs.keyword) + keyword = self.compound_keywords[cs.keyword] text = (" " + keyword + " ").join( ( @@ -2591,11 +2758,13 @@ class SQLCompiler(Compiled): # a different set of parameter values. here, we accommodate for # parameters that may have been cloned both before and after the cache # key was been generated. - ckbm = self._cache_key_bind_match - if ckbm: + ckbm_tuple = self._cache_key_bind_match + + if ckbm_tuple: + ckbm, cksm = ckbm_tuple for bp in bindparam._cloned_set: - if bp.key in ckbm: - cb = ckbm[bp.key] + if bp.key in cksm: + cb = cksm[bp.key] ckbm[cb].append(bindparam) if bindparam.isoutparam: @@ -2720,7 +2889,7 @@ class SQLCompiler(Compiled): if positional_names is not None: positional_names.append(name) else: - self.positiontup.append(name) + self.positiontup.append(name) # type: ignore[union-attr] elif not escaped_from: if _BIND_TRANSLATE_RE.search(name): @@ -2735,9 +2904,9 @@ class SQLCompiler(Compiled): name = new_name if escaped_from: - if not self.escaped_bind_names: - self.escaped_bind_names = {} - self.escaped_bind_names[escaped_from] = name + self.escaped_bind_names = self.escaped_bind_names.union( + {escaped_from: name} + ) if post_compile: return "__[POSTCOMPILE_%s]" % name @@ -2772,7 +2941,8 @@ class SQLCompiler(Compiled): cte_opts: selectable._CTEOpts = selectable._CTEOpts(False), **kwargs: Any, ) -> Optional[str]: - self._init_cte_state() + self_ctes = self._init_cte_state() + assert self_ctes is self.ctes kwargs["visiting_cte"] = cte @@ -2838,7 +3008,7 @@ class SQLCompiler(Compiled): # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". - del self.ctes[existing_cte] + del self_ctes[existing_cte] existing_cte_reference_cte = existing_cte._get_reference_cte() @@ -2875,7 +3045,7 @@ class SQLCompiler(Compiled): if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) - if not cte_pre_alias_name and cte not in self.ctes: + if not cte_pre_alias_name and cte not in self_ctes: if cte.recursive: self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) @@ -2942,14 +3112,14 @@ class SQLCompiler(Compiled): cte, cte._suffixes, **kwargs ) - self.ctes[cte] = text + self_ctes[cte] = text if asfrom: if from_linter: from_linter.froms[cte] = cte_name if not is_new_cte and embedded_in_current_named_cte: - return self.preparer.format_alias(cte, cte_name) + return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501 if cte_pre_alias_name: text = self.preparer.format_alias(cte, cte_pre_alias_name) @@ -2960,6 +3130,8 @@ class SQLCompiler(Compiled): else: return self.preparer.format_alias(cte, cte_name) + return None + def visit_table_valued_alias(self, element, **kw): if element._is_lateral: return self.visit_lateral(element, **kw) @@ -3143,7 +3315,7 @@ class SQLCompiler(Compiled): self, keyname: str, name: str, - objects: List[Any], + objects: Tuple[Any, ...], type_: TypeEngine[Any], ) -> None: if keyname is None or keyname == "*": @@ -3358,9 +3530,12 @@ class SQLCompiler(Compiled): def get_statement_hint_text(self, hint_texts): return " ".join(hint_texts) - _default_stack_entry = util.immutabledict( - [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] - ) + _default_stack_entry: _CompilerStackEntry + + if not typing.TYPE_CHECKING: + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) def _display_froms_for_select( self, select_stmt, asfrom, lateral=False, **kw @@ -3391,7 +3566,7 @@ class SQLCompiler(Compiled): ) return froms - translate_select_structure = None + translate_select_structure: Any = None """if not ``None``, should be a callable which accepts ``(select_stmt, **kw)`` and returns a select object. this is used for structural changes mostly to accommodate for LIMIT/OFFSET schemes @@ -3563,7 +3738,9 @@ class SQLCompiler(Compiled): ) self._result_columns = [ - (key, name, tuple(translate.get(o, o) for o in obj), type_) + ResultColumnsEntry( + key, name, tuple(translate.get(o, o) for o in obj), type_ + ) for key, name, obj, type_ in self._result_columns ] @@ -3660,10 +3837,10 @@ class SQLCompiler(Compiled): implicit_correlate_froms=asfrom_froms, ) - new_correlate_froms = set(selectable._from_objects(*froms)) + new_correlate_froms = set(_from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) - new_entry = { + new_entry: _CompilerStackEntry = { "asfrom_froms": new_correlate_froms, "correlate_froms": all_correlate_froms, "selectable": select, @@ -3734,6 +3911,7 @@ class SQLCompiler(Compiled): text += " \nWHERE " + t if warn_linting: + assert from_linter is not None from_linter.warn() if select._group_by_clauses: @@ -3781,6 +3959,8 @@ class SQLCompiler(Compiled): if not self.ctes: return "" + ctes: MutableMapping[CTE, str] + if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): @@ -3805,10 +3985,16 @@ class SQLCompiler(Compiled): ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: + assert self.positiontup is not None self.positiontup = ( - sum([self.cte_positional[cte] for cte in ctes], []) + list( + itertools.chain.from_iterable( + self.cte_positional[cte] for cte in ctes + ) + ) + self.positiontup ) + cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " @@ -4190,7 +4376,7 @@ class SQLCompiler(Compiled): if is_multitable: # main table might be a JOIN - main_froms = set(selectable._from_objects(update_stmt.table)) + main_froms = set(_from_objects(update_stmt.table)) render_extra_froms = [ f for f in extra_froms if f not in main_froms ] @@ -4506,7 +4692,11 @@ class DDLCompiler(Compiled): def type_compiler(self): return self.dialect.type_compiler - def construct_params(self, params=None, extracted_parameters=None): + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + ) -> Optional[_MutableCoreSingleExecuteParams]: return None def visit_ddl(self, ddl, **kwargs): @@ -5199,6 +5389,11 @@ class StrSQLTypeCompiler(GenericTypeCompiler): return get_col_spec(**kw) +class _SchemaForObjectCallable(Protocol): + def __call__(self, obj: Any) -> str: + ... + + class IdentifierPreparer: """Handle quoting and case-folding of identifiers based on options.""" @@ -5209,7 +5404,13 @@ class IdentifierPreparer: illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - schema_for_object = operator.attrgetter("schema") + initial_quote: str + + final_quote: str + + _strings: MutableMapping[str, str] + + schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema") """Return the .schema attribute for an object. For the default IdentifierPreparer, the schema for an object is always @@ -5297,7 +5498,7 @@ class IdentifierPreparer: return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement) - def _escape_identifier(self, value): + def _escape_identifier(self, value: str) -> str: """Escape an identifier. Subclasses should override this to provide database-dependent @@ -5309,7 +5510,7 @@ class IdentifierPreparer: value = value.replace("%", "%%") return value - def _unescape_identifier(self, value): + def _unescape_identifier(self, value: str) -> str: """Canonicalize an escaped identifier. Subclasses should override this to provide database-dependent @@ -5336,7 +5537,7 @@ class IdentifierPreparer: ) return element - def quote_identifier(self, value): + def quote_identifier(self, value: str) -> str: """Quote an identifier. Subclasses should override this to provide database-dependent @@ -5349,7 +5550,7 @@ class IdentifierPreparer: + self.final_quote ) - def _requires_quotes(self, value): + def _requires_quotes(self, value: str) -> bool: """Return True if the given identifier requires quoting.""" lc_value = value.lower() return ( @@ -5364,7 +5565,7 @@ class IdentifierPreparer: not taking case convention into account.""" return not self.legal_characters.match(str(value)) - def quote_schema(self, schema, force=None): + def quote_schema(self, schema: str, force: Any = None) -> str: """Conditionally quote a schema name. @@ -5403,7 +5604,7 @@ class IdentifierPreparer: return self.quote(schema) - def quote(self, ident, force=None): + def quote(self, ident: str, force: Any = None) -> str: """Conditionally quote an identifier. The identifier is quoted if it is a reserved word, contains @@ -5474,11 +5675,19 @@ class IdentifierPreparer: name = self.quote_schema(effective_schema) + "." + name return name - def format_label(self, label, name=None): + def format_label( + self, label: Label[Any], name: Optional[str] = None + ) -> str: return self.quote(name or label.name) - def format_alias(self, alias, name=None): - return self.quote(name or alias.name) + def format_alias( + self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None + ) -> str: + if name is None: + assert alias is not None + return self.quote(alias.name) + else: + return self.quote(name) def format_savepoint(self, savepoint, name=None): # Running the savepoint name through quoting is unnecessary |