summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py447
1 files changed, 328 insertions, 119 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 423c3d446..f28dceefc 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -35,14 +35,19 @@ from time import perf_counter
import typing
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
+from typing import FrozenSet
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
from typing import Union
from . import base
@@ -54,19 +59,42 @@ from . import operators
from . import schema
from . import selectable
from . import sqltypes
+from .base import _from_objects
from .base import NO_ARG
-from .base import prefix_anon_map
from .elements import quoted_name
from .schema import Column
+from .sqltypes import TupleType
from .type_api import TypeEngine
+from .visitors import prefix_anon_map
from .. import exc
from .. import util
from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
+ from .annotation import _AnnotationDict
+ from .base import _AmbiguousTableNameMap
+ from .base import CompileState
+ from .cache_key import CacheKey
+ from .elements import BindParameter
+ from .elements import ColumnClause
+ from .elements import Label
+ from .functions import Function
+ from .selectable import Alias
+ from .selectable import AliasedReturnsRows
+ from .selectable import CompoundSelectState
from .selectable import CTE
from .selectable import FromClause
+ from .selectable import NamedFromClause
+ from .selectable import ReturnsRows
+ from .selectable import Select
+ from .selectable import SelectState
+ from ..engine.cursor import CursorResultMetaData
from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _MutableCoreSingleExecuteParams
+ from ..engine.interfaces import _SchemaTranslateMapType
from ..engine.result import _ProcessorType
_FromHintsType = Dict["FromClause", str]
@@ -236,7 +264,7 @@ OPERATORS = {
operators.nulls_last_op: " NULLS LAST",
}
-FUNCTIONS = {
+FUNCTIONS: Dict[Type[Function], str] = {
functions.coalesce: "coalesce",
functions.current_date: "CURRENT_DATE",
functions.current_time: "CURRENT_TIME",
@@ -298,8 +326,8 @@ class ResultColumnsEntry(NamedTuple):
name: str
"""column name, may be labeled"""
- objects: List[Any]
- """list of objects that should be able to locate this column
+ objects: Tuple[Any, ...]
+ """sequence of objects that should be able to locate this column
in a RowMapping. This is typically string names and aliases
as well as Column objects.
@@ -313,6 +341,17 @@ class ResultColumnsEntry(NamedTuple):
"""
+class _ResultMapAppender(Protocol):
+ def __call__(
+ self,
+ keyname: str,
+ name: str,
+ objects: Sequence[Any],
+ type_: TypeEngine[Any],
+ ) -> None:
+ ...
+
+
# integer indexes into ResultColumnsEntry used by cursor.py.
# some profiling showed integer access faster than named tuple
RM_RENDERED_NAME: Literal[0] = 0
@@ -321,6 +360,20 @@ RM_OBJECTS: Literal[2] = 2
RM_TYPE: Literal[3] = 3
+class _BaseCompilerStackEntry(TypedDict):
+ asfrom_froms: Set[FromClause]
+ correlate_froms: Set[FromClause]
+ selectable: ReturnsRows
+
+
+class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
+ compile_state: CompileState
+ need_result_map_for_nested: bool
+ need_result_map_for_compound: bool
+ select_0: ReturnsRows
+ insert_from_select: Select
+
+
class ExpandedState(NamedTuple):
statement: str
additional_parameters: _CoreSingleExecuteParams
@@ -427,21 +480,23 @@ class Compiled:
defaults.
"""
- _cached_metadata = None
+ _cached_metadata: Optional[CursorResultMetaData] = None
_result_columns: Optional[List[ResultColumnsEntry]] = None
- schema_translate_map = None
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None
- execution_options = util.EMPTY_DICT
+ execution_options: _ExecuteOptions = util.EMPTY_DICT
"""
Execution options propagated from the statement. In some cases,
sub-elements of the statement can modify these.
"""
- _annotations = util.EMPTY_DICT
+ preparer: IdentifierPreparer
+
+ _annotations: _AnnotationDict = util.EMPTY_DICT
- compile_state = None
+ compile_state: Optional[CompileState] = None
"""Optional :class:`.CompileState` object that maintains additional
state used by the compiler.
@@ -457,9 +512,21 @@ class Compiled:
"""
- cache_key = None
+ cache_key: Optional[CacheKey] = None
+ """The :class:`.CacheKey` that was generated ahead of creating this
+ :class:`.Compiled` object.
+
+ This is used for routines that need access to the original
+ :class:`.CacheKey` instance generated when the :class:`.Compiled`
+ instance was first cached, typically in order to reconcile
+ the original list of :class:`.BindParameter` objects with a
+ per-statement list that's generated on each call.
+
+ """
_gen_time: float
+ """Generation time of this :class:`.Compiled`, used for reporting
+ cache stats."""
def __init__(
self,
@@ -543,7 +610,11 @@ class Compiled:
return self.string or ""
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
"""Return the bind params for this compiled object.
:param params: a dict of string/object pairs whose values will
@@ -646,6 +717,17 @@ class SQLCompiler(Compiled):
isplaintext: bool = False
+ binds: Dict[str, BindParameter[Any]]
+ """a dictionary of bind parameter keys to BindParameter instances."""
+
+ bind_names: Dict[BindParameter[Any], str]
+ """a dictionary of BindParameter instances to "compiled" names
+ that are actually present in the generated SQL"""
+
+ stack: List[_CompilerStackEntry]
+ """major statements such as SELECT, INSERT, UPDATE, DELETE are
+ tracked in this stack using an entry format."""
+
result_columns: List[ResultColumnsEntry]
"""relates label names in the final SQL to a tuple of local
column/label name, ColumnElement object (if any) and
@@ -709,7 +791,7 @@ class SQLCompiler(Compiled):
"""
- insert_single_values_expr = None
+ insert_single_values_expr: Optional[str] = None
"""When an INSERT is compiled with a single set of parameters inside
a VALUES expression, the string is assigned here, where it can be
used for insert batching schemes to rewrite the VALUES expression.
@@ -718,19 +800,19 @@ class SQLCompiler(Compiled):
"""
- literal_execute_params = frozenset()
+ literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as literal values at statement
execution time.
"""
- post_compile_params = frozenset()
+ post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as bound parameter placeholders
at statement execution time.
"""
- escaped_bind_names = util.EMPTY_DICT
+ escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
"""Late escaping of bound parameter names that has to be converted
to the original name when looking in the parameter dictionary.
@@ -744,14 +826,25 @@ class SQLCompiler(Compiled):
"""if True, and this in insert, use cursor.lastrowid to populate
result.inserted_primary_key. """
- _cache_key_bind_match = None
+ _cache_key_bind_match: Optional[
+ Tuple[
+ Dict[
+ BindParameter[Any],
+ List[BindParameter[Any]],
+ ],
+ Dict[
+ str,
+ BindParameter[Any],
+ ],
+ ]
+ ] = None
"""a mapping that will relate the BindParameter object we compile
to those that are part of the extracted collection of parameters
in the cache key, if we were given a cache key.
"""
- positiontup: Optional[Sequence[str]] = None
+ positiontup: Optional[List[str]] = None
"""for a compiled construct that uses a positional paramstyle, will be
a sequence of strings, indicating the names of bound parameters in order.
@@ -768,6 +861,19 @@ class SQLCompiler(Compiled):
inline: bool = False
+ ctes: Optional[MutableMapping[CTE, str]]
+
+ # Detect same CTE references - Dict[(level, name), cte]
+ # Level is required for supporting nesting
+ ctes_by_level_name: Dict[Tuple[int, str], CTE]
+
+ # To retrieve key/level in ctes_by_level_name -
+ # Dict[cte_reference, (level, cte_name, cte_opts)]
+ level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
+
+ ctes_recursive: bool
+ cte_positional: Dict[CTE, List[str]]
+
def __init__(
self,
dialect,
@@ -804,10 +910,9 @@ class SQLCompiler(Compiled):
self.cache_key = cache_key
if cache_key:
- self._cache_key_bind_match = ckbm = {
- b.key: b for b in cache_key[1]
- }
- ckbm.update({b: [b] for b in cache_key[1]})
+ cksm = {b.key: b for b in cache_key[1]}
+ ckbm = {b: [b] for b in cache_key[1]}
+ self._cache_key_bind_match = (ckbm, cksm)
# compile INSERT/UPDATE defaults/sequences to expect executemany
# style execution, which may mean no pre-execute of defaults,
@@ -911,14 +1016,14 @@ class SQLCompiler(Compiled):
@property
def prefetch(self):
- return list(self.insert_prefetch + self.update_prefetch)
+ return list(self.insert_prefetch) + list(self.update_prefetch)
@util.memoized_property
def _global_attributes(self):
return {}
@util.memoized_instancemethod
- def _init_cte_state(self) -> None:
+ def _init_cte_state(self) -> MutableMapping[CTE, str]:
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
@@ -926,21 +1031,22 @@ class SQLCompiler(Compiled):
"""
# collect CTEs to tack on top of a SELECT
# To store the query to print - Dict[cte, text_query]
- self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ self.ctes = ctes
# Detect same CTE references - Dict[(level, name), cte]
# Level is required for supporting nesting
- self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
+ self.ctes_by_level_name = {}
# To retrieve key/level in ctes_by_level_name -
# Dict[cte_reference, (level, cte_name, cte_opts)]
- self.level_name_by_cte: Dict[
- CTE, Tuple[int, str, selectable._CTEOpts]
- ] = {}
+ self.level_name_by_cte = {}
- self.ctes_recursive: bool = False
+ self.ctes_recursive = False
if self.positional:
- self.cte_positional: Dict[CTE, List[str]] = {}
+ self.cte_positional = {}
+
+ return ctes
@contextlib.contextmanager
def _nested_result(self):
@@ -985,7 +1091,7 @@ class SQLCompiler(Compiled):
if not bindparam.type._is_tuple_type
else tuple(
elem_type._cached_bind_processor(self.dialect)
- for elem_type in bindparam.type.types
+ for elem_type in cast(TupleType, bindparam.type).types
),
)
for bindparam in self.bind_names
@@ -1002,11 +1108,11 @@ class SQLCompiler(Compiled):
def construct_params(
self,
- params=None,
- _group_number=None,
- _check=True,
- extracted_parameters=None,
- ):
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ _group_number: Optional[int] = None,
+ _check: bool = True,
+ ) -> _MutableCoreSingleExecuteParams:
"""return a dictionary of bind parameter keys and values"""
has_escaped_names = bool(self.escaped_bind_names)
@@ -1018,15 +1124,17 @@ class SQLCompiler(Compiled):
# way. The parameters present in self.bind_names may be clones of
# these original cache key params in the case of DML but the .key
# will be guaranteed to match.
- try:
- orig_extracted = self.cache_key[1]
- except TypeError as err:
+ if self.cache_key is None:
raise exc.CompileError(
"This compiled object has no original cache key; "
"can't pass extracted_parameters to construct_params"
- ) from err
+ )
+ else:
+ orig_extracted = self.cache_key[1]
- ckbm = self._cache_key_bind_match
+ ckbm_tuple = self._cache_key_bind_match
+ assert ckbm_tuple is not None
+ ckbm, _ = ckbm_tuple
resolved_extracted = {
bind: extracted
for b, extracted in zip(orig_extracted, extracted_parameters)
@@ -1142,7 +1250,8 @@ class SQLCompiler(Compiled):
if bindparam.type._is_tuple_type:
inputsizes[bindparam] = [
- lookup_type(typ) for typ in bindparam.type.types
+ lookup_type(typ)
+ for typ in cast(TupleType, bindparam.type).types
]
else:
inputsizes[bindparam] = lookup_type(bindparam.type)
@@ -1164,7 +1273,7 @@ class SQLCompiler(Compiled):
def _process_parameters_for_postcompile(
self,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_MutableCoreSingleExecuteParams] = None,
_populate_self: bool = False,
) -> ExpandedState:
"""handle special post compile parameters.
@@ -1183,14 +1292,20 @@ class SQLCompiler(Compiled):
parameters = self.construct_params()
expanded_parameters = {}
+ positiontup: Optional[List[str]]
+
if self.positional:
positiontup = []
else:
positiontup = None
processors = self._bind_processors
+ single_processors = cast("Mapping[str, _ProcessorType]", processors)
+ tuple_processors = cast(
+ "Mapping[str, Sequence[_ProcessorType]]", processors
+ )
- new_processors = {}
+ new_processors: Dict[str, _ProcessorType] = {}
if self.positional and self._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'.
@@ -1203,8 +1318,8 @@ class SQLCompiler(Compiled):
"the 'numeric' paramstyle at this time."
)
- replacement_expressions = {}
- to_update_sets = {}
+ replacement_expressions: Dict[str, Any] = {}
+ to_update_sets: Dict[str, Any] = {}
# notes:
# *unescaped* parameter names in:
@@ -1213,9 +1328,12 @@ class SQLCompiler(Compiled):
# *escaped* parameter names in:
# construct_params(), replacement_expressions
- for name in (
- self.positiontup if self.positional else self.bind_names.values()
- ):
+ if self.positional and self.positiontup is not None:
+ names: Iterable[str] = self.positiontup
+ else:
+ names = self.bind_names.values()
+
+ for name in names:
escaped_name = (
self.escaped_bind_names.get(name, name)
if self.escaped_bind_names
@@ -1236,6 +1354,7 @@ class SQLCompiler(Compiled):
if parameter in self.post_compile_params:
if escaped_name in replacement_expressions:
to_update = to_update_sets[escaped_name]
+ values = None
else:
# we are removing the parameter from parameters
# because it is a list value, which is not expected by
@@ -1256,28 +1375,29 @@ class SQLCompiler(Compiled):
if not parameter.literal_execute:
parameters.update(to_update)
if parameter.type._is_tuple_type:
+ assert values is not None
new_processors.update(
(
"%s_%s_%s" % (name, i, j),
- processors[name][j - 1],
+ tuple_processors[name][j - 1],
)
for i, tuple_element in enumerate(values, 1)
- for j, value in enumerate(tuple_element, 1)
- if name in processors
- and processors[name][j - 1] is not None
+ for j, _ in enumerate(tuple_element, 1)
+ if name in tuple_processors
+ and tuple_processors[name][j - 1] is not None
)
else:
new_processors.update(
- (key, processors[name])
- for key, value in to_update
- if name in processors
+ (key, single_processors[name])
+ for key, _ in to_update
+ if name in single_processors
)
- if self.positional:
- positiontup.extend(name for name, value in to_update)
+ if positiontup is not None:
+ positiontup.extend(name for name, _ in to_update)
expanded_parameters[name] = [
- expand_key for expand_key, value in to_update
+ expand_key for expand_key, _ in to_update
]
- elif self.positional:
+ elif positiontup is not None:
positiontup.append(name)
def process_expanding(m):
@@ -1315,7 +1435,7 @@ class SQLCompiler(Compiled):
# special use cases.
self.string = expanded_state.statement
self._bind_processors.update(expanded_state.processors)
- self.positiontup = expanded_state.positiontup
+ self.positiontup = list(expanded_state.positiontup or ())
self.post_compile_params = frozenset()
for key in expanded_state.parameter_expansion:
bind = self.binds.pop(key)
@@ -1338,6 +1458,12 @@ class SQLCompiler(Compiled):
self._result_columns
)
+ _key_getters_for_crud_column: Tuple[
+ Callable[[Union[str, Column[Any]]], str],
+ Callable[[Column[Any]], str],
+ Callable[[Column[Any]], str],
+ ]
+
@util.memoized_property
def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
getter = self._key_getters_for_crud_column[2]
@@ -1398,22 +1524,30 @@ class SQLCompiler(Compiled):
@util.memoized_property
@util.preload_module("sqlalchemy.engine.result")
def _inserted_primary_key_from_returning_getter(self):
- result = util.preloaded.engine_result
+ if typing.TYPE_CHECKING:
+ from ..engine import result
+ else:
+ result = util.preloaded.engine_result
param_key_getter = self._within_exec_param_key_getter
table = self.statement.table
- ret = {col: idx for idx, col in enumerate(self.returning)}
+ returning = self.returning
+ assert returning is not None
+ ret = {col: idx for idx, col in enumerate(returning)}
- getters = [
- (operator.itemgetter(ret[col]), True)
- if col in ret
- else (
- operator.methodcaller("get", param_key_getter(col), None),
- False,
- )
- for col in table.primary_key
- ]
+ getters = cast(
+ "List[Tuple[Callable[[Any], Any], bool]]",
+ [
+ (operator.itemgetter(ret[col]), True)
+ if col in ret
+ else (
+ operator.methodcaller("get", param_key_getter(col), None),
+ False,
+ )
+ for col in table.primary_key
+ ],
+ )
row_fn = result.result_tuple([col.key for col in table.primary_key])
@@ -1444,7 +1578,16 @@ class SQLCompiler(Compiled):
self, element, within_columns_clause=False, **kwargs
):
if self.stack and self.dialect.supports_simple_order_by_label:
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ raise exc.CompileError(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ) from ke
(
with_cols,
@@ -1485,7 +1628,22 @@ class SQLCompiler(Compiled):
# compiling the element outside of the context of a SELECT
return self.process(element._text_clause)
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ coercions._no_text_coercion(
+ element.element,
+ extra=(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ),
+ exc_cls=exc.CompileError,
+ err=ke,
+ )
+
with_cols, only_froms, only_cols = compile_state._label_resolve_dict
try:
if within_columns_clause:
@@ -1568,13 +1726,13 @@ class SQLCompiler(Compiled):
def visit_column(
self,
- column,
- add_to_result_map=None,
- include_table=True,
- result_map_targets=(),
- ambiguous_table_name_map=None,
- **kwargs,
- ):
+ column: ColumnClause[Any],
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ include_table: bool = True,
+ result_map_targets: Tuple[Any, ...] = (),
+ ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
+ **kwargs: Any,
+ ) -> str:
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
@@ -1608,7 +1766,8 @@ class SQLCompiler(Compiled):
)
else:
schema_prefix = ""
- tablename = table.name
+
+ tablename = cast("NamedFromClause", table).name
if (
not effective_schema
@@ -1678,7 +1837,7 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"correlate_froms": set(),
"asfrom_froms": set(),
"selectable": taf,
@@ -1879,11 +2038,19 @@ class SQLCompiler(Compiled):
compiled_col = self.visit_column(element, **kw)
return "(%s).%s" % (compiled_fn, compiled_col)
- def visit_function(self, func, add_to_result_map=None, **kwargs):
+ def visit_function(
+ self,
+ func: Function,
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ **kwargs: Any,
+ ) -> str:
if add_to_result_map is not None:
add_to_result_map(func.name, func.name, (), func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+
+ text: str
+
if disp:
text = disp(func, **kwargs)
else:
@@ -1964,7 +2131,7 @@ class SQLCompiler(Compiled):
if compound_stmt._independent_ctes:
self._dispatch_independent_ctes(compound_stmt, kwargs)
- keyword = self.compound_keywords.get(cs.keyword)
+ keyword = self.compound_keywords[cs.keyword]
text = (" " + keyword + " ").join(
(
@@ -2591,11 +2758,13 @@ class SQLCompiler(Compiled):
# a different set of parameter values. here, we accommodate for
# parameters that may have been cloned both before and after the cache
# key was been generated.
- ckbm = self._cache_key_bind_match
- if ckbm:
+ ckbm_tuple = self._cache_key_bind_match
+
+ if ckbm_tuple:
+ ckbm, cksm = ckbm_tuple
for bp in bindparam._cloned_set:
- if bp.key in ckbm:
- cb = ckbm[bp.key]
+ if bp.key in cksm:
+ cb = cksm[bp.key]
ckbm[cb].append(bindparam)
if bindparam.isoutparam:
@@ -2720,7 +2889,7 @@ class SQLCompiler(Compiled):
if positional_names is not None:
positional_names.append(name)
else:
- self.positiontup.append(name)
+ self.positiontup.append(name) # type: ignore[union-attr]
elif not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
@@ -2735,9 +2904,9 @@ class SQLCompiler(Compiled):
name = new_name
if escaped_from:
- if not self.escaped_bind_names:
- self.escaped_bind_names = {}
- self.escaped_bind_names[escaped_from] = name
+ self.escaped_bind_names = self.escaped_bind_names.union(
+ {escaped_from: name}
+ )
if post_compile:
return "__[POSTCOMPILE_%s]" % name
@@ -2772,7 +2941,8 @@ class SQLCompiler(Compiled):
cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
**kwargs: Any,
) -> Optional[str]:
- self._init_cte_state()
+ self_ctes = self._init_cte_state()
+ assert self_ctes is self.ctes
kwargs["visiting_cte"] = cte
@@ -2838,7 +3008,7 @@ class SQLCompiler(Compiled):
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
- del self.ctes[existing_cte]
+ del self_ctes[existing_cte]
existing_cte_reference_cte = existing_cte._get_reference_cte()
@@ -2875,7 +3045,7 @@ class SQLCompiler(Compiled):
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
- if not cte_pre_alias_name and cte not in self.ctes:
+ if not cte_pre_alias_name and cte not in self_ctes:
if cte.recursive:
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
@@ -2942,14 +3112,14 @@ class SQLCompiler(Compiled):
cte, cte._suffixes, **kwargs
)
- self.ctes[cte] = text
+ self_ctes[cte] = text
if asfrom:
if from_linter:
from_linter.froms[cte] = cte_name
if not is_new_cte and embedded_in_current_named_cte:
- return self.preparer.format_alias(cte, cte_name)
+ return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501
if cte_pre_alias_name:
text = self.preparer.format_alias(cte, cte_pre_alias_name)
@@ -2960,6 +3130,8 @@ class SQLCompiler(Compiled):
else:
return self.preparer.format_alias(cte, cte_name)
+ return None
+
def visit_table_valued_alias(self, element, **kw):
if element._is_lateral:
return self.visit_lateral(element, **kw)
@@ -3143,7 +3315,7 @@ class SQLCompiler(Compiled):
self,
keyname: str,
name: str,
- objects: List[Any],
+ objects: Tuple[Any, ...],
type_: TypeEngine[Any],
) -> None:
if keyname is None or keyname == "*":
@@ -3358,9 +3530,12 @@ class SQLCompiler(Compiled):
def get_statement_hint_text(self, hint_texts):
return " ".join(hint_texts)
- _default_stack_entry = util.immutabledict(
- [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
- )
+ _default_stack_entry: _CompilerStackEntry
+
+ if not typing.TYPE_CHECKING:
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
def _display_froms_for_select(
self, select_stmt, asfrom, lateral=False, **kw
@@ -3391,7 +3566,7 @@ class SQLCompiler(Compiled):
)
return froms
- translate_select_structure = None
+ translate_select_structure: Any = None
"""if not ``None``, should be a callable which accepts ``(select_stmt,
**kw)`` and returns a select object. this is used for structural changes
mostly to accommodate for LIMIT/OFFSET schemes
@@ -3563,7 +3738,9 @@ class SQLCompiler(Compiled):
)
self._result_columns = [
- (key, name, tuple(translate.get(o, o) for o in obj), type_)
+ ResultColumnsEntry(
+ key, name, tuple(translate.get(o, o) for o in obj), type_
+ )
for key, name, obj, type_ in self._result_columns
]
@@ -3660,10 +3837,10 @@ class SQLCompiler(Compiled):
implicit_correlate_froms=asfrom_froms,
)
- new_correlate_froms = set(selectable._from_objects(*froms))
+ new_correlate_froms = set(_from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"asfrom_froms": new_correlate_froms,
"correlate_froms": all_correlate_froms,
"selectable": select,
@@ -3734,6 +3911,7 @@ class SQLCompiler(Compiled):
text += " \nWHERE " + t
if warn_linting:
+ assert from_linter is not None
from_linter.warn()
if select._group_by_clauses:
@@ -3781,6 +3959,8 @@ class SQLCompiler(Compiled):
if not self.ctes:
return ""
+ ctes: MutableMapping[CTE, str]
+
if nesting_level and nesting_level > 1:
ctes = util.OrderedDict()
for cte in list(self.ctes.keys()):
@@ -3805,10 +3985,16 @@ class SQLCompiler(Compiled):
ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
+ assert self.positiontup is not None
self.positiontup = (
- sum([self.cte_positional[cte] for cte in ctes], [])
+ list(
+ itertools.chain.from_iterable(
+ self.cte_positional[cte] for cte in ctes
+ )
+ )
+ self.positiontup
)
+
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
@@ -4190,7 +4376,7 @@ class SQLCompiler(Compiled):
if is_multitable:
# main table might be a JOIN
- main_froms = set(selectable._from_objects(update_stmt.table))
+ main_froms = set(_from_objects(update_stmt.table))
render_extra_froms = [
f for f in extra_froms if f not in main_froms
]
@@ -4506,7 +4692,11 @@ class DDLCompiler(Compiled):
def type_compiler(self):
return self.dialect.type_compiler
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
return None
def visit_ddl(self, ddl, **kwargs):
@@ -5199,6 +5389,11 @@ class StrSQLTypeCompiler(GenericTypeCompiler):
return get_col_spec(**kw)
+class _SchemaForObjectCallable(Protocol):
+ def __call__(self, obj: Any) -> str:
+ ...
+
+
class IdentifierPreparer:
"""Handle quoting and case-folding of identifiers based on options."""
@@ -5209,7 +5404,13 @@ class IdentifierPreparer:
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
- schema_for_object = operator.attrgetter("schema")
+ initial_quote: str
+
+ final_quote: str
+
+ _strings: MutableMapping[str, str]
+
+ schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
"""Return the .schema attribute for an object.
For the default IdentifierPreparer, the schema for an object is always
@@ -5297,7 +5498,7 @@ class IdentifierPreparer:
return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
- def _escape_identifier(self, value):
+ def _escape_identifier(self, value: str) -> str:
"""Escape an identifier.
Subclasses should override this to provide database-dependent
@@ -5309,7 +5510,7 @@ class IdentifierPreparer:
value = value.replace("%", "%%")
return value
- def _unescape_identifier(self, value):
+ def _unescape_identifier(self, value: str) -> str:
"""Canonicalize an escaped identifier.
Subclasses should override this to provide database-dependent
@@ -5336,7 +5537,7 @@ class IdentifierPreparer:
)
return element
- def quote_identifier(self, value):
+ def quote_identifier(self, value: str) -> str:
"""Quote an identifier.
Subclasses should override this to provide database-dependent
@@ -5349,7 +5550,7 @@ class IdentifierPreparer:
+ self.final_quote
)
- def _requires_quotes(self, value):
+ def _requires_quotes(self, value: str) -> bool:
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (
@@ -5364,7 +5565,7 @@ class IdentifierPreparer:
not taking case convention into account."""
return not self.legal_characters.match(str(value))
- def quote_schema(self, schema, force=None):
+ def quote_schema(self, schema: str, force: Any = None) -> str:
"""Conditionally quote a schema name.
@@ -5403,7 +5604,7 @@ class IdentifierPreparer:
return self.quote(schema)
- def quote(self, ident, force=None):
+ def quote(self, ident: str, force: Any = None) -> str:
"""Conditionally quote an identifier.
The identifier is quoted if it is a reserved word, contains
@@ -5474,11 +5675,19 @@ class IdentifierPreparer:
name = self.quote_schema(effective_schema) + "." + name
return name
- def format_label(self, label, name=None):
+ def format_label(
+ self, label: Label[Any], name: Optional[str] = None
+ ) -> str:
return self.quote(name or label.name)
- def format_alias(self, alias, name=None):
- return self.quote(name or alias.name)
+ def format_alias(
+ self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
+ ) -> str:
+ if name is None:
+ assert alias is not None
+ return self.quote(alias.name)
+ else:
+ return self.quote(name)
def format_savepoint(self, savepoint, name=None):
# Running the savepoint name through quoting is unnecessary