summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-08 17:14:41 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-13 15:29:20 -0400
commit769fa67d842035dd852ab8b6a26ea3f110a51131 (patch)
tree5c121caca336071091c6f5ea4c54743c92d6458a /lib/sqlalchemy/sql/compiler.py
parent77fc8216a74e6b2d0efc6591c6c735687bd10002 (diff)
downloadsqlalchemy-769fa67d842035dd852ab8b6a26ea3f110a51131.tar.gz
pep-484: sqlalchemy.sql pass one
sqlalchemy.sql will require many passes to get all modules even gradually typed. Will have to pick and choose what modules can be strictly typed vs. which can be gradual. in this patch, emphasis is on visitors.py, cache_key.py, annotations.py for strict typing, compiler.py is on gradual typing but has much more structure, in particular where it connects with the outside world. The work within compiler.py also reached back out to engine/cursor.py , default.py quite a bit. References: #6810 Change-Id: I6e8a29f6013fd216e43d45091bc193f8be0368fd
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py447
1 files changed, 328 insertions, 119 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 423c3d446..f28dceefc 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -35,14 +35,19 @@ from time import perf_counter
import typing
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
+from typing import FrozenSet
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
from typing import Union
from . import base
@@ -54,19 +59,42 @@ from . import operators
from . import schema
from . import selectable
from . import sqltypes
+from .base import _from_objects
from .base import NO_ARG
-from .base import prefix_anon_map
from .elements import quoted_name
from .schema import Column
+from .sqltypes import TupleType
from .type_api import TypeEngine
+from .visitors import prefix_anon_map
from .. import exc
from .. import util
from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
+ from .annotation import _AnnotationDict
+ from .base import _AmbiguousTableNameMap
+ from .base import CompileState
+ from .cache_key import CacheKey
+ from .elements import BindParameter
+ from .elements import ColumnClause
+ from .elements import Label
+ from .functions import Function
+ from .selectable import Alias
+ from .selectable import AliasedReturnsRows
+ from .selectable import CompoundSelectState
from .selectable import CTE
from .selectable import FromClause
+ from .selectable import NamedFromClause
+ from .selectable import ReturnsRows
+ from .selectable import Select
+ from .selectable import SelectState
+ from ..engine.cursor import CursorResultMetaData
from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _MutableCoreSingleExecuteParams
+ from ..engine.interfaces import _SchemaTranslateMapType
from ..engine.result import _ProcessorType
_FromHintsType = Dict["FromClause", str]
@@ -236,7 +264,7 @@ OPERATORS = {
operators.nulls_last_op: " NULLS LAST",
}
-FUNCTIONS = {
+FUNCTIONS: Dict[Type[Function], str] = {
functions.coalesce: "coalesce",
functions.current_date: "CURRENT_DATE",
functions.current_time: "CURRENT_TIME",
@@ -298,8 +326,8 @@ class ResultColumnsEntry(NamedTuple):
name: str
"""column name, may be labeled"""
- objects: List[Any]
- """list of objects that should be able to locate this column
+ objects: Tuple[Any, ...]
+ """sequence of objects that should be able to locate this column
in a RowMapping. This is typically string names and aliases
as well as Column objects.
@@ -313,6 +341,17 @@ class ResultColumnsEntry(NamedTuple):
"""
+class _ResultMapAppender(Protocol):
+ def __call__(
+ self,
+ keyname: str,
+ name: str,
+ objects: Sequence[Any],
+ type_: TypeEngine[Any],
+ ) -> None:
+ ...
+
+
# integer indexes into ResultColumnsEntry used by cursor.py.
# some profiling showed integer access faster than named tuple
RM_RENDERED_NAME: Literal[0] = 0
@@ -321,6 +360,20 @@ RM_OBJECTS: Literal[2] = 2
RM_TYPE: Literal[3] = 3
+class _BaseCompilerStackEntry(TypedDict):
+ asfrom_froms: Set[FromClause]
+ correlate_froms: Set[FromClause]
+ selectable: ReturnsRows
+
+
+class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
+ compile_state: CompileState
+ need_result_map_for_nested: bool
+ need_result_map_for_compound: bool
+ select_0: ReturnsRows
+ insert_from_select: Select
+
+
class ExpandedState(NamedTuple):
statement: str
additional_parameters: _CoreSingleExecuteParams
@@ -427,21 +480,23 @@ class Compiled:
defaults.
"""
- _cached_metadata = None
+ _cached_metadata: Optional[CursorResultMetaData] = None
_result_columns: Optional[List[ResultColumnsEntry]] = None
- schema_translate_map = None
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None
- execution_options = util.EMPTY_DICT
+ execution_options: _ExecuteOptions = util.EMPTY_DICT
"""
Execution options propagated from the statement. In some cases,
sub-elements of the statement can modify these.
"""
- _annotations = util.EMPTY_DICT
+ preparer: IdentifierPreparer
+
+ _annotations: _AnnotationDict = util.EMPTY_DICT
- compile_state = None
+ compile_state: Optional[CompileState] = None
"""Optional :class:`.CompileState` object that maintains additional
state used by the compiler.
@@ -457,9 +512,21 @@ class Compiled:
"""
- cache_key = None
+ cache_key: Optional[CacheKey] = None
+ """The :class:`.CacheKey` that was generated ahead of creating this
+ :class:`.Compiled` object.
+
+ This is used for routines that need access to the original
+ :class:`.CacheKey` instance generated when the :class:`.Compiled`
+ instance was first cached, typically in order to reconcile
+ the original list of :class:`.BindParameter` objects with a
+ per-statement list that's generated on each call.
+
+ """
_gen_time: float
+ """Generation time of this :class:`.Compiled`, used for reporting
+ cache stats."""
def __init__(
self,
@@ -543,7 +610,11 @@ class Compiled:
return self.string or ""
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
"""Return the bind params for this compiled object.
:param params: a dict of string/object pairs whose values will
@@ -646,6 +717,17 @@ class SQLCompiler(Compiled):
isplaintext: bool = False
+ binds: Dict[str, BindParameter[Any]]
+ """a dictionary of bind parameter keys to BindParameter instances."""
+
+ bind_names: Dict[BindParameter[Any], str]
+ """a dictionary of BindParameter instances to "compiled" names
+ that are actually present in the generated SQL"""
+
+ stack: List[_CompilerStackEntry]
+ """major statements such as SELECT, INSERT, UPDATE, DELETE are
+ tracked in this stack using an entry format."""
+
result_columns: List[ResultColumnsEntry]
"""relates label names in the final SQL to a tuple of local
column/label name, ColumnElement object (if any) and
@@ -709,7 +791,7 @@ class SQLCompiler(Compiled):
"""
- insert_single_values_expr = None
+ insert_single_values_expr: Optional[str] = None
"""When an INSERT is compiled with a single set of parameters inside
a VALUES expression, the string is assigned here, where it can be
used for insert batching schemes to rewrite the VALUES expression.
@@ -718,19 +800,19 @@ class SQLCompiler(Compiled):
"""
- literal_execute_params = frozenset()
+ literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as literal values at statement
execution time.
"""
- post_compile_params = frozenset()
+ post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as bound parameter placeholders
at statement execution time.
"""
- escaped_bind_names = util.EMPTY_DICT
+ escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
"""Late escaping of bound parameter names that has to be converted
to the original name when looking in the parameter dictionary.
@@ -744,14 +826,25 @@ class SQLCompiler(Compiled):
"""if True, and this in insert, use cursor.lastrowid to populate
result.inserted_primary_key. """
- _cache_key_bind_match = None
+ _cache_key_bind_match: Optional[
+ Tuple[
+ Dict[
+ BindParameter[Any],
+ List[BindParameter[Any]],
+ ],
+ Dict[
+ str,
+ BindParameter[Any],
+ ],
+ ]
+ ] = None
"""a mapping that will relate the BindParameter object we compile
to those that are part of the extracted collection of parameters
in the cache key, if we were given a cache key.
"""
- positiontup: Optional[Sequence[str]] = None
+ positiontup: Optional[List[str]] = None
"""for a compiled construct that uses a positional paramstyle, will be
a sequence of strings, indicating the names of bound parameters in order.
@@ -768,6 +861,19 @@ class SQLCompiler(Compiled):
inline: bool = False
+ ctes: Optional[MutableMapping[CTE, str]]
+
+ # Detect same CTE references - Dict[(level, name), cte]
+ # Level is required for supporting nesting
+ ctes_by_level_name: Dict[Tuple[int, str], CTE]
+
+ # To retrieve key/level in ctes_by_level_name -
+ # Dict[cte_reference, (level, cte_name, cte_opts)]
+ level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
+
+ ctes_recursive: bool
+ cte_positional: Dict[CTE, List[str]]
+
def __init__(
self,
dialect,
@@ -804,10 +910,9 @@ class SQLCompiler(Compiled):
self.cache_key = cache_key
if cache_key:
- self._cache_key_bind_match = ckbm = {
- b.key: b for b in cache_key[1]
- }
- ckbm.update({b: [b] for b in cache_key[1]})
+ cksm = {b.key: b for b in cache_key[1]}
+ ckbm = {b: [b] for b in cache_key[1]}
+ self._cache_key_bind_match = (ckbm, cksm)
# compile INSERT/UPDATE defaults/sequences to expect executemany
# style execution, which may mean no pre-execute of defaults,
@@ -911,14 +1016,14 @@ class SQLCompiler(Compiled):
@property
def prefetch(self):
- return list(self.insert_prefetch + self.update_prefetch)
+ return list(self.insert_prefetch) + list(self.update_prefetch)
@util.memoized_property
def _global_attributes(self):
return {}
@util.memoized_instancemethod
- def _init_cte_state(self) -> None:
+ def _init_cte_state(self) -> MutableMapping[CTE, str]:
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
@@ -926,21 +1031,22 @@ class SQLCompiler(Compiled):
"""
# collect CTEs to tack on top of a SELECT
# To store the query to print - Dict[cte, text_query]
- self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ self.ctes = ctes
# Detect same CTE references - Dict[(level, name), cte]
# Level is required for supporting nesting
- self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
+ self.ctes_by_level_name = {}
# To retrieve key/level in ctes_by_level_name -
# Dict[cte_reference, (level, cte_name, cte_opts)]
- self.level_name_by_cte: Dict[
- CTE, Tuple[int, str, selectable._CTEOpts]
- ] = {}
+ self.level_name_by_cte = {}
- self.ctes_recursive: bool = False
+ self.ctes_recursive = False
if self.positional:
- self.cte_positional: Dict[CTE, List[str]] = {}
+ self.cte_positional = {}
+
+ return ctes
@contextlib.contextmanager
def _nested_result(self):
@@ -985,7 +1091,7 @@ class SQLCompiler(Compiled):
if not bindparam.type._is_tuple_type
else tuple(
elem_type._cached_bind_processor(self.dialect)
- for elem_type in bindparam.type.types
+ for elem_type in cast(TupleType, bindparam.type).types
),
)
for bindparam in self.bind_names
@@ -1002,11 +1108,11 @@ class SQLCompiler(Compiled):
def construct_params(
self,
- params=None,
- _group_number=None,
- _check=True,
- extracted_parameters=None,
- ):
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ _group_number: Optional[int] = None,
+ _check: bool = True,
+ ) -> _MutableCoreSingleExecuteParams:
"""return a dictionary of bind parameter keys and values"""
has_escaped_names = bool(self.escaped_bind_names)
@@ -1018,15 +1124,17 @@ class SQLCompiler(Compiled):
# way. The parameters present in self.bind_names may be clones of
# these original cache key params in the case of DML but the .key
# will be guaranteed to match.
- try:
- orig_extracted = self.cache_key[1]
- except TypeError as err:
+ if self.cache_key is None:
raise exc.CompileError(
"This compiled object has no original cache key; "
"can't pass extracted_parameters to construct_params"
- ) from err
+ )
+ else:
+ orig_extracted = self.cache_key[1]
- ckbm = self._cache_key_bind_match
+ ckbm_tuple = self._cache_key_bind_match
+ assert ckbm_tuple is not None
+ ckbm, _ = ckbm_tuple
resolved_extracted = {
bind: extracted
for b, extracted in zip(orig_extracted, extracted_parameters)
@@ -1142,7 +1250,8 @@ class SQLCompiler(Compiled):
if bindparam.type._is_tuple_type:
inputsizes[bindparam] = [
- lookup_type(typ) for typ in bindparam.type.types
+ lookup_type(typ)
+ for typ in cast(TupleType, bindparam.type).types
]
else:
inputsizes[bindparam] = lookup_type(bindparam.type)
@@ -1164,7 +1273,7 @@ class SQLCompiler(Compiled):
def _process_parameters_for_postcompile(
self,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_MutableCoreSingleExecuteParams] = None,
_populate_self: bool = False,
) -> ExpandedState:
"""handle special post compile parameters.
@@ -1183,14 +1292,20 @@ class SQLCompiler(Compiled):
parameters = self.construct_params()
expanded_parameters = {}
+ positiontup: Optional[List[str]]
+
if self.positional:
positiontup = []
else:
positiontup = None
processors = self._bind_processors
+ single_processors = cast("Mapping[str, _ProcessorType]", processors)
+ tuple_processors = cast(
+ "Mapping[str, Sequence[_ProcessorType]]", processors
+ )
- new_processors = {}
+ new_processors: Dict[str, _ProcessorType] = {}
if self.positional and self._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'.
@@ -1203,8 +1318,8 @@ class SQLCompiler(Compiled):
"the 'numeric' paramstyle at this time."
)
- replacement_expressions = {}
- to_update_sets = {}
+ replacement_expressions: Dict[str, Any] = {}
+ to_update_sets: Dict[str, Any] = {}
# notes:
# *unescaped* parameter names in:
@@ -1213,9 +1328,12 @@ class SQLCompiler(Compiled):
# *escaped* parameter names in:
# construct_params(), replacement_expressions
- for name in (
- self.positiontup if self.positional else self.bind_names.values()
- ):
+ if self.positional and self.positiontup is not None:
+ names: Iterable[str] = self.positiontup
+ else:
+ names = self.bind_names.values()
+
+ for name in names:
escaped_name = (
self.escaped_bind_names.get(name, name)
if self.escaped_bind_names
@@ -1236,6 +1354,7 @@ class SQLCompiler(Compiled):
if parameter in self.post_compile_params:
if escaped_name in replacement_expressions:
to_update = to_update_sets[escaped_name]
+ values = None
else:
# we are removing the parameter from parameters
# because it is a list value, which is not expected by
@@ -1256,28 +1375,29 @@ class SQLCompiler(Compiled):
if not parameter.literal_execute:
parameters.update(to_update)
if parameter.type._is_tuple_type:
+ assert values is not None
new_processors.update(
(
"%s_%s_%s" % (name, i, j),
- processors[name][j - 1],
+ tuple_processors[name][j - 1],
)
for i, tuple_element in enumerate(values, 1)
- for j, value in enumerate(tuple_element, 1)
- if name in processors
- and processors[name][j - 1] is not None
+ for j, _ in enumerate(tuple_element, 1)
+ if name in tuple_processors
+ and tuple_processors[name][j - 1] is not None
)
else:
new_processors.update(
- (key, processors[name])
- for key, value in to_update
- if name in processors
+ (key, single_processors[name])
+ for key, _ in to_update
+ if name in single_processors
)
- if self.positional:
- positiontup.extend(name for name, value in to_update)
+ if positiontup is not None:
+ positiontup.extend(name for name, _ in to_update)
expanded_parameters[name] = [
- expand_key for expand_key, value in to_update
+ expand_key for expand_key, _ in to_update
]
- elif self.positional:
+ elif positiontup is not None:
positiontup.append(name)
def process_expanding(m):
@@ -1315,7 +1435,7 @@ class SQLCompiler(Compiled):
# special use cases.
self.string = expanded_state.statement
self._bind_processors.update(expanded_state.processors)
- self.positiontup = expanded_state.positiontup
+ self.positiontup = list(expanded_state.positiontup or ())
self.post_compile_params = frozenset()
for key in expanded_state.parameter_expansion:
bind = self.binds.pop(key)
@@ -1338,6 +1458,12 @@ class SQLCompiler(Compiled):
self._result_columns
)
+ _key_getters_for_crud_column: Tuple[
+ Callable[[Union[str, Column[Any]]], str],
+ Callable[[Column[Any]], str],
+ Callable[[Column[Any]], str],
+ ]
+
@util.memoized_property
def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
getter = self._key_getters_for_crud_column[2]
@@ -1398,22 +1524,30 @@ class SQLCompiler(Compiled):
@util.memoized_property
@util.preload_module("sqlalchemy.engine.result")
def _inserted_primary_key_from_returning_getter(self):
- result = util.preloaded.engine_result
+ if typing.TYPE_CHECKING:
+ from ..engine import result
+ else:
+ result = util.preloaded.engine_result
param_key_getter = self._within_exec_param_key_getter
table = self.statement.table
- ret = {col: idx for idx, col in enumerate(self.returning)}
+ returning = self.returning
+ assert returning is not None
+ ret = {col: idx for idx, col in enumerate(returning)}
- getters = [
- (operator.itemgetter(ret[col]), True)
- if col in ret
- else (
- operator.methodcaller("get", param_key_getter(col), None),
- False,
- )
- for col in table.primary_key
- ]
+ getters = cast(
+ "List[Tuple[Callable[[Any], Any], bool]]",
+ [
+ (operator.itemgetter(ret[col]), True)
+ if col in ret
+ else (
+ operator.methodcaller("get", param_key_getter(col), None),
+ False,
+ )
+ for col in table.primary_key
+ ],
+ )
row_fn = result.result_tuple([col.key for col in table.primary_key])
@@ -1444,7 +1578,16 @@ class SQLCompiler(Compiled):
self, element, within_columns_clause=False, **kwargs
):
if self.stack and self.dialect.supports_simple_order_by_label:
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ raise exc.CompileError(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ) from ke
(
with_cols,
@@ -1485,7 +1628,22 @@ class SQLCompiler(Compiled):
# compiling the element outside of the context of a SELECT
return self.process(element._text_clause)
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ coercions._no_text_coercion(
+ element.element,
+ extra=(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ),
+ exc_cls=exc.CompileError,
+ err=ke,
+ )
+
with_cols, only_froms, only_cols = compile_state._label_resolve_dict
try:
if within_columns_clause:
@@ -1568,13 +1726,13 @@ class SQLCompiler(Compiled):
def visit_column(
self,
- column,
- add_to_result_map=None,
- include_table=True,
- result_map_targets=(),
- ambiguous_table_name_map=None,
- **kwargs,
- ):
+ column: ColumnClause[Any],
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ include_table: bool = True,
+ result_map_targets: Tuple[Any, ...] = (),
+ ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
+ **kwargs: Any,
+ ) -> str:
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
@@ -1608,7 +1766,8 @@ class SQLCompiler(Compiled):
)
else:
schema_prefix = ""
- tablename = table.name
+
+ tablename = cast("NamedFromClause", table).name
if (
not effective_schema
@@ -1678,7 +1837,7 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"correlate_froms": set(),
"asfrom_froms": set(),
"selectable": taf,
@@ -1879,11 +2038,19 @@ class SQLCompiler(Compiled):
compiled_col = self.visit_column(element, **kw)
return "(%s).%s" % (compiled_fn, compiled_col)
- def visit_function(self, func, add_to_result_map=None, **kwargs):
+ def visit_function(
+ self,
+ func: Function,
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ **kwargs: Any,
+ ) -> str:
if add_to_result_map is not None:
add_to_result_map(func.name, func.name, (), func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+
+ text: str
+
if disp:
text = disp(func, **kwargs)
else:
@@ -1964,7 +2131,7 @@ class SQLCompiler(Compiled):
if compound_stmt._independent_ctes:
self._dispatch_independent_ctes(compound_stmt, kwargs)
- keyword = self.compound_keywords.get(cs.keyword)
+ keyword = self.compound_keywords[cs.keyword]
text = (" " + keyword + " ").join(
(
@@ -2591,11 +2758,13 @@ class SQLCompiler(Compiled):
# a different set of parameter values. here, we accommodate for
# parameters that may have been cloned both before and after the cache
# key was been generated.
- ckbm = self._cache_key_bind_match
- if ckbm:
+ ckbm_tuple = self._cache_key_bind_match
+
+ if ckbm_tuple:
+ ckbm, cksm = ckbm_tuple
for bp in bindparam._cloned_set:
- if bp.key in ckbm:
- cb = ckbm[bp.key]
+ if bp.key in cksm:
+ cb = cksm[bp.key]
ckbm[cb].append(bindparam)
if bindparam.isoutparam:
@@ -2720,7 +2889,7 @@ class SQLCompiler(Compiled):
if positional_names is not None:
positional_names.append(name)
else:
- self.positiontup.append(name)
+ self.positiontup.append(name) # type: ignore[union-attr]
elif not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
@@ -2735,9 +2904,9 @@ class SQLCompiler(Compiled):
name = new_name
if escaped_from:
- if not self.escaped_bind_names:
- self.escaped_bind_names = {}
- self.escaped_bind_names[escaped_from] = name
+ self.escaped_bind_names = self.escaped_bind_names.union(
+ {escaped_from: name}
+ )
if post_compile:
return "__[POSTCOMPILE_%s]" % name
@@ -2772,7 +2941,8 @@ class SQLCompiler(Compiled):
cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
**kwargs: Any,
) -> Optional[str]:
- self._init_cte_state()
+ self_ctes = self._init_cte_state()
+ assert self_ctes is self.ctes
kwargs["visiting_cte"] = cte
@@ -2838,7 +3008,7 @@ class SQLCompiler(Compiled):
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
- del self.ctes[existing_cte]
+ del self_ctes[existing_cte]
existing_cte_reference_cte = existing_cte._get_reference_cte()
@@ -2875,7 +3045,7 @@ class SQLCompiler(Compiled):
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
- if not cte_pre_alias_name and cte not in self.ctes:
+ if not cte_pre_alias_name and cte not in self_ctes:
if cte.recursive:
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
@@ -2942,14 +3112,14 @@ class SQLCompiler(Compiled):
cte, cte._suffixes, **kwargs
)
- self.ctes[cte] = text
+ self_ctes[cte] = text
if asfrom:
if from_linter:
from_linter.froms[cte] = cte_name
if not is_new_cte and embedded_in_current_named_cte:
- return self.preparer.format_alias(cte, cte_name)
+ return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501
if cte_pre_alias_name:
text = self.preparer.format_alias(cte, cte_pre_alias_name)
@@ -2960,6 +3130,8 @@ class SQLCompiler(Compiled):
else:
return self.preparer.format_alias(cte, cte_name)
+ return None
+
def visit_table_valued_alias(self, element, **kw):
if element._is_lateral:
return self.visit_lateral(element, **kw)
@@ -3143,7 +3315,7 @@ class SQLCompiler(Compiled):
self,
keyname: str,
name: str,
- objects: List[Any],
+ objects: Tuple[Any, ...],
type_: TypeEngine[Any],
) -> None:
if keyname is None or keyname == "*":
@@ -3358,9 +3530,12 @@ class SQLCompiler(Compiled):
def get_statement_hint_text(self, hint_texts):
return " ".join(hint_texts)
- _default_stack_entry = util.immutabledict(
- [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
- )
+ _default_stack_entry: _CompilerStackEntry
+
+ if not typing.TYPE_CHECKING:
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
def _display_froms_for_select(
self, select_stmt, asfrom, lateral=False, **kw
@@ -3391,7 +3566,7 @@ class SQLCompiler(Compiled):
)
return froms
- translate_select_structure = None
+ translate_select_structure: Any = None
"""if not ``None``, should be a callable which accepts ``(select_stmt,
**kw)`` and returns a select object. this is used for structural changes
mostly to accommodate for LIMIT/OFFSET schemes
@@ -3563,7 +3738,9 @@ class SQLCompiler(Compiled):
)
self._result_columns = [
- (key, name, tuple(translate.get(o, o) for o in obj), type_)
+ ResultColumnsEntry(
+ key, name, tuple(translate.get(o, o) for o in obj), type_
+ )
for key, name, obj, type_ in self._result_columns
]
@@ -3660,10 +3837,10 @@ class SQLCompiler(Compiled):
implicit_correlate_froms=asfrom_froms,
)
- new_correlate_froms = set(selectable._from_objects(*froms))
+ new_correlate_froms = set(_from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"asfrom_froms": new_correlate_froms,
"correlate_froms": all_correlate_froms,
"selectable": select,
@@ -3734,6 +3911,7 @@ class SQLCompiler(Compiled):
text += " \nWHERE " + t
if warn_linting:
+ assert from_linter is not None
from_linter.warn()
if select._group_by_clauses:
@@ -3781,6 +3959,8 @@ class SQLCompiler(Compiled):
if not self.ctes:
return ""
+ ctes: MutableMapping[CTE, str]
+
if nesting_level and nesting_level > 1:
ctes = util.OrderedDict()
for cte in list(self.ctes.keys()):
@@ -3805,10 +3985,16 @@ class SQLCompiler(Compiled):
ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
+ assert self.positiontup is not None
self.positiontup = (
- sum([self.cte_positional[cte] for cte in ctes], [])
+ list(
+ itertools.chain.from_iterable(
+ self.cte_positional[cte] for cte in ctes
+ )
+ )
+ self.positiontup
)
+
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
@@ -4190,7 +4376,7 @@ class SQLCompiler(Compiled):
if is_multitable:
# main table might be a JOIN
- main_froms = set(selectable._from_objects(update_stmt.table))
+ main_froms = set(_from_objects(update_stmt.table))
render_extra_froms = [
f for f in extra_froms if f not in main_froms
]
@@ -4506,7 +4692,11 @@ class DDLCompiler(Compiled):
def type_compiler(self):
return self.dialect.type_compiler
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
return None
def visit_ddl(self, ddl, **kwargs):
@@ -5199,6 +5389,11 @@ class StrSQLTypeCompiler(GenericTypeCompiler):
return get_col_spec(**kw)
+class _SchemaForObjectCallable(Protocol):
+ def __call__(self, obj: Any) -> str:
+ ...
+
+
class IdentifierPreparer:
"""Handle quoting and case-folding of identifiers based on options."""
@@ -5209,7 +5404,13 @@ class IdentifierPreparer:
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
- schema_for_object = operator.attrgetter("schema")
+ initial_quote: str
+
+ final_quote: str
+
+ _strings: MutableMapping[str, str]
+
+ schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
"""Return the .schema attribute for an object.
For the default IdentifierPreparer, the schema for an object is always
@@ -5297,7 +5498,7 @@ class IdentifierPreparer:
return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
- def _escape_identifier(self, value):
+ def _escape_identifier(self, value: str) -> str:
"""Escape an identifier.
Subclasses should override this to provide database-dependent
@@ -5309,7 +5510,7 @@ class IdentifierPreparer:
value = value.replace("%", "%%")
return value
- def _unescape_identifier(self, value):
+ def _unescape_identifier(self, value: str) -> str:
"""Canonicalize an escaped identifier.
Subclasses should override this to provide database-dependent
@@ -5336,7 +5537,7 @@ class IdentifierPreparer:
)
return element
- def quote_identifier(self, value):
+ def quote_identifier(self, value: str) -> str:
"""Quote an identifier.
Subclasses should override this to provide database-dependent
@@ -5349,7 +5550,7 @@ class IdentifierPreparer:
+ self.final_quote
)
- def _requires_quotes(self, value):
+ def _requires_quotes(self, value: str) -> bool:
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (
@@ -5364,7 +5565,7 @@ class IdentifierPreparer:
not taking case convention into account."""
return not self.legal_characters.match(str(value))
- def quote_schema(self, schema, force=None):
+ def quote_schema(self, schema: str, force: Any = None) -> str:
"""Conditionally quote a schema name.
@@ -5403,7 +5604,7 @@ class IdentifierPreparer:
return self.quote(schema)
- def quote(self, ident, force=None):
+ def quote(self, ident: str, force: Any = None) -> str:
"""Conditionally quote an identifier.
The identifier is quoted if it is a reserved word, contains
@@ -5474,11 +5675,19 @@ class IdentifierPreparer:
name = self.quote_schema(effective_schema) + "." + name
return name
- def format_label(self, label, name=None):
+ def format_label(
+ self, label: Label[Any], name: Optional[str] = None
+ ) -> str:
return self.quote(name or label.name)
- def format_alias(self, alias, name=None):
- return self.quote(name or alias.name)
+ def format_alias(
+ self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
+ ) -> str:
+ if name is None:
+ assert alias is not None
+ return self.quote(alias.name)
+ else:
+ return self.quote(name)
def format_savepoint(self, savepoint, name=None):
# Running the savepoint name through quoting is unnecessary