diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 447 |
1 files changed, 328 insertions, 119 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 423c3d446..f28dceefc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -35,14 +35,19 @@ from time import perf_counter import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict +from typing import FrozenSet +from typing import Iterable from typing import List from typing import Mapping from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type from typing import Union from . import base @@ -54,19 +59,42 @@ from . import operators from . import schema from . import selectable from . import sqltypes +from .base import _from_objects from .base import NO_ARG -from .base import prefix_anon_map from .elements import quoted_name from .schema import Column +from .sqltypes import TupleType from .type_api import TypeEngine +from .visitors import prefix_anon_map from .. import exc from .. import util from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import TypedDict if typing.TYPE_CHECKING: + from .annotation import _AnnotationDict + from .base import _AmbiguousTableNameMap + from .base import CompileState + from .cache_key import CacheKey + from .elements import BindParameter + from .elements import ColumnClause + from .elements import Label + from .functions import Function + from .selectable import Alias + from .selectable import AliasedReturnsRows + from .selectable import CompoundSelectState from .selectable import CTE from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import SelectState + from ..engine.cursor import CursorResultMetaData from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _MutableCoreSingleExecuteParams + from ..engine.interfaces import _SchemaTranslateMapType from ..engine.result import _ProcessorType _FromHintsType = Dict["FromClause", str] @@ -236,7 +264,7 @@ OPERATORS = { operators.nulls_last_op: " NULLS LAST", } -FUNCTIONS = { +FUNCTIONS: Dict[Type[Function], str] = { functions.coalesce: "coalesce", functions.current_date: "CURRENT_DATE", functions.current_time: "CURRENT_TIME", @@ -298,8 +326,8 @@ class ResultColumnsEntry(NamedTuple): name: str """column name, may be labeled""" - objects: List[Any] - """list of objects that should be able to locate this column + objects: Tuple[Any, ...] + """sequence of objects that should be able to locate this column in a RowMapping. This is typically string names and aliases as well as Column objects. @@ -313,6 +341,17 @@ class ResultColumnsEntry(NamedTuple): """ +class _ResultMapAppender(Protocol): + def __call__( + self, + keyname: str, + name: str, + objects: Sequence[Any], + type_: TypeEngine[Any], + ) -> None: + ... + + # integer indexes into ResultColumnsEntry used by cursor.py. # some profiling showed integer access faster than named tuple RM_RENDERED_NAME: Literal[0] = 0 @@ -321,6 +360,20 @@ RM_OBJECTS: Literal[2] = 2 RM_TYPE: Literal[3] = 3 +class _BaseCompilerStackEntry(TypedDict): + asfrom_froms: Set[FromClause] + correlate_froms: Set[FromClause] + selectable: ReturnsRows + + +class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): + compile_state: CompileState + need_result_map_for_nested: bool + need_result_map_for_compound: bool + select_0: ReturnsRows + insert_from_select: Select + + class ExpandedState(NamedTuple): statement: str additional_parameters: _CoreSingleExecuteParams @@ -427,21 +480,23 @@ class Compiled: defaults. """ - _cached_metadata = None + _cached_metadata: Optional[CursorResultMetaData] = None _result_columns: Optional[List[ResultColumnsEntry]] = None - schema_translate_map = None + schema_translate_map: Optional[_SchemaTranslateMapType] = None - execution_options = util.EMPTY_DICT + execution_options: _ExecuteOptions = util.EMPTY_DICT """ Execution options propagated from the statement. In some cases, sub-elements of the statement can modify these. """ - _annotations = util.EMPTY_DICT + preparer: IdentifierPreparer + + _annotations: _AnnotationDict = util.EMPTY_DICT - compile_state = None + compile_state: Optional[CompileState] = None """Optional :class:`.CompileState` object that maintains additional state used by the compiler. @@ -457,9 +512,21 @@ class Compiled: """ - cache_key = None + cache_key: Optional[CacheKey] = None + """The :class:`.CacheKey` that was generated ahead of creating this + :class:`.Compiled` object. + + This is used for routines that need access to the original + :class:`.CacheKey` instance generated when the :class:`.Compiled` + instance was first cached, typically in order to reconcile + the original list of :class:`.BindParameter` objects with a + per-statement list that's generated on each call. + + """ _gen_time: float + """Generation time of this :class:`.Compiled`, used for reporting + cache stats.""" def __init__( self, @@ -543,7 +610,11 @@ class Compiled: return self.string or "" - def construct_params(self, params=None, extracted_parameters=None): + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + ) -> Optional[_MutableCoreSingleExecuteParams]: """Return the bind params for this compiled object. :param params: a dict of string/object pairs whose values will @@ -646,6 +717,17 @@ class SQLCompiler(Compiled): isplaintext: bool = False + binds: Dict[str, BindParameter[Any]] + """a dictionary of bind parameter keys to BindParameter instances.""" + + bind_names: Dict[BindParameter[Any], str] + """a dictionary of BindParameter instances to "compiled" names + that are actually present in the generated SQL""" + + stack: List[_CompilerStackEntry] + """major statements such as SELECT, INSERT, UPDATE, DELETE are + tracked in this stack using an entry format.""" + result_columns: List[ResultColumnsEntry] """relates label names in the final SQL to a tuple of local column/label name, ColumnElement object (if any) and @@ -709,7 +791,7 @@ class SQLCompiler(Compiled): """ - insert_single_values_expr = None + insert_single_values_expr: Optional[str] = None """When an INSERT is compiled with a single set of parameters inside a VALUES expression, the string is assigned here, where it can be used for insert batching schemes to rewrite the VALUES expression. @@ -718,19 +800,19 @@ class SQLCompiler(Compiled): """ - literal_execute_params = frozenset() + literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset() """bindparameter objects that are rendered as literal values at statement execution time. """ - post_compile_params = frozenset() + post_compile_params: FrozenSet[BindParameter[Any]] = frozenset() """bindparameter objects that are rendered as bound parameter placeholders at statement execution time. """ - escaped_bind_names = util.EMPTY_DICT + escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT """Late escaping of bound parameter names that has to be converted to the original name when looking in the parameter dictionary. @@ -744,14 +826,25 @@ class SQLCompiler(Compiled): """if True, and this in insert, use cursor.lastrowid to populate result.inserted_primary_key. """ - _cache_key_bind_match = None + _cache_key_bind_match: Optional[ + Tuple[ + Dict[ + BindParameter[Any], + List[BindParameter[Any]], + ], + Dict[ + str, + BindParameter[Any], + ], + ] + ] = None """a mapping that will relate the BindParameter object we compile to those that are part of the extracted collection of parameters in the cache key, if we were given a cache key. """ - positiontup: Optional[Sequence[str]] = None + positiontup: Optional[List[str]] = None """for a compiled construct that uses a positional paramstyle, will be a sequence of strings, indicating the names of bound parameters in order. @@ -768,6 +861,19 @@ class SQLCompiler(Compiled): inline: bool = False + ctes: Optional[MutableMapping[CTE, str]] + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + ctes_by_level_name: Dict[Tuple[int, str], CTE] + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name, cte_opts)] + level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]] + + ctes_recursive: bool + cte_positional: Dict[CTE, List[str]] + def __init__( self, dialect, @@ -804,10 +910,9 @@ class SQLCompiler(Compiled): self.cache_key = cache_key if cache_key: - self._cache_key_bind_match = ckbm = { - b.key: b for b in cache_key[1] - } - ckbm.update({b: [b] for b in cache_key[1]}) + cksm = {b.key: b for b in cache_key[1]} + ckbm = {b: [b] for b in cache_key[1]} + self._cache_key_bind_match = (ckbm, cksm) # compile INSERT/UPDATE defaults/sequences to expect executemany # style execution, which may mean no pre-execute of defaults, @@ -911,14 +1016,14 @@ class SQLCompiler(Compiled): @property def prefetch(self): - return list(self.insert_prefetch + self.update_prefetch) + return list(self.insert_prefetch) + list(self.update_prefetch) @util.memoized_property def _global_attributes(self): return {} @util.memoized_instancemethod - def _init_cte_state(self) -> None: + def _init_cte_state(self) -> MutableMapping[CTE, str]: """Initialize collections related to CTEs only if a CTE is located, to save on the overhead of these collections otherwise. @@ -926,21 +1031,22 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT # To store the query to print - Dict[cte, text_query] - self.ctes: MutableMapping[CTE, str] = util.OrderedDict() + ctes: MutableMapping[CTE, str] = util.OrderedDict() + self.ctes = ctes # Detect same CTE references - Dict[(level, name), cte] # Level is required for supporting nesting - self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {} + self.ctes_by_level_name = {} # To retrieve key/level in ctes_by_level_name - # Dict[cte_reference, (level, cte_name, cte_opts)] - self.level_name_by_cte: Dict[ - CTE, Tuple[int, str, selectable._CTEOpts] - ] = {} + self.level_name_by_cte = {} - self.ctes_recursive: bool = False + self.ctes_recursive = False if self.positional: - self.cte_positional: Dict[CTE, List[str]] = {} + self.cte_positional = {} + + return ctes @contextlib.contextmanager def _nested_result(self): @@ -985,7 +1091,7 @@ class SQLCompiler(Compiled): if not bindparam.type._is_tuple_type else tuple( elem_type._cached_bind_processor(self.dialect) - for elem_type in bindparam.type.types + for elem_type in cast(TupleType, bindparam.type).types ), ) for bindparam in self.bind_names @@ -1002,11 +1108,11 @@ class SQLCompiler(Compiled): def construct_params( self, - params=None, - _group_number=None, - _check=True, - extracted_parameters=None, - ): + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + _group_number: Optional[int] = None, + _check: bool = True, + ) -> _MutableCoreSingleExecuteParams: """return a dictionary of bind parameter keys and values""" has_escaped_names = bool(self.escaped_bind_names) @@ -1018,15 +1124,17 @@ class SQLCompiler(Compiled): # way. The parameters present in self.bind_names may be clones of # these original cache key params in the case of DML but the .key # will be guaranteed to match. - try: - orig_extracted = self.cache_key[1] - except TypeError as err: + if self.cache_key is None: raise exc.CompileError( "This compiled object has no original cache key; " "can't pass extracted_parameters to construct_params" - ) from err + ) + else: + orig_extracted = self.cache_key[1] - ckbm = self._cache_key_bind_match + ckbm_tuple = self._cache_key_bind_match + assert ckbm_tuple is not None + ckbm, _ = ckbm_tuple resolved_extracted = { bind: extracted for b, extracted in zip(orig_extracted, extracted_parameters) @@ -1142,7 +1250,8 @@ class SQLCompiler(Compiled): if bindparam.type._is_tuple_type: inputsizes[bindparam] = [ - lookup_type(typ) for typ in bindparam.type.types + lookup_type(typ) + for typ in cast(TupleType, bindparam.type).types ] else: inputsizes[bindparam] = lookup_type(bindparam.type) @@ -1164,7 +1273,7 @@ class SQLCompiler(Compiled): def _process_parameters_for_postcompile( self, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_MutableCoreSingleExecuteParams] = None, _populate_self: bool = False, ) -> ExpandedState: """handle special post compile parameters. @@ -1183,14 +1292,20 @@ class SQLCompiler(Compiled): parameters = self.construct_params() expanded_parameters = {} + positiontup: Optional[List[str]] + if self.positional: positiontup = [] else: positiontup = None processors = self._bind_processors + single_processors = cast("Mapping[str, _ProcessorType]", processors) + tuple_processors = cast( + "Mapping[str, Sequence[_ProcessorType]]", processors + ) - new_processors = {} + new_processors: Dict[str, _ProcessorType] = {} if self.positional and self._numeric_binds: # I'm not familiar with any DBAPI that uses 'numeric'. @@ -1203,8 +1318,8 @@ class SQLCompiler(Compiled): "the 'numeric' paramstyle at this time." ) - replacement_expressions = {} - to_update_sets = {} + replacement_expressions: Dict[str, Any] = {} + to_update_sets: Dict[str, Any] = {} # notes: # *unescaped* parameter names in: @@ -1213,9 +1328,12 @@ class SQLCompiler(Compiled): # *escaped* parameter names in: # construct_params(), replacement_expressions - for name in ( - self.positiontup if self.positional else self.bind_names.values() - ): + if self.positional and self.positiontup is not None: + names: Iterable[str] = self.positiontup + else: + names = self.bind_names.values() + + for name in names: escaped_name = ( self.escaped_bind_names.get(name, name) if self.escaped_bind_names @@ -1236,6 +1354,7 @@ class SQLCompiler(Compiled): if parameter in self.post_compile_params: if escaped_name in replacement_expressions: to_update = to_update_sets[escaped_name] + values = None else: # we are removing the parameter from parameters # because it is a list value, which is not expected by @@ -1256,28 +1375,29 @@ class SQLCompiler(Compiled): if not parameter.literal_execute: parameters.update(to_update) if parameter.type._is_tuple_type: + assert values is not None new_processors.update( ( "%s_%s_%s" % (name, i, j), - processors[name][j - 1], + tuple_processors[name][j - 1], ) for i, tuple_element in enumerate(values, 1) - for j, value in enumerate(tuple_element, 1) - if name in processors - and processors[name][j - 1] is not None + for j, _ in enumerate(tuple_element, 1) + if name in tuple_processors + and tuple_processors[name][j - 1] is not None ) else: new_processors.update( - (key, processors[name]) - for key, value in to_update - if name in processors + (key, single_processors[name]) + for key, _ in to_update + if name in single_processors ) - if self.positional: - positiontup.extend(name for name, value in to_update) + if positiontup is not None: + positiontup.extend(name for name, _ in to_update) expanded_parameters[name] = [ - expand_key for expand_key, value in to_update + expand_key for expand_key, _ in to_update ] - elif self.positional: + elif positiontup is not None: positiontup.append(name) def process_expanding(m): @@ -1315,7 +1435,7 @@ class SQLCompiler(Compiled): # special use cases. self.string = expanded_state.statement self._bind_processors.update(expanded_state.processors) - self.positiontup = expanded_state.positiontup + self.positiontup = list(expanded_state.positiontup or ()) self.post_compile_params = frozenset() for key in expanded_state.parameter_expansion: bind = self.binds.pop(key) @@ -1338,6 +1458,12 @@ class SQLCompiler(Compiled): self._result_columns ) + _key_getters_for_crud_column: Tuple[ + Callable[[Union[str, Column[Any]]], str], + Callable[[Column[Any]], str], + Callable[[Column[Any]], str], + ] + @util.memoized_property def _within_exec_param_key_getter(self) -> Callable[[Any], str]: getter = self._key_getters_for_crud_column[2] @@ -1398,22 +1524,30 @@ class SQLCompiler(Compiled): @util.memoized_property @util.preload_module("sqlalchemy.engine.result") def _inserted_primary_key_from_returning_getter(self): - result = util.preloaded.engine_result + if typing.TYPE_CHECKING: + from ..engine import result + else: + result = util.preloaded.engine_result param_key_getter = self._within_exec_param_key_getter table = self.statement.table - ret = {col: idx for idx, col in enumerate(self.returning)} + returning = self.returning + assert returning is not None + ret = {col: idx for idx, col in enumerate(returning)} - getters = [ - (operator.itemgetter(ret[col]), True) - if col in ret - else ( - operator.methodcaller("get", param_key_getter(col), None), - False, - ) - for col in table.primary_key - ] + getters = cast( + "List[Tuple[Callable[[Any], Any], bool]]", + [ + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller("get", param_key_getter(col), None), + False, + ) + for col in table.primary_key + ], + ) row_fn = result.result_tuple([col.key for col in table.primary_key]) @@ -1444,7 +1578,16 @@ class SQLCompiler(Compiled): self, element, within_columns_clause=False, **kwargs ): if self.stack and self.dialect.supports_simple_order_by_label: - compile_state = self.stack[-1]["compile_state"] + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + raise exc.CompileError( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ) from ke ( with_cols, @@ -1485,7 +1628,22 @@ class SQLCompiler(Compiled): # compiling the element outside of the context of a SELECT return self.process(element._text_clause) - compile_state = self.stack[-1]["compile_state"] + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + coercions._no_text_coercion( + element.element, + extra=( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ), + exc_cls=exc.CompileError, + err=ke, + ) + with_cols, only_froms, only_cols = compile_state._label_resolve_dict try: if within_columns_clause: @@ -1568,13 +1726,13 @@ class SQLCompiler(Compiled): def visit_column( self, - column, - add_to_result_map=None, - include_table=True, - result_map_targets=(), - ambiguous_table_name_map=None, - **kwargs, - ): + column: ColumnClause[Any], + add_to_result_map: Optional[_ResultMapAppender] = None, + include_table: bool = True, + result_map_targets: Tuple[Any, ...] = (), + ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, + **kwargs: Any, + ) -> str: name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -1608,7 +1766,8 @@ class SQLCompiler(Compiled): ) else: schema_prefix = "" - tablename = table.name + + tablename = cast("NamedFromClause", table).name if ( not effective_schema @@ -1678,7 +1837,7 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - new_entry = { + new_entry: _CompilerStackEntry = { "correlate_froms": set(), "asfrom_froms": set(), "selectable": taf, @@ -1879,11 +2038,19 @@ class SQLCompiler(Compiled): compiled_col = self.visit_column(element, **kw) return "(%s).%s" % (compiled_fn, compiled_col) - def visit_function(self, func, add_to_result_map=None, **kwargs): + def visit_function( + self, + func: Function, + add_to_result_map: Optional[_ResultMapAppender] = None, + **kwargs: Any, + ) -> str: if add_to_result_map is not None: add_to_result_map(func.name, func.name, (), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + + text: str + if disp: text = disp(func, **kwargs) else: @@ -1964,7 +2131,7 @@ class SQLCompiler(Compiled): if compound_stmt._independent_ctes: self._dispatch_independent_ctes(compound_stmt, kwargs) - keyword = self.compound_keywords.get(cs.keyword) + keyword = self.compound_keywords[cs.keyword] text = (" " + keyword + " ").join( ( @@ -2591,11 +2758,13 @@ class SQLCompiler(Compiled): # a different set of parameter values. here, we accommodate for # parameters that may have been cloned both before and after the cache # key was been generated. - ckbm = self._cache_key_bind_match - if ckbm: + ckbm_tuple = self._cache_key_bind_match + + if ckbm_tuple: + ckbm, cksm = ckbm_tuple for bp in bindparam._cloned_set: - if bp.key in ckbm: - cb = ckbm[bp.key] + if bp.key in cksm: + cb = cksm[bp.key] ckbm[cb].append(bindparam) if bindparam.isoutparam: @@ -2720,7 +2889,7 @@ class SQLCompiler(Compiled): if positional_names is not None: positional_names.append(name) else: - self.positiontup.append(name) + self.positiontup.append(name) # type: ignore[union-attr] elif not escaped_from: if _BIND_TRANSLATE_RE.search(name): @@ -2735,9 +2904,9 @@ class SQLCompiler(Compiled): name = new_name if escaped_from: - if not self.escaped_bind_names: - self.escaped_bind_names = {} - self.escaped_bind_names[escaped_from] = name + self.escaped_bind_names = self.escaped_bind_names.union( + {escaped_from: name} + ) if post_compile: return "__[POSTCOMPILE_%s]" % name @@ -2772,7 +2941,8 @@ class SQLCompiler(Compiled): cte_opts: selectable._CTEOpts = selectable._CTEOpts(False), **kwargs: Any, ) -> Optional[str]: - self._init_cte_state() + self_ctes = self._init_cte_state() + assert self_ctes is self.ctes kwargs["visiting_cte"] = cte @@ -2838,7 +3008,7 @@ class SQLCompiler(Compiled): # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". - del self.ctes[existing_cte] + del self_ctes[existing_cte] existing_cte_reference_cte = existing_cte._get_reference_cte() @@ -2875,7 +3045,7 @@ class SQLCompiler(Compiled): if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) - if not cte_pre_alias_name and cte not in self.ctes: + if not cte_pre_alias_name and cte not in self_ctes: if cte.recursive: self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) @@ -2942,14 +3112,14 @@ class SQLCompiler(Compiled): cte, cte._suffixes, **kwargs ) - self.ctes[cte] = text + self_ctes[cte] = text if asfrom: if from_linter: from_linter.froms[cte] = cte_name if not is_new_cte and embedded_in_current_named_cte: - return self.preparer.format_alias(cte, cte_name) + return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501 if cte_pre_alias_name: text = self.preparer.format_alias(cte, cte_pre_alias_name) @@ -2960,6 +3130,8 @@ class SQLCompiler(Compiled): else: return self.preparer.format_alias(cte, cte_name) + return None + def visit_table_valued_alias(self, element, **kw): if element._is_lateral: return self.visit_lateral(element, **kw) @@ -3143,7 +3315,7 @@ class SQLCompiler(Compiled): self, keyname: str, name: str, - objects: List[Any], + objects: Tuple[Any, ...], type_: TypeEngine[Any], ) -> None: if keyname is None or keyname == "*": @@ -3358,9 +3530,12 @@ class SQLCompiler(Compiled): def get_statement_hint_text(self, hint_texts): return " ".join(hint_texts) - _default_stack_entry = util.immutabledict( - [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] - ) + _default_stack_entry: _CompilerStackEntry + + if not typing.TYPE_CHECKING: + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) def _display_froms_for_select( self, select_stmt, asfrom, lateral=False, **kw @@ -3391,7 +3566,7 @@ class SQLCompiler(Compiled): ) return froms - translate_select_structure = None + translate_select_structure: Any = None """if not ``None``, should be a callable which accepts ``(select_stmt, **kw)`` and returns a select object. this is used for structural changes mostly to accommodate for LIMIT/OFFSET schemes @@ -3563,7 +3738,9 @@ class SQLCompiler(Compiled): ) self._result_columns = [ - (key, name, tuple(translate.get(o, o) for o in obj), type_) + ResultColumnsEntry( + key, name, tuple(translate.get(o, o) for o in obj), type_ + ) for key, name, obj, type_ in self._result_columns ] @@ -3660,10 +3837,10 @@ class SQLCompiler(Compiled): implicit_correlate_froms=asfrom_froms, ) - new_correlate_froms = set(selectable._from_objects(*froms)) + new_correlate_froms = set(_from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) - new_entry = { + new_entry: _CompilerStackEntry = { "asfrom_froms": new_correlate_froms, "correlate_froms": all_correlate_froms, "selectable": select, @@ -3734,6 +3911,7 @@ class SQLCompiler(Compiled): text += " \nWHERE " + t if warn_linting: + assert from_linter is not None from_linter.warn() if select._group_by_clauses: @@ -3781,6 +3959,8 @@ class SQLCompiler(Compiled): if not self.ctes: return "" + ctes: MutableMapping[CTE, str] + if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): @@ -3805,10 +3985,16 @@ class SQLCompiler(Compiled): ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: + assert self.positiontup is not None self.positiontup = ( - sum([self.cte_positional[cte] for cte in ctes], []) + list( + itertools.chain.from_iterable( + self.cte_positional[cte] for cte in ctes + ) + ) + self.positiontup ) + cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " @@ -4190,7 +4376,7 @@ class SQLCompiler(Compiled): if is_multitable: # main table might be a JOIN - main_froms = set(selectable._from_objects(update_stmt.table)) + main_froms = set(_from_objects(update_stmt.table)) render_extra_froms = [ f for f in extra_froms if f not in main_froms ] @@ -4506,7 +4692,11 @@ class DDLCompiler(Compiled): def type_compiler(self): return self.dialect.type_compiler - def construct_params(self, params=None, extracted_parameters=None): + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + ) -> Optional[_MutableCoreSingleExecuteParams]: return None def visit_ddl(self, ddl, **kwargs): @@ -5199,6 +5389,11 @@ class StrSQLTypeCompiler(GenericTypeCompiler): return get_col_spec(**kw) +class _SchemaForObjectCallable(Protocol): + def __call__(self, obj: Any) -> str: + ... + + class IdentifierPreparer: """Handle quoting and case-folding of identifiers based on options.""" @@ -5209,7 +5404,13 @@ class IdentifierPreparer: illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - schema_for_object = operator.attrgetter("schema") + initial_quote: str + + final_quote: str + + _strings: MutableMapping[str, str] + + schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema") """Return the .schema attribute for an object. For the default IdentifierPreparer, the schema for an object is always @@ -5297,7 +5498,7 @@ class IdentifierPreparer: return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement) - def _escape_identifier(self, value): + def _escape_identifier(self, value: str) -> str: """Escape an identifier. Subclasses should override this to provide database-dependent @@ -5309,7 +5510,7 @@ class IdentifierPreparer: value = value.replace("%", "%%") return value - def _unescape_identifier(self, value): + def _unescape_identifier(self, value: str) -> str: """Canonicalize an escaped identifier. Subclasses should override this to provide database-dependent @@ -5336,7 +5537,7 @@ class IdentifierPreparer: ) return element - def quote_identifier(self, value): + def quote_identifier(self, value: str) -> str: """Quote an identifier. Subclasses should override this to provide database-dependent @@ -5349,7 +5550,7 @@ class IdentifierPreparer: + self.final_quote ) - def _requires_quotes(self, value): + def _requires_quotes(self, value: str) -> bool: """Return True if the given identifier requires quoting.""" lc_value = value.lower() return ( @@ -5364,7 +5565,7 @@ class IdentifierPreparer: not taking case convention into account.""" return not self.legal_characters.match(str(value)) - def quote_schema(self, schema, force=None): + def quote_schema(self, schema: str, force: Any = None) -> str: """Conditionally quote a schema name. @@ -5403,7 +5604,7 @@ class IdentifierPreparer: return self.quote(schema) - def quote(self, ident, force=None): + def quote(self, ident: str, force: Any = None) -> str: """Conditionally quote an identifier. The identifier is quoted if it is a reserved word, contains @@ -5474,11 +5675,19 @@ class IdentifierPreparer: name = self.quote_schema(effective_schema) + "." + name return name - def format_label(self, label, name=None): + def format_label( + self, label: Label[Any], name: Optional[str] = None + ) -> str: return self.quote(name or label.name) - def format_alias(self, alias, name=None): - return self.quote(name or alias.name) + def format_alias( + self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None + ) -> str: + if name is None: + assert alias is not None + return self.quote(alias.name) + else: + return self.quote(name) def format_savepoint(self, savepoint, name=None): # Running the savepoint name through quoting is unnecessary |