diff options
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 231 |
1 files changed, 174 insertions, 57 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index cdce49f7b..80711c4b5 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -13,10 +13,20 @@ from __future__ import annotations from collections import deque from itertools import chain import typing +from typing import AbstractSet from typing import Any +from typing import Callable from typing import cast +from typing import Dict from typing import Iterator +from typing import List from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union from . import coercions from . import operators @@ -49,11 +59,22 @@ from .selectable import Join from .selectable import ScalarSelect from .selectable import SelectBase from .selectable import TableClause +from .visitors import _ET from .. import exc from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument + from ._typing import _TypeEngineArgument from .roles import FromClauseRole + from .selectable import _JoinTargetElement + from .selectable import _OnClauseElement + from .selectable import Selectable + from .visitors import _TraverseCallableType + from .visitors import ExternallyTraversible + from .visitors import ExternalTraversal from ..engine.interfaces import _AnyExecuteParams from ..engine.interfaces import _AnyMultiExecuteParams from ..engine.interfaces import _AnySingleExecuteParams @@ -160,7 +181,11 @@ def find_left_clause_that_matches_given(clauses, join_from): return liberal_idx -def find_left_clause_to_join_from(clauses, join_to, onclause): +def find_left_clause_to_join_from( + clauses: Sequence[FromClause], + join_to: _JoinTargetElement, + onclause: Optional[ColumnElement[Any]], +) -> List[int]: """Given a list of FROM clauses, a selectable, and optional ON clause, return a list of integer indexes from the clauses list indicating the clauses that can be joined from. @@ -189,6 +214,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause): for i, f in enumerate(clauses): for s in selectables.difference([f]): if resolve_ambiguity: + assert cols_in_onclause is not None if set(f.c).union(s.c).issuperset(cols_in_onclause): idx.append(i) break @@ -207,7 +233,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause): # onclause was given and none of them resolved, so assume # all indexes can match if not idx and onclause is not None: - return range(len(clauses)) + return list(range(len(clauses))) else: return idx @@ -247,7 +273,7 @@ def visit_binary_product(fn, expr): a binary comparison is passed as pairs. """ - stack = [] + stack: List[ClauseElement] = [] def visit(element): if isinstance(element, ScalarSelect): @@ -272,21 +298,22 @@ def visit_binary_product(fn, expr): yield e list(visit(expr)) - visit = None # remove gc cycles + visit = None # type: ignore # remove gc cycles def find_tables( - clause, - check_columns=False, - include_aliases=False, - include_joins=False, - include_selects=False, - include_crud=False, -): + clause: ClauseElement, + *, + check_columns: bool = False, + include_aliases: bool = False, + include_joins: bool = False, + include_selects: bool = False, + include_crud: bool = False, +) -> List[TableClause]: """locate Table objects within the given expression.""" - tables = [] - _visitors = {} + tables: List[TableClause] = [] + _visitors: Dict[str, _TraverseCallableType[Any]] = {} if include_selects: _visitors["select"] = _visitors["compound_select"] = tables.append @@ -335,7 +362,7 @@ def unwrap_order_by(clause): t = stack.popleft() if isinstance(t, ColumnElement) and ( not isinstance(t, UnaryExpression) - or not operators.is_ordering_modifier(t.modifier) + or not operators.is_ordering_modifier(t.modifier) # type: ignore ): if isinstance(t, Label) and not isinstance( t.element, ScalarSelect @@ -365,9 +392,14 @@ def unwrap_order_by(clause): def unwrap_label_reference(element): - def replace(elem): - if isinstance(elem, (_label_reference, _textual_label_reference)): - return elem.element + def replace( + element: ExternallyTraversible, **kw: Any + ) -> Optional[ExternallyTraversible]: + if isinstance(element, _label_reference): + return element.element + elif isinstance(element, _textual_label_reference): + assert False, "can't unwrap a textual label reference" + return None return visitors.replacement_traverse(element, {}, replace) @@ -407,7 +439,7 @@ def clause_is_present(clause, search): return False -def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]: +def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]: if isinstance(clause, Join): for t in tables_from_leftmost(clause.left): yield t @@ -509,6 +541,8 @@ class _repr_base: __slots__ = ("max_chars",) + max_chars: int + def trunc(self, value: Any) -> str: rep = repr(value) lenrep = len(rep) @@ -612,7 +646,7 @@ class _repr_params(_repr_base): def _repr_multi( self, multi_params: _AnyMultiExecuteParams, - typ, + typ: int, ) -> str: if multi_params: if isinstance(multi_params[0], list): @@ -639,7 +673,7 @@ class _repr_params(_repr_base): def _repr_params( self, - params: Optional[_AnySingleExecuteParams], + params: _AnySingleExecuteParams, typ: int, ) -> str: trunc = self.trunc @@ -653,9 +687,10 @@ class _repr_params(_repr_base): ) ) elif typ is self._TUPLE: + seq_params = cast("Sequence[Any]", params) return "(%s%s)" % ( - ", ".join(trunc(value) for value in params), - "," if len(params) == 1 else "", + ", ".join(trunc(value) for value in seq_params), + "," if len(seq_params) == 1 else "", ) else: return "[%s]" % (", ".join(trunc(value) for value in params)) @@ -688,11 +723,15 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) -def splice_joins(left, right, stop_on=None): +def splice_joins( + left: Optional[FromClause], + right: Optional[FromClause], + stop_on: Optional[FromClause] = None, +) -> Optional[FromClause]: if left is None: return right - stack = [(right, None)] + stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)] adapter = ClauseAdapter(left) ret = None @@ -705,6 +744,7 @@ def splice_joins(left, right, stop_on=None): else: right = adapter.traverse(right) if prevright is not None: + assert right is not None prevright.left = right if ret is None: ret = right @@ -845,11 +885,14 @@ def criterion_as_pairs( elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) - pairs = [] + pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = [] visitors.traverse(expression, {}, {"binary": visit_binary}) return pairs +_CE = TypeVar("_CE", bound="ClauseElement") + + class ClauseAdapter(visitors.ReplacingExternalTraversal): """Clones and modifies clauses based on column correspondence. @@ -879,13 +922,15 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): def __init__( self, - selectable, - equivalents=None, - include_fn=None, - exclude_fn=None, - adapt_on_names=False, - anonymize_labels=False, - adapt_from_selectables=None, + selectable: Selectable, + equivalents: Optional[ + Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] + ] = None, + include_fn: Optional[Callable[[ClauseElement], bool]] = None, + exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, + adapt_on_names: bool = False, + anonymize_labels: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, ): self.__traverse_options__ = { "stop_on": [selectable], @@ -898,6 +943,29 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): self.adapt_on_names = adapt_on_names self.adapt_from_selectables = adapt_from_selectables + if TYPE_CHECKING: + + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + # note this specializes the ReplacingExternalTraversal.traverse() + # method to state + # that we will return the same kind of ExternalTraversal object as + # we were given. This is probably not 100% true, such as it's + # possible for us to swap out Alias for Table at the top level. + # Ideally there could be overloads specific to ColumnElement and + # FromClause but Mypy is not accepting those as compatible with + # the base ReplacingExternalTraversal + @overload + def traverse(self, obj: _ET) -> _ET: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: + ... + def _corresponding_column( self, col, require_embedded, _seen=util.EMPTY_SET ): @@ -919,9 +987,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): return newcol @util.preload_module("sqlalchemy.sql.functions") - def replace(self, col, _include_singleton_constants=False): + def replace( + self, col: _ET, _include_singleton_constants: bool = False + ) -> Optional[_ET]: functions = util.preloaded.sql_functions + # TODO: cython candidate + if isinstance(col, FromClause) and not isinstance( col, functions.FunctionElement ): @@ -933,7 +1005,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): break else: return None - return self.selectable + return self.selectable # type: ignore elif isinstance(col, Alias) and isinstance( col.element, TableClause ): @@ -944,7 +1016,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): # we are an alias of a table and we are not derived from an # alias of a table (which nonetheless may be the same table # as ours) so, same thing - return col + return col # type: ignore else: # other cases where we are a selectable and the element # is another join or selectable that contains a table which our @@ -972,12 +1044,22 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): else: return None + if TYPE_CHECKING: + assert isinstance(col, ColumnElement) + if self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): return None else: - return self._corresponding_column(col, True) + return self._corresponding_column(col, True) # type: ignore + + +class _ColumnLookup(Protocol): + def __getitem__( + self, key: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: + ... class ColumnAdapter(ClauseAdapter): @@ -1011,17 +1093,21 @@ class ColumnAdapter(ClauseAdapter): """ + columns: _ColumnLookup + def __init__( self, - selectable, - equivalents=None, - adapt_required=False, - include_fn=None, - exclude_fn=None, - adapt_on_names=False, - allow_label_resolve=True, - anonymize_labels=False, - adapt_from_selectables=None, + selectable: Selectable, + equivalents: Optional[ + Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] + ] = None, + adapt_required: bool = False, + include_fn: Optional[Callable[[ClauseElement], bool]] = None, + exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, + adapt_on_names: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, ): ClauseAdapter.__init__( self, @@ -1034,7 +1120,7 @@ class ColumnAdapter(ClauseAdapter): adapt_from_selectables=adapt_from_selectables, ) - self.columns = util.WeakPopulateDict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore if self.include_fn or self.exclude_fn: self.columns = self._IncludeExcludeMapping(self, self.columns) self.adapt_required = adapt_required @@ -1060,7 +1146,7 @@ class ColumnAdapter(ClauseAdapter): ac = self.__class__.__new__(self.__class__) ac.__dict__.update(self.__dict__) ac._wrap = adapter - ac.columns = util.WeakPopulateDict(ac._locate_col) + ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore if ac.include_fn or ac.exclude_fn: ac.columns = self._IncludeExcludeMapping(ac, ac.columns) @@ -1069,6 +1155,17 @@ class ColumnAdapter(ClauseAdapter): def traverse(self, obj): return self.columns[obj] + def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: + assert isinstance(visitor, ColumnAdapter) + + return super().chain(visitor) + + if TYPE_CHECKING: + + @property + def visitor_iterator(self) -> Iterator[ColumnAdapter]: + ... + adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process @@ -1080,7 +1177,9 @@ class ColumnAdapter(ClauseAdapter): return newcol - def _locate_col(self, col): + def _locate_col( + self, col: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: # both replace and traverse() are overly complicated for what # we are doing here and we would do better to have an inlined # version that doesn't build up as much overhead. the issue is that @@ -1120,10 +1219,14 @@ class ColumnAdapter(ClauseAdapter): def __setstate__(self, state): self.__dict__.update(state) - self.columns = util.WeakPopulateDict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore -def _offset_or_limit_clause(element, name=None, type_=None): +def _offset_or_limit_clause( + element: Union[int, _ColumnExpressionArgument[int]], + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[int]] = None, +) -> ColumnElement[int]: """Convert the given value to an "offset or limit" clause. This handles incoming integers and converts to an expression; if @@ -1135,7 +1238,9 @@ def _offset_or_limit_clause(element, name=None, type_=None): ) -def _offset_or_limit_clause_asint_if_possible(clause): +def _offset_or_limit_clause_asint_if_possible( + clause: Optional[Union[int, _ColumnExpressionArgument[int]]] +) -> Optional[Union[int, _ColumnExpressionArgument[int]]]: """Return the offset or limit clause as a simple integer if possible, else return the clause. @@ -1143,18 +1248,27 @@ def _offset_or_limit_clause_asint_if_possible(clause): if clause is None: return None if hasattr(clause, "_limit_offset_value"): - value = clause._limit_offset_value + value = clause._limit_offset_value # type: ignore return util.asint(value) else: return clause -def _make_slice(limit_clause, offset_clause, start, stop): +def _make_slice( + limit_clause: Optional[Union[int, _ColumnExpressionArgument[int]]], + offset_clause: Optional[Union[int, _ColumnExpressionArgument[int]]], + start: int, + stop: int, +) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]: """Compute LIMIT/OFFSET in terms of slice start/end""" # for calculated limit/offset, try to do the addition of # values to offset in Python, however if a SQL clause is present # then the addition has to be on the SQL side. + + # TODO: typing is finding a few gaps in here, see if they can be + # closed up + if start is not None and stop is not None: offset_clause = _offset_or_limit_clause_asint_if_possible( offset_clause @@ -1163,11 +1277,12 @@ def _make_slice(limit_clause, offset_clause, start, stop): offset_clause = 0 if start != 0: - offset_clause = offset_clause + start + offset_clause = offset_clause + start # type: ignore if offset_clause == 0: offset_clause = None else: + assert offset_clause is not None offset_clause = _offset_or_limit_clause(offset_clause) limit_clause = _offset_or_limit_clause(stop - start) @@ -1182,11 +1297,13 @@ def _make_slice(limit_clause, offset_clause, start, stop): offset_clause = 0 if start != 0: - offset_clause = offset_clause + start + offset_clause = offset_clause + start # type: ignore if offset_clause == 0: offset_clause = None else: - offset_clause = _offset_or_limit_clause(offset_clause) + offset_clause = _offset_or_limit_clause( + offset_clause # type: ignore + ) - return limit_clause, offset_clause + return limit_clause, offset_clause # type: ignore |