diff options
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 44 |
1 files changed, 29 insertions, 15 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 262689128..390e23952 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -73,6 +73,7 @@ if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument + from .elements import BinaryExpression from .elements import TextClause from .selectable import _JoinTargetElement from .selectable import _SelectIterable @@ -86,8 +87,15 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.row import Row +_CE = TypeVar("_CE", bound="ColumnElement[Any]") -def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None): + +def join_condition( + a: FromClause, + b: FromClause, + a_subset: Optional[FromClause] = None, + consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None, +) -> ColumnElement[bool]: """Create a join condition between two tables or selectables. e.g.:: @@ -118,7 +126,9 @@ def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None): ) -def find_join_source(clauses, join_to): +def find_join_source( + clauses: List[FromClause], join_to: FromClause +) -> List[int]: """Given a list of FROM clauses and a selectable, return the first index and element from the list of clauses which can be joined against the selectable. returns @@ -144,7 +154,9 @@ def find_join_source(clauses, join_to): return idx -def find_left_clause_that_matches_given(clauses, join_from): +def find_left_clause_that_matches_given( + clauses: Sequence[FromClause], join_from: FromClause +) -> List[int]: """Given a list of FROM clauses and a selectable, return the indexes from the list of clauses which is derived from the selectable. @@ -243,7 +255,12 @@ def find_left_clause_to_join_from( return idx -def visit_binary_product(fn, expr): +def visit_binary_product( + fn: Callable[ + [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None + ], + expr: ColumnElement[Any], +) -> None: """Produce a traversal of the given expression, delivering column comparisons to the given function. @@ -278,19 +295,19 @@ def visit_binary_product(fn, expr): a binary comparison is passed as pairs. """ - stack: List[ClauseElement] = [] + stack: List[BinaryExpression[Any]] = [] - def visit(element): + def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]: if isinstance(element, ScalarSelect): # we don't want to dig into correlated subqueries, # those are just column elements by themselves yield element elif element.__visit_name__ == "binary" and operators.is_comparison( - element.operator + element.operator # type: ignore ): - stack.insert(0, element) - for l in visit(element.left): - for r in visit(element.right): + stack.insert(0, element) # type: ignore + for l in visit(element.left): # type: ignore + for r in visit(element.right): # type: ignore fn(stack[0], l, r) stack.pop(0) for elem in element.get_children(): @@ -502,7 +519,7 @@ def extract_first_column_annotation(column, annotation_name): return None -def selectables_overlap(left, right): +def selectables_overlap(left: FromClause, right: FromClause) -> bool: """Return True if left/right have some overlapping selectable""" return bool( @@ -701,7 +718,7 @@ class _repr_params(_repr_base): return "[%s]" % (", ".join(trunc(value) for value in params)) -def adapt_criterion_to_null(crit, nulls): +def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE: """given criterion containing bind params, convert selected elements to IS NULL. @@ -922,9 +939,6 @@ def criterion_as_pairs( return pairs -_CE = TypeVar("_CE", bound="ClauseElement") - - class ClauseAdapter(visitors.ReplacingExternalTraversal): """Clones and modifies clauses based on column correspondence. |