summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py44
1 files changed, 29 insertions, 15 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 262689128..390e23952 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -73,6 +73,7 @@ if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
+ from .elements import BinaryExpression
from .elements import TextClause
from .selectable import _JoinTargetElement
from .selectable import _SelectIterable
@@ -86,8 +87,15 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.row import Row
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
-def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
+
+def join_condition(
+ a: FromClause,
+ b: FromClause,
+ a_subset: Optional[FromClause] = None,
+ consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
+) -> ColumnElement[bool]:
"""Create a join condition between two tables or selectables.
e.g.::
@@ -118,7 +126,9 @@ def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
)
-def find_join_source(clauses, join_to):
+def find_join_source(
+ clauses: List[FromClause], join_to: FromClause
+) -> List[int]:
"""Given a list of FROM clauses and a selectable,
return the first index and element from the list of
clauses which can be joined against the selectable. returns
@@ -144,7 +154,9 @@ def find_join_source(clauses, join_to):
return idx
-def find_left_clause_that_matches_given(clauses, join_from):
+def find_left_clause_that_matches_given(
+ clauses: Sequence[FromClause], join_from: FromClause
+) -> List[int]:
"""Given a list of FROM clauses and a selectable,
return the indexes from the list of
clauses which is derived from the selectable.
@@ -243,7 +255,12 @@ def find_left_clause_to_join_from(
return idx
-def visit_binary_product(fn, expr):
+def visit_binary_product(
+ fn: Callable[
+ [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
+ ],
+ expr: ColumnElement[Any],
+) -> None:
"""Produce a traversal of the given expression, delivering
column comparisons to the given function.
@@ -278,19 +295,19 @@ def visit_binary_product(fn, expr):
a binary comparison is passed as pairs.
"""
- stack: List[ClauseElement] = []
+ stack: List[BinaryExpression[Any]] = []
- def visit(element):
+ def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
if isinstance(element, ScalarSelect):
# we don't want to dig into correlated subqueries,
# those are just column elements by themselves
yield element
elif element.__visit_name__ == "binary" and operators.is_comparison(
- element.operator
+ element.operator # type: ignore
):
- stack.insert(0, element)
- for l in visit(element.left):
- for r in visit(element.right):
+ stack.insert(0, element) # type: ignore
+ for l in visit(element.left): # type: ignore
+ for r in visit(element.right): # type: ignore
fn(stack[0], l, r)
stack.pop(0)
for elem in element.get_children():
@@ -502,7 +519,7 @@ def extract_first_column_annotation(column, annotation_name):
return None
-def selectables_overlap(left, right):
+def selectables_overlap(left: FromClause, right: FromClause) -> bool:
"""Return True if left/right have some overlapping selectable"""
return bool(
@@ -701,7 +718,7 @@ class _repr_params(_repr_base):
return "[%s]" % (", ".join(trunc(value) for value in params))
-def adapt_criterion_to_null(crit, nulls):
+def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
"""given criterion containing bind params, convert selected elements
to IS NULL.
@@ -922,9 +939,6 @@ def criterion_as_pairs(
return pairs
-_CE = TypeVar("_CE", bound="ClauseElement")
-
-
class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.