summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-30 18:01:58 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-04 09:26:43 -0400
commit3b4d62f4f72e8dfad7f38db192a6a90a8551608c (patch)
treed0334c4bb52f803bd7dad661f2e6a12e25f5880c /lib/sqlalchemy/sql/util.py
parent4e603e23755f31278f27a45449120a8dea470a45 (diff)
downloadsqlalchemy-3b4d62f4f72e8dfad7f38db192a6a90a8551608c.tar.gz
pep484 - sql.selectable
the pep484 task becomes more intense as there is mounting pressure to come up with a consistency in how data moves from end-user to instance variable. current thinking is coming into: 1. there are _typing._XYZArgument objects that represent "what the user sent" 2. there's the roles, which represent a kind of "filter" for different kinds of objects. These are mostly important as the argument we pass to coerce(). 3. there's the thing that coerce() returns, which should be what the construct uses as its internal representation of the thing. This is _typing._XYZElement. but there's some controversy over whether or not we should pass actual ClauseElements around by their role or not. I think we shouldn't at the moment, but this makes the "role-ness" of something a little less portable. Like, we have to set DMLTableRole for TableClause, Join, and Alias, but then also we have to repeat those three types in order to set up _DMLTableElement. Other change introduced here, there was a deannotate=True for the left/right of a sql.join(). All tests pass without that. I'd rather not have that there as if we have a join(A, B) where A, B are mapped classes, we want them inside of the _annotations. The rationale seems to be performance, but this performance can be illustrated to be on the compile side which we hope is cached in the normal case. CTEs now accommodate for text selects including recursive. Get typing to accommodate "util.preloaded" cleanly; add "preloaded" as a real module. This seemed like we would have needed pep562 `__getattr__()` but we don't, just set names in globals() as we import them. References: #6810 Change-Id: I34d17f617de2fe2c086fc556bd55748dc782faf0
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py231
1 files changed, 174 insertions, 57 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index cdce49f7b..80711c4b5 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -13,10 +13,20 @@ from __future__ import annotations
from collections import deque
from itertools import chain
import typing
+from typing import AbstractSet
from typing import Any
+from typing import Callable
from typing import cast
+from typing import Dict
from typing import Iterator
+from typing import List
from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from . import coercions
from . import operators
@@ -49,11 +59,22 @@ from .selectable import Join
from .selectable import ScalarSelect
from .selectable import SelectBase
from .selectable import TableClause
+from .visitors import _ET
from .. import exc
from .. import util
+from ..util.typing import Literal
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _ColumnExpressionArgument
+ from ._typing import _TypeEngineArgument
from .roles import FromClauseRole
+ from .selectable import _JoinTargetElement
+ from .selectable import _OnClauseElement
+ from .selectable import Selectable
+ from .visitors import _TraverseCallableType
+ from .visitors import ExternallyTraversible
+ from .visitors import ExternalTraversal
from ..engine.interfaces import _AnyExecuteParams
from ..engine.interfaces import _AnyMultiExecuteParams
from ..engine.interfaces import _AnySingleExecuteParams
@@ -160,7 +181,11 @@ def find_left_clause_that_matches_given(clauses, join_from):
return liberal_idx
-def find_left_clause_to_join_from(clauses, join_to, onclause):
+def find_left_clause_to_join_from(
+ clauses: Sequence[FromClause],
+ join_to: _JoinTargetElement,
+ onclause: Optional[ColumnElement[Any]],
+) -> List[int]:
"""Given a list of FROM clauses, a selectable,
and optional ON clause, return a list of integer indexes from the
clauses list indicating the clauses that can be joined from.
@@ -189,6 +214,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause):
for i, f in enumerate(clauses):
for s in selectables.difference([f]):
if resolve_ambiguity:
+ assert cols_in_onclause is not None
if set(f.c).union(s.c).issuperset(cols_in_onclause):
idx.append(i)
break
@@ -207,7 +233,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause):
# onclause was given and none of them resolved, so assume
# all indexes can match
if not idx and onclause is not None:
- return range(len(clauses))
+ return list(range(len(clauses)))
else:
return idx
@@ -247,7 +273,7 @@ def visit_binary_product(fn, expr):
a binary comparison is passed as pairs.
"""
- stack = []
+ stack: List[ClauseElement] = []
def visit(element):
if isinstance(element, ScalarSelect):
@@ -272,21 +298,22 @@ def visit_binary_product(fn, expr):
yield e
list(visit(expr))
- visit = None # remove gc cycles
+ visit = None # type: ignore # remove gc cycles
def find_tables(
- clause,
- check_columns=False,
- include_aliases=False,
- include_joins=False,
- include_selects=False,
- include_crud=False,
-):
+ clause: ClauseElement,
+ *,
+ check_columns: bool = False,
+ include_aliases: bool = False,
+ include_joins: bool = False,
+ include_selects: bool = False,
+ include_crud: bool = False,
+) -> List[TableClause]:
"""locate Table objects within the given expression."""
- tables = []
- _visitors = {}
+ tables: List[TableClause] = []
+ _visitors: Dict[str, _TraverseCallableType[Any]] = {}
if include_selects:
_visitors["select"] = _visitors["compound_select"] = tables.append
@@ -335,7 +362,7 @@ def unwrap_order_by(clause):
t = stack.popleft()
if isinstance(t, ColumnElement) and (
not isinstance(t, UnaryExpression)
- or not operators.is_ordering_modifier(t.modifier)
+ or not operators.is_ordering_modifier(t.modifier) # type: ignore
):
if isinstance(t, Label) and not isinstance(
t.element, ScalarSelect
@@ -365,9 +392,14 @@ def unwrap_order_by(clause):
def unwrap_label_reference(element):
- def replace(elem):
- if isinstance(elem, (_label_reference, _textual_label_reference)):
- return elem.element
+ def replace(
+ element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
+ if isinstance(element, _label_reference):
+ return element.element
+ elif isinstance(element, _textual_label_reference):
+ assert False, "can't unwrap a textual label reference"
+ return None
return visitors.replacement_traverse(element, {}, replace)
@@ -407,7 +439,7 @@ def clause_is_present(clause, search):
return False
-def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]:
+def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
if isinstance(clause, Join):
for t in tables_from_leftmost(clause.left):
yield t
@@ -509,6 +541,8 @@ class _repr_base:
__slots__ = ("max_chars",)
+ max_chars: int
+
def trunc(self, value: Any) -> str:
rep = repr(value)
lenrep = len(rep)
@@ -612,7 +646,7 @@ class _repr_params(_repr_base):
def _repr_multi(
self,
multi_params: _AnyMultiExecuteParams,
- typ,
+ typ: int,
) -> str:
if multi_params:
if isinstance(multi_params[0], list):
@@ -639,7 +673,7 @@ class _repr_params(_repr_base):
def _repr_params(
self,
- params: Optional[_AnySingleExecuteParams],
+ params: _AnySingleExecuteParams,
typ: int,
) -> str:
trunc = self.trunc
@@ -653,9 +687,10 @@ class _repr_params(_repr_base):
)
)
elif typ is self._TUPLE:
+ seq_params = cast("Sequence[Any]", params)
return "(%s%s)" % (
- ", ".join(trunc(value) for value in params),
- "," if len(params) == 1 else "",
+ ", ".join(trunc(value) for value in seq_params),
+ "," if len(seq_params) == 1 else "",
)
else:
return "[%s]" % (", ".join(trunc(value) for value in params))
@@ -688,11 +723,15 @@ def adapt_criterion_to_null(crit, nulls):
return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
-def splice_joins(left, right, stop_on=None):
+def splice_joins(
+ left: Optional[FromClause],
+ right: Optional[FromClause],
+ stop_on: Optional[FromClause] = None,
+) -> Optional[FromClause]:
if left is None:
return right
- stack = [(right, None)]
+ stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
adapter = ClauseAdapter(left)
ret = None
@@ -705,6 +744,7 @@ def splice_joins(left, right, stop_on=None):
else:
right = adapter.traverse(right)
if prevright is not None:
+ assert right is not None
prevright.left = right
if ret is None:
ret = right
@@ -845,11 +885,14 @@ def criterion_as_pairs(
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
- pairs = []
+ pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
visitors.traverse(expression, {}, {"binary": visit_binary})
return pairs
+_CE = TypeVar("_CE", bound="ClauseElement")
+
+
class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
@@ -879,13 +922,15 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
def __init__(
self,
- selectable,
- equivalents=None,
- include_fn=None,
- exclude_fn=None,
- adapt_on_names=False,
- anonymize_labels=False,
- adapt_from_selectables=None,
+ selectable: Selectable,
+ equivalents: Optional[
+ Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
+ ] = None,
+ include_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ adapt_on_names: bool = False,
+ anonymize_labels: bool = False,
+ adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
):
self.__traverse_options__ = {
"stop_on": [selectable],
@@ -898,6 +943,29 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
self.adapt_on_names = adapt_on_names
self.adapt_from_selectables = adapt_from_selectables
+ if TYPE_CHECKING:
+
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ # note this specializes the ReplacingExternalTraversal.traverse()
+ # method to state
+ # that we will return the same kind of ExternalTraversal object as
+ # we were given. This is probably not 100% true, such as it's
+ # possible for us to swap out Alias for Table at the top level.
+ # Ideally there could be overloads specific to ColumnElement and
+ # FromClause but Mypy is not accepting those as compatible with
+ # the base ReplacingExternalTraversal
+ @overload
+ def traverse(self, obj: _ET) -> _ET:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
+ ...
+
def _corresponding_column(
self, col, require_embedded, _seen=util.EMPTY_SET
):
@@ -919,9 +987,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
return newcol
@util.preload_module("sqlalchemy.sql.functions")
- def replace(self, col, _include_singleton_constants=False):
+ def replace(
+ self, col: _ET, _include_singleton_constants: bool = False
+ ) -> Optional[_ET]:
functions = util.preloaded.sql_functions
+ # TODO: cython candidate
+
if isinstance(col, FromClause) and not isinstance(
col, functions.FunctionElement
):
@@ -933,7 +1005,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
break
else:
return None
- return self.selectable
+ return self.selectable # type: ignore
elif isinstance(col, Alias) and isinstance(
col.element, TableClause
):
@@ -944,7 +1016,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
# we are an alias of a table and we are not derived from an
# alias of a table (which nonetheless may be the same table
# as ours) so, same thing
- return col
+ return col # type: ignore
else:
# other cases where we are a selectable and the element
# is another join or selectable that contains a table which our
@@ -972,12 +1044,22 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
else:
return None
+ if TYPE_CHECKING:
+ assert isinstance(col, ColumnElement)
+
if self.include_fn and not self.include_fn(col):
return None
elif self.exclude_fn and self.exclude_fn(col):
return None
else:
- return self._corresponding_column(col, True)
+ return self._corresponding_column(col, True) # type: ignore
+
+
+class _ColumnLookup(Protocol):
+ def __getitem__(
+ self, key: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
+ ...
class ColumnAdapter(ClauseAdapter):
@@ -1011,17 +1093,21 @@ class ColumnAdapter(ClauseAdapter):
"""
+ columns: _ColumnLookup
+
def __init__(
self,
- selectable,
- equivalents=None,
- adapt_required=False,
- include_fn=None,
- exclude_fn=None,
- adapt_on_names=False,
- allow_label_resolve=True,
- anonymize_labels=False,
- adapt_from_selectables=None,
+ selectable: Selectable,
+ equivalents: Optional[
+ Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
+ ] = None,
+ adapt_required: bool = False,
+ include_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ adapt_on_names: bool = False,
+ allow_label_resolve: bool = True,
+ anonymize_labels: bool = False,
+ adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
):
ClauseAdapter.__init__(
self,
@@ -1034,7 +1120,7 @@ class ColumnAdapter(ClauseAdapter):
adapt_from_selectables=adapt_from_selectables,
)
- self.columns = util.WeakPopulateDict(self._locate_col)
+ self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
if self.include_fn or self.exclude_fn:
self.columns = self._IncludeExcludeMapping(self, self.columns)
self.adapt_required = adapt_required
@@ -1060,7 +1146,7 @@ class ColumnAdapter(ClauseAdapter):
ac = self.__class__.__new__(self.__class__)
ac.__dict__.update(self.__dict__)
ac._wrap = adapter
- ac.columns = util.WeakPopulateDict(ac._locate_col)
+ ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
if ac.include_fn or ac.exclude_fn:
ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
@@ -1069,6 +1155,17 @@ class ColumnAdapter(ClauseAdapter):
def traverse(self, obj):
return self.columns[obj]
+ def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
+ assert isinstance(visitor, ColumnAdapter)
+
+ return super().chain(visitor)
+
+ if TYPE_CHECKING:
+
+ @property
+ def visitor_iterator(self) -> Iterator[ColumnAdapter]:
+ ...
+
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
@@ -1080,7 +1177,9 @@ class ColumnAdapter(ClauseAdapter):
return newcol
- def _locate_col(self, col):
+ def _locate_col(
+ self, col: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
# both replace and traverse() are overly complicated for what
# we are doing here and we would do better to have an inlined
# version that doesn't build up as much overhead. the issue is that
@@ -1120,10 +1219,14 @@ class ColumnAdapter(ClauseAdapter):
def __setstate__(self, state):
self.__dict__.update(state)
- self.columns = util.WeakPopulateDict(self._locate_col)
+ self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
-def _offset_or_limit_clause(element, name=None, type_=None):
+def _offset_or_limit_clause(
+ element: Union[int, _ColumnExpressionArgument[int]],
+ name: Optional[str] = None,
+ type_: Optional[_TypeEngineArgument[int]] = None,
+) -> ColumnElement[int]:
"""Convert the given value to an "offset or limit" clause.
This handles incoming integers and converts to an expression; if
@@ -1135,7 +1238,9 @@ def _offset_or_limit_clause(element, name=None, type_=None):
)
-def _offset_or_limit_clause_asint_if_possible(clause):
+def _offset_or_limit_clause_asint_if_possible(
+ clause: Optional[Union[int, _ColumnExpressionArgument[int]]]
+) -> Optional[Union[int, _ColumnExpressionArgument[int]]]:
"""Return the offset or limit clause as a simple integer if possible,
else return the clause.
@@ -1143,18 +1248,27 @@ def _offset_or_limit_clause_asint_if_possible(clause):
if clause is None:
return None
if hasattr(clause, "_limit_offset_value"):
- value = clause._limit_offset_value
+ value = clause._limit_offset_value # type: ignore
return util.asint(value)
else:
return clause
-def _make_slice(limit_clause, offset_clause, start, stop):
+def _make_slice(
+ limit_clause: Optional[Union[int, _ColumnExpressionArgument[int]]],
+ offset_clause: Optional[Union[int, _ColumnExpressionArgument[int]]],
+ start: int,
+ stop: int,
+) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
"""Compute LIMIT/OFFSET in terms of slice start/end"""
# for calculated limit/offset, try to do the addition of
# values to offset in Python, however if a SQL clause is present
# then the addition has to be on the SQL side.
+
+ # TODO: typing is finding a few gaps in here, see if they can be
+ # closed up
+
if start is not None and stop is not None:
offset_clause = _offset_or_limit_clause_asint_if_possible(
offset_clause
@@ -1163,11 +1277,12 @@ def _make_slice(limit_clause, offset_clause, start, stop):
offset_clause = 0
if start != 0:
- offset_clause = offset_clause + start
+ offset_clause = offset_clause + start # type: ignore
if offset_clause == 0:
offset_clause = None
else:
+ assert offset_clause is not None
offset_clause = _offset_or_limit_clause(offset_clause)
limit_clause = _offset_or_limit_clause(stop - start)
@@ -1182,11 +1297,13 @@ def _make_slice(limit_clause, offset_clause, start, stop):
offset_clause = 0
if start != 0:
- offset_clause = offset_clause + start
+ offset_clause = offset_clause + start # type: ignore
if offset_clause == 0:
offset_clause = None
else:
- offset_clause = _offset_or_limit_clause(offset_clause)
+ offset_clause = _offset_or_limit_clause(
+ offset_clause # type: ignore
+ )
- return limit_clause, offset_clause
+ return limit_clause, offset_clause # type: ignore