summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py231
1 files changed, 174 insertions, 57 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index cdce49f7b..80711c4b5 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -13,10 +13,20 @@ from __future__ import annotations
from collections import deque
from itertools import chain
import typing
+from typing import AbstractSet
from typing import Any
+from typing import Callable
from typing import cast
+from typing import Dict
from typing import Iterator
+from typing import List
from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from . import coercions
from . import operators
@@ -49,11 +59,22 @@ from .selectable import Join
from .selectable import ScalarSelect
from .selectable import SelectBase
from .selectable import TableClause
+from .visitors import _ET
from .. import exc
from .. import util
+from ..util.typing import Literal
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _ColumnExpressionArgument
+ from ._typing import _TypeEngineArgument
from .roles import FromClauseRole
+ from .selectable import _JoinTargetElement
+ from .selectable import _OnClauseElement
+ from .selectable import Selectable
+ from .visitors import _TraverseCallableType
+ from .visitors import ExternallyTraversible
+ from .visitors import ExternalTraversal
from ..engine.interfaces import _AnyExecuteParams
from ..engine.interfaces import _AnyMultiExecuteParams
from ..engine.interfaces import _AnySingleExecuteParams
@@ -160,7 +181,11 @@ def find_left_clause_that_matches_given(clauses, join_from):
return liberal_idx
-def find_left_clause_to_join_from(clauses, join_to, onclause):
+def find_left_clause_to_join_from(
+ clauses: Sequence[FromClause],
+ join_to: _JoinTargetElement,
+ onclause: Optional[ColumnElement[Any]],
+) -> List[int]:
"""Given a list of FROM clauses, a selectable,
and optional ON clause, return a list of integer indexes from the
clauses list indicating the clauses that can be joined from.
@@ -189,6 +214,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause):
for i, f in enumerate(clauses):
for s in selectables.difference([f]):
if resolve_ambiguity:
+ assert cols_in_onclause is not None
if set(f.c).union(s.c).issuperset(cols_in_onclause):
idx.append(i)
break
@@ -207,7 +233,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause):
# onclause was given and none of them resolved, so assume
# all indexes can match
if not idx and onclause is not None:
- return range(len(clauses))
+ return list(range(len(clauses)))
else:
return idx
@@ -247,7 +273,7 @@ def visit_binary_product(fn, expr):
a binary comparison is passed as pairs.
"""
- stack = []
+ stack: List[ClauseElement] = []
def visit(element):
if isinstance(element, ScalarSelect):
@@ -272,21 +298,22 @@ def visit_binary_product(fn, expr):
yield e
list(visit(expr))
- visit = None # remove gc cycles
+ visit = None # type: ignore # remove gc cycles
def find_tables(
- clause,
- check_columns=False,
- include_aliases=False,
- include_joins=False,
- include_selects=False,
- include_crud=False,
-):
+ clause: ClauseElement,
+ *,
+ check_columns: bool = False,
+ include_aliases: bool = False,
+ include_joins: bool = False,
+ include_selects: bool = False,
+ include_crud: bool = False,
+) -> List[TableClause]:
"""locate Table objects within the given expression."""
- tables = []
- _visitors = {}
+ tables: List[TableClause] = []
+ _visitors: Dict[str, _TraverseCallableType[Any]] = {}
if include_selects:
_visitors["select"] = _visitors["compound_select"] = tables.append
@@ -335,7 +362,7 @@ def unwrap_order_by(clause):
t = stack.popleft()
if isinstance(t, ColumnElement) and (
not isinstance(t, UnaryExpression)
- or not operators.is_ordering_modifier(t.modifier)
+ or not operators.is_ordering_modifier(t.modifier) # type: ignore
):
if isinstance(t, Label) and not isinstance(
t.element, ScalarSelect
@@ -365,9 +392,14 @@ def unwrap_order_by(clause):
def unwrap_label_reference(element):
- def replace(elem):
- if isinstance(elem, (_label_reference, _textual_label_reference)):
- return elem.element
+ def replace(
+ element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
+ if isinstance(element, _label_reference):
+ return element.element
+ elif isinstance(element, _textual_label_reference):
+ assert False, "can't unwrap a textual label reference"
+ return None
return visitors.replacement_traverse(element, {}, replace)
@@ -407,7 +439,7 @@ def clause_is_present(clause, search):
return False
-def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]:
+def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
if isinstance(clause, Join):
for t in tables_from_leftmost(clause.left):
yield t
@@ -509,6 +541,8 @@ class _repr_base:
__slots__ = ("max_chars",)
+ max_chars: int
+
def trunc(self, value: Any) -> str:
rep = repr(value)
lenrep = len(rep)
@@ -612,7 +646,7 @@ class _repr_params(_repr_base):
def _repr_multi(
self,
multi_params: _AnyMultiExecuteParams,
- typ,
+ typ: int,
) -> str:
if multi_params:
if isinstance(multi_params[0], list):
@@ -639,7 +673,7 @@ class _repr_params(_repr_base):
def _repr_params(
self,
- params: Optional[_AnySingleExecuteParams],
+ params: _AnySingleExecuteParams,
typ: int,
) -> str:
trunc = self.trunc
@@ -653,9 +687,10 @@ class _repr_params(_repr_base):
)
)
elif typ is self._TUPLE:
+ seq_params = cast("Sequence[Any]", params)
return "(%s%s)" % (
- ", ".join(trunc(value) for value in params),
- "," if len(params) == 1 else "",
+ ", ".join(trunc(value) for value in seq_params),
+ "," if len(seq_params) == 1 else "",
)
else:
return "[%s]" % (", ".join(trunc(value) for value in params))
@@ -688,11 +723,15 @@ def adapt_criterion_to_null(crit, nulls):
return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
-def splice_joins(left, right, stop_on=None):
+def splice_joins(
+ left: Optional[FromClause],
+ right: Optional[FromClause],
+ stop_on: Optional[FromClause] = None,
+) -> Optional[FromClause]:
if left is None:
return right
- stack = [(right, None)]
+ stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
adapter = ClauseAdapter(left)
ret = None
@@ -705,6 +744,7 @@ def splice_joins(left, right, stop_on=None):
else:
right = adapter.traverse(right)
if prevright is not None:
+ assert right is not None
prevright.left = right
if ret is None:
ret = right
@@ -845,11 +885,14 @@ def criterion_as_pairs(
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
- pairs = []
+ pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
visitors.traverse(expression, {}, {"binary": visit_binary})
return pairs
+_CE = TypeVar("_CE", bound="ClauseElement")
+
+
class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
@@ -879,13 +922,15 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
def __init__(
self,
- selectable,
- equivalents=None,
- include_fn=None,
- exclude_fn=None,
- adapt_on_names=False,
- anonymize_labels=False,
- adapt_from_selectables=None,
+ selectable: Selectable,
+ equivalents: Optional[
+ Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
+ ] = None,
+ include_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ adapt_on_names: bool = False,
+ anonymize_labels: bool = False,
+ adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
):
self.__traverse_options__ = {
"stop_on": [selectable],
@@ -898,6 +943,29 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
self.adapt_on_names = adapt_on_names
self.adapt_from_selectables = adapt_from_selectables
+ if TYPE_CHECKING:
+
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ # note this specializes the ReplacingExternalTraversal.traverse()
+ # method to state
+ # that we will return the same kind of ExternalTraversal object as
+ # we were given. This is probably not 100% true, such as it's
+ # possible for us to swap out Alias for Table at the top level.
+ # Ideally there could be overloads specific to ColumnElement and
+ # FromClause but Mypy is not accepting those as compatible with
+ # the base ReplacingExternalTraversal
+ @overload
+ def traverse(self, obj: _ET) -> _ET:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
+ ...
+
def _corresponding_column(
self, col, require_embedded, _seen=util.EMPTY_SET
):
@@ -919,9 +987,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
return newcol
@util.preload_module("sqlalchemy.sql.functions")
- def replace(self, col, _include_singleton_constants=False):
+ def replace(
+ self, col: _ET, _include_singleton_constants: bool = False
+ ) -> Optional[_ET]:
functions = util.preloaded.sql_functions
+ # TODO: cython candidate
+
if isinstance(col, FromClause) and not isinstance(
col, functions.FunctionElement
):
@@ -933,7 +1005,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
break
else:
return None
- return self.selectable
+ return self.selectable # type: ignore
elif isinstance(col, Alias) and isinstance(
col.element, TableClause
):
@@ -944,7 +1016,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
# we are an alias of a table and we are not derived from an
# alias of a table (which nonetheless may be the same table
# as ours) so, same thing
- return col
+ return col # type: ignore
else:
# other cases where we are a selectable and the element
# is another join or selectable that contains a table which our
@@ -972,12 +1044,22 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
else:
return None
+ if TYPE_CHECKING:
+ assert isinstance(col, ColumnElement)
+
if self.include_fn and not self.include_fn(col):
return None
elif self.exclude_fn and self.exclude_fn(col):
return None
else:
- return self._corresponding_column(col, True)
+ return self._corresponding_column(col, True) # type: ignore
+
+
+class _ColumnLookup(Protocol):
+ def __getitem__(
+ self, key: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
+ ...
class ColumnAdapter(ClauseAdapter):
@@ -1011,17 +1093,21 @@ class ColumnAdapter(ClauseAdapter):
"""
+ columns: _ColumnLookup
+
def __init__(
self,
- selectable,
- equivalents=None,
- adapt_required=False,
- include_fn=None,
- exclude_fn=None,
- adapt_on_names=False,
- allow_label_resolve=True,
- anonymize_labels=False,
- adapt_from_selectables=None,
+ selectable: Selectable,
+ equivalents: Optional[
+ Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
+ ] = None,
+ adapt_required: bool = False,
+ include_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ adapt_on_names: bool = False,
+ allow_label_resolve: bool = True,
+ anonymize_labels: bool = False,
+ adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
):
ClauseAdapter.__init__(
self,
@@ -1034,7 +1120,7 @@ class ColumnAdapter(ClauseAdapter):
adapt_from_selectables=adapt_from_selectables,
)
- self.columns = util.WeakPopulateDict(self._locate_col)
+ self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
if self.include_fn or self.exclude_fn:
self.columns = self._IncludeExcludeMapping(self, self.columns)
self.adapt_required = adapt_required
@@ -1060,7 +1146,7 @@ class ColumnAdapter(ClauseAdapter):
ac = self.__class__.__new__(self.__class__)
ac.__dict__.update(self.__dict__)
ac._wrap = adapter
- ac.columns = util.WeakPopulateDict(ac._locate_col)
+ ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
if ac.include_fn or ac.exclude_fn:
ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
@@ -1069,6 +1155,17 @@ class ColumnAdapter(ClauseAdapter):
def traverse(self, obj):
return self.columns[obj]
+ def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
+ assert isinstance(visitor, ColumnAdapter)
+
+ return super().chain(visitor)
+
+ if TYPE_CHECKING:
+
+ @property
+ def visitor_iterator(self) -> Iterator[ColumnAdapter]:
+ ...
+
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
@@ -1080,7 +1177,9 @@ class ColumnAdapter(ClauseAdapter):
return newcol
- def _locate_col(self, col):
+ def _locate_col(
+ self, col: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
# both replace and traverse() are overly complicated for what
# we are doing here and we would do better to have an inlined
# version that doesn't build up as much overhead. the issue is that
@@ -1120,10 +1219,14 @@ class ColumnAdapter(ClauseAdapter):
def __setstate__(self, state):
self.__dict__.update(state)
- self.columns = util.WeakPopulateDict(self._locate_col)
+ self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
-def _offset_or_limit_clause(element, name=None, type_=None):
+def _offset_or_limit_clause(
+ element: Union[int, _ColumnExpressionArgument[int]],
+ name: Optional[str] = None,
+ type_: Optional[_TypeEngineArgument[int]] = None,
+) -> ColumnElement[int]:
"""Convert the given value to an "offset or limit" clause.
This handles incoming integers and converts to an expression; if
@@ -1135,7 +1238,9 @@ def _offset_or_limit_clause(element, name=None, type_=None):
)
-def _offset_or_limit_clause_asint_if_possible(clause):
+def _offset_or_limit_clause_asint_if_possible(
+ clause: Optional[Union[int, _ColumnExpressionArgument[int]]]
+) -> Optional[Union[int, _ColumnExpressionArgument[int]]]:
"""Return the offset or limit clause as a simple integer if possible,
else return the clause.
@@ -1143,18 +1248,27 @@ def _offset_or_limit_clause_asint_if_possible(clause):
if clause is None:
return None
if hasattr(clause, "_limit_offset_value"):
- value = clause._limit_offset_value
+ value = clause._limit_offset_value # type: ignore
return util.asint(value)
else:
return clause
-def _make_slice(limit_clause, offset_clause, start, stop):
+def _make_slice(
+ limit_clause: Optional[Union[int, _ColumnExpressionArgument[int]]],
+ offset_clause: Optional[Union[int, _ColumnExpressionArgument[int]]],
+ start: int,
+ stop: int,
+) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
"""Compute LIMIT/OFFSET in terms of slice start/end"""
# for calculated limit/offset, try to do the addition of
# values to offset in Python, however if a SQL clause is present
# then the addition has to be on the SQL side.
+
+ # TODO: typing is finding a few gaps in here, see if they can be
+ # closed up
+
if start is not None and stop is not None:
offset_clause = _offset_or_limit_clause_asint_if_possible(
offset_clause
@@ -1163,11 +1277,12 @@ def _make_slice(limit_clause, offset_clause, start, stop):
offset_clause = 0
if start != 0:
- offset_clause = offset_clause + start
+ offset_clause = offset_clause + start # type: ignore
if offset_clause == 0:
offset_clause = None
else:
+ assert offset_clause is not None
offset_clause = _offset_or_limit_clause(offset_clause)
limit_clause = _offset_or_limit_clause(stop - start)
@@ -1182,11 +1297,13 @@ def _make_slice(limit_clause, offset_clause, start, stop):
offset_clause = 0
if start != 0:
- offset_clause = offset_clause + start
+ offset_clause = offset_clause + start # type: ignore
if offset_clause == 0:
offset_clause = None
else:
- offset_clause = _offset_or_limit_clause(offset_clause)
+ offset_clause = _offset_or_limit_clause(
+ offset_clause # type: ignore
+ )
- return limit_clause, offset_clause
+ return limit_clause, offset_clause # type: ignore