diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-30 18:01:58 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-04 09:26:43 -0400 |
commit | 3b4d62f4f72e8dfad7f38db192a6a90a8551608c (patch) | |
tree | d0334c4bb52f803bd7dad661f2e6a12e25f5880c /lib/sqlalchemy/sql/util.py | |
parent | 4e603e23755f31278f27a45449120a8dea470a45 (diff) | |
download | sqlalchemy-3b4d62f4f72e8dfad7f38db192a6a90a8551608c.tar.gz |
pep484 - sql.selectable
the pep484 task becomes more intense as there is mounting
pressure to come up with a consistency in how data moves
from end-user to instance variable.
current thinking is coming into:
1. there are _typing._XYZArgument objects that represent "what the
user sent"
2. there's the roles, which represent a kind of "filter" for different
kinds of objects. These are mostly important as the argument
we pass to coerce().
3. there's the thing that coerce() returns, which should be what the
construct uses as its internal representation of the thing.
This is _typing._XYZElement.
but there's some controversy over whether or
not we should pass actual ClauseElements around by their role
or not. I think we shouldn't at the moment, but this makes the
"role-ness" of something a little less portable. Like, we have
to set DMLTableRole for TableClause, Join, and Alias, but then
also we have to repeat those three types in order to set up
_DMLTableElement.
Other change introduced here, there was a deannotate=True
for the left/right of a sql.join(). All tests pass without that.
I'd rather not have that there as if we have a join(A, B) where
A, B are mapped classes, we want them inside of the _annotations.
The rationale seems to be performance, but this performance can
be illustrated to be on the compile side which we hope is cached
in the normal case.
CTEs now accommodate for text selects including recursive.
Get typing to accommodate "util.preloaded" cleanly; add "preloaded"
as a real module. This seemed like we would have needed
pep562 `__getattr__()` but we don't, just set names in
globals() as we import them.
References: #6810
Change-Id: I34d17f617de2fe2c086fc556bd55748dc782faf0
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 231 |
1 files changed, 174 insertions, 57 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index cdce49f7b..80711c4b5 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -13,10 +13,20 @@ from __future__ import annotations from collections import deque from itertools import chain import typing +from typing import AbstractSet from typing import Any +from typing import Callable from typing import cast +from typing import Dict from typing import Iterator +from typing import List from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union from . import coercions from . import operators @@ -49,11 +59,22 @@ from .selectable import Join from .selectable import ScalarSelect from .selectable import SelectBase from .selectable import TableClause +from .visitors import _ET from .. import exc from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument + from ._typing import _TypeEngineArgument from .roles import FromClauseRole + from .selectable import _JoinTargetElement + from .selectable import _OnClauseElement + from .selectable import Selectable + from .visitors import _TraverseCallableType + from .visitors import ExternallyTraversible + from .visitors import ExternalTraversal from ..engine.interfaces import _AnyExecuteParams from ..engine.interfaces import _AnyMultiExecuteParams from ..engine.interfaces import _AnySingleExecuteParams @@ -160,7 +181,11 @@ def find_left_clause_that_matches_given(clauses, join_from): return liberal_idx -def find_left_clause_to_join_from(clauses, join_to, onclause): +def find_left_clause_to_join_from( + clauses: Sequence[FromClause], + join_to: _JoinTargetElement, + onclause: Optional[ColumnElement[Any]], +) -> List[int]: """Given a list of FROM clauses, a selectable, and optional ON clause, return a list of integer indexes from the clauses list indicating the clauses that can be joined from. @@ -189,6 +214,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause): for i, f in enumerate(clauses): for s in selectables.difference([f]): if resolve_ambiguity: + assert cols_in_onclause is not None if set(f.c).union(s.c).issuperset(cols_in_onclause): idx.append(i) break @@ -207,7 +233,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause): # onclause was given and none of them resolved, so assume # all indexes can match if not idx and onclause is not None: - return range(len(clauses)) + return list(range(len(clauses))) else: return idx @@ -247,7 +273,7 @@ def visit_binary_product(fn, expr): a binary comparison is passed as pairs. """ - stack = [] + stack: List[ClauseElement] = [] def visit(element): if isinstance(element, ScalarSelect): @@ -272,21 +298,22 @@ def visit_binary_product(fn, expr): yield e list(visit(expr)) - visit = None # remove gc cycles + visit = None # type: ignore # remove gc cycles def find_tables( - clause, - check_columns=False, - include_aliases=False, - include_joins=False, - include_selects=False, - include_crud=False, -): + clause: ClauseElement, + *, + check_columns: bool = False, + include_aliases: bool = False, + include_joins: bool = False, + include_selects: bool = False, + include_crud: bool = False, +) -> List[TableClause]: """locate Table objects within the given expression.""" - tables = [] - _visitors = {} + tables: List[TableClause] = [] + _visitors: Dict[str, _TraverseCallableType[Any]] = {} if include_selects: _visitors["select"] = _visitors["compound_select"] = tables.append @@ -335,7 +362,7 @@ def unwrap_order_by(clause): t = stack.popleft() if isinstance(t, ColumnElement) and ( not isinstance(t, UnaryExpression) - or not operators.is_ordering_modifier(t.modifier) + or not operators.is_ordering_modifier(t.modifier) # type: ignore ): if isinstance(t, Label) and not isinstance( t.element, ScalarSelect @@ -365,9 +392,14 @@ def unwrap_order_by(clause): def unwrap_label_reference(element): - def replace(elem): - if isinstance(elem, (_label_reference, _textual_label_reference)): - return elem.element + def replace( + element: ExternallyTraversible, **kw: Any + ) -> Optional[ExternallyTraversible]: + if isinstance(element, _label_reference): + return element.element + elif isinstance(element, _textual_label_reference): + assert False, "can't unwrap a textual label reference" + return None return visitors.replacement_traverse(element, {}, replace) @@ -407,7 +439,7 @@ def clause_is_present(clause, search): return False -def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]: +def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]: if isinstance(clause, Join): for t in tables_from_leftmost(clause.left): yield t @@ -509,6 +541,8 @@ class _repr_base: __slots__ = ("max_chars",) + max_chars: int + def trunc(self, value: Any) -> str: rep = repr(value) lenrep = len(rep) @@ -612,7 +646,7 @@ class _repr_params(_repr_base): def _repr_multi( self, multi_params: _AnyMultiExecuteParams, - typ, + typ: int, ) -> str: if multi_params: if isinstance(multi_params[0], list): @@ -639,7 +673,7 @@ class _repr_params(_repr_base): def _repr_params( self, - params: Optional[_AnySingleExecuteParams], + params: _AnySingleExecuteParams, typ: int, ) -> str: trunc = self.trunc @@ -653,9 +687,10 @@ class _repr_params(_repr_base): ) ) elif typ is self._TUPLE: + seq_params = cast("Sequence[Any]", params) return "(%s%s)" % ( - ", ".join(trunc(value) for value in params), - "," if len(params) == 1 else "", + ", ".join(trunc(value) for value in seq_params), + "," if len(seq_params) == 1 else "", ) else: return "[%s]" % (", ".join(trunc(value) for value in params)) @@ -688,11 +723,15 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) -def splice_joins(left, right, stop_on=None): +def splice_joins( + left: Optional[FromClause], + right: Optional[FromClause], + stop_on: Optional[FromClause] = None, +) -> Optional[FromClause]: if left is None: return right - stack = [(right, None)] + stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)] adapter = ClauseAdapter(left) ret = None @@ -705,6 +744,7 @@ def splice_joins(left, right, stop_on=None): else: right = adapter.traverse(right) if prevright is not None: + assert right is not None prevright.left = right if ret is None: ret = right @@ -845,11 +885,14 @@ def criterion_as_pairs( elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) - pairs = [] + pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = [] visitors.traverse(expression, {}, {"binary": visit_binary}) return pairs +_CE = TypeVar("_CE", bound="ClauseElement") + + class ClauseAdapter(visitors.ReplacingExternalTraversal): """Clones and modifies clauses based on column correspondence. @@ -879,13 +922,15 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): def __init__( self, - selectable, - equivalents=None, - include_fn=None, - exclude_fn=None, - adapt_on_names=False, - anonymize_labels=False, - adapt_from_selectables=None, + selectable: Selectable, + equivalents: Optional[ + Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] + ] = None, + include_fn: Optional[Callable[[ClauseElement], bool]] = None, + exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, + adapt_on_names: bool = False, + anonymize_labels: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, ): self.__traverse_options__ = { "stop_on": [selectable], @@ -898,6 +943,29 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): self.adapt_on_names = adapt_on_names self.adapt_from_selectables = adapt_from_selectables + if TYPE_CHECKING: + + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + # note this specializes the ReplacingExternalTraversal.traverse() + # method to state + # that we will return the same kind of ExternalTraversal object as + # we were given. This is probably not 100% true, such as it's + # possible for us to swap out Alias for Table at the top level. + # Ideally there could be overloads specific to ColumnElement and + # FromClause but Mypy is not accepting those as compatible with + # the base ReplacingExternalTraversal + @overload + def traverse(self, obj: _ET) -> _ET: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: + ... + def _corresponding_column( self, col, require_embedded, _seen=util.EMPTY_SET ): @@ -919,9 +987,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): return newcol @util.preload_module("sqlalchemy.sql.functions") - def replace(self, col, _include_singleton_constants=False): + def replace( + self, col: _ET, _include_singleton_constants: bool = False + ) -> Optional[_ET]: functions = util.preloaded.sql_functions + # TODO: cython candidate + if isinstance(col, FromClause) and not isinstance( col, functions.FunctionElement ): @@ -933,7 +1005,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): break else: return None - return self.selectable + return self.selectable # type: ignore elif isinstance(col, Alias) and isinstance( col.element, TableClause ): @@ -944,7 +1016,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): # we are an alias of a table and we are not derived from an # alias of a table (which nonetheless may be the same table # as ours) so, same thing - return col + return col # type: ignore else: # other cases where we are a selectable and the element # is another join or selectable that contains a table which our @@ -972,12 +1044,22 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): else: return None + if TYPE_CHECKING: + assert isinstance(col, ColumnElement) + if self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): return None else: - return self._corresponding_column(col, True) + return self._corresponding_column(col, True) # type: ignore + + +class _ColumnLookup(Protocol): + def __getitem__( + self, key: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: + ... class ColumnAdapter(ClauseAdapter): @@ -1011,17 +1093,21 @@ class ColumnAdapter(ClauseAdapter): """ + columns: _ColumnLookup + def __init__( self, - selectable, - equivalents=None, - adapt_required=False, - include_fn=None, - exclude_fn=None, - adapt_on_names=False, - allow_label_resolve=True, - anonymize_labels=False, - adapt_from_selectables=None, + selectable: Selectable, + equivalents: Optional[ + Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] + ] = None, + adapt_required: bool = False, + include_fn: Optional[Callable[[ClauseElement], bool]] = None, + exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, + adapt_on_names: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, ): ClauseAdapter.__init__( self, @@ -1034,7 +1120,7 @@ class ColumnAdapter(ClauseAdapter): adapt_from_selectables=adapt_from_selectables, ) - self.columns = util.WeakPopulateDict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore if self.include_fn or self.exclude_fn: self.columns = self._IncludeExcludeMapping(self, self.columns) self.adapt_required = adapt_required @@ -1060,7 +1146,7 @@ class ColumnAdapter(ClauseAdapter): ac = self.__class__.__new__(self.__class__) ac.__dict__.update(self.__dict__) ac._wrap = adapter - ac.columns = util.WeakPopulateDict(ac._locate_col) + ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore if ac.include_fn or ac.exclude_fn: ac.columns = self._IncludeExcludeMapping(ac, ac.columns) @@ -1069,6 +1155,17 @@ class ColumnAdapter(ClauseAdapter): def traverse(self, obj): return self.columns[obj] + def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: + assert isinstance(visitor, ColumnAdapter) + + return super().chain(visitor) + + if TYPE_CHECKING: + + @property + def visitor_iterator(self) -> Iterator[ColumnAdapter]: + ... + adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process @@ -1080,7 +1177,9 @@ class ColumnAdapter(ClauseAdapter): return newcol - def _locate_col(self, col): + def _locate_col( + self, col: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: # both replace and traverse() are overly complicated for what # we are doing here and we would do better to have an inlined # version that doesn't build up as much overhead. the issue is that @@ -1120,10 +1219,14 @@ class ColumnAdapter(ClauseAdapter): def __setstate__(self, state): self.__dict__.update(state) - self.columns = util.WeakPopulateDict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore -def _offset_or_limit_clause(element, name=None, type_=None): +def _offset_or_limit_clause( + element: Union[int, _ColumnExpressionArgument[int]], + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[int]] = None, +) -> ColumnElement[int]: """Convert the given value to an "offset or limit" clause. This handles incoming integers and converts to an expression; if @@ -1135,7 +1238,9 @@ def _offset_or_limit_clause(element, name=None, type_=None): ) -def _offset_or_limit_clause_asint_if_possible(clause): +def _offset_or_limit_clause_asint_if_possible( + clause: Optional[Union[int, _ColumnExpressionArgument[int]]] +) -> Optional[Union[int, _ColumnExpressionArgument[int]]]: """Return the offset or limit clause as a simple integer if possible, else return the clause. @@ -1143,18 +1248,27 @@ def _offset_or_limit_clause_asint_if_possible(clause): if clause is None: return None if hasattr(clause, "_limit_offset_value"): - value = clause._limit_offset_value + value = clause._limit_offset_value # type: ignore return util.asint(value) else: return clause -def _make_slice(limit_clause, offset_clause, start, stop): +def _make_slice( + limit_clause: Optional[Union[int, _ColumnExpressionArgument[int]]], + offset_clause: Optional[Union[int, _ColumnExpressionArgument[int]]], + start: int, + stop: int, +) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]: """Compute LIMIT/OFFSET in terms of slice start/end""" # for calculated limit/offset, try to do the addition of # values to offset in Python, however if a SQL clause is present # then the addition has to be on the SQL side. + + # TODO: typing is finding a few gaps in here, see if they can be + # closed up + if start is not None and stop is not None: offset_clause = _offset_or_limit_clause_asint_if_possible( offset_clause @@ -1163,11 +1277,12 @@ def _make_slice(limit_clause, offset_clause, start, stop): offset_clause = 0 if start != 0: - offset_clause = offset_clause + start + offset_clause = offset_clause + start # type: ignore if offset_clause == 0: offset_clause = None else: + assert offset_clause is not None offset_clause = _offset_or_limit_clause(offset_clause) limit_clause = _offset_or_limit_clause(stop - start) @@ -1182,11 +1297,13 @@ def _make_slice(limit_clause, offset_clause, start, stop): offset_clause = 0 if start != 0: - offset_clause = offset_clause + start + offset_clause = offset_clause + start # type: ignore if offset_clause == 0: offset_clause = None else: - offset_clause = _offset_or_limit_clause(offset_clause) + offset_clause = _offset_or_limit_clause( + offset_clause # type: ignore + ) - return limit_clause, offset_clause + return limit_clause, offset_clause # type: ignore |