summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/elements.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-30 18:01:58 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-04 09:26:43 -0400
commit3b4d62f4f72e8dfad7f38db192a6a90a8551608c (patch)
treed0334c4bb52f803bd7dad661f2e6a12e25f5880c /lib/sqlalchemy/sql/elements.py
parent4e603e23755f31278f27a45449120a8dea470a45 (diff)
downloadsqlalchemy-3b4d62f4f72e8dfad7f38db192a6a90a8551608c.tar.gz
pep484 - sql.selectable
the pep484 task becomes more intense as there is mounting pressure to come up with a consistency in how data moves from end-user to instance variable. current thinking is coming into: 1. there are _typing._XYZArgument objects that represent "what the user sent" 2. there's the roles, which represent a kind of "filter" for different kinds of objects. These are mostly important as the argument we pass to coerce(). 3. there's the thing that coerce() returns, which should be what the construct uses as its internal representation of the thing. This is _typing._XYZElement. but there's some controversy over whether or not we should pass actual ClauseElements around by their role or not. I think we shouldn't at the moment, but this makes the "role-ness" of something a little less portable. Like, we have to set DMLTableRole for TableClause, Join, and Alias, but then also we have to repeat those three types in order to set up _DMLTableElement. Other change introduced here, there was a deannotate=True for the left/right of a sql.join(). All tests pass without that. I'd rather not have that there as if we have a join(A, B) where A, B are mapped classes, we want them inside of the _annotations. The rationale seems to be performance, but this performance can be illustrated to be on the compile side which we hope is cached in the normal case. CTEs now accommodate for text selects including recursive. Get typing to accommodate "util.preloaded" cleanly; add "preloaded" as a real module. This seemed like we would have needed pep562 `__getattr__()` but we don't, just set names in globals() as we import them. References: #6810 Change-Id: I34d17f617de2fe2c086fc556bd55748dc782faf0
Diffstat (limited to 'lib/sqlalchemy/sql/elements.py')
-rw-r--r--lib/sqlalchemy/sql/elements.py85
1 files changed, 53 insertions, 32 deletions
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index c735085f8..aec29d1b2 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -26,6 +26,7 @@ from typing import Dict
from typing import FrozenSet
from typing import Generic
from typing import Iterable
+from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
@@ -77,8 +78,8 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _PropagateAttrsType
- from ._typing import _SelectIterable
from ._typing import _TypeEngineArgument
+ from .cache_key import _CacheKeyTraversalType
from .cache_key import CacheKey
from .compiler import Compiled
from .compiler import SQLCompiler
@@ -88,6 +89,7 @@ if typing.TYPE_CHECKING:
from .schema import DefaultGenerator
from .schema import FetchedValue
from .schema import ForeignKey
+ from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import NamedFromClause
from .selectable import ReturnsRows
@@ -96,6 +98,7 @@ if typing.TYPE_CHECKING:
from .sqltypes import Boolean
from .sqltypes import TupleType
from .type_api import TypeEngine
+ from .visitors import _CloneCallableType
from .visitors import _TraverseInternalsType
from ..engine import Connection
from ..engine import Dialect
@@ -310,6 +313,7 @@ class ClauseElement(
_is_text_clause = False
_is_from_container = False
_is_select_container = False
+ _is_select_base = False
_is_select_statement = False
_is_bind_parameter = False
_is_clause_list = False
@@ -321,7 +325,7 @@ class ClauseElement(
def _order_by_label_element(self) -> Optional[Label[Any]]:
return None
- _cache_key_traversal = None
+ _cache_key_traversal: _CacheKeyTraversalType = None
negation_clause: ColumnElement[bool]
@@ -528,7 +532,7 @@ class ClauseElement(
"""
return traversals.compare(self, other, **kw)
- def self_group(self, against=None):
+ def self_group(self, against: Optional[OperatorType] = None) -> Any:
"""Apply a 'grouping' to this :class:`_expression.ClauseElement`.
This method is overridden by subclasses to return a "grouping"
@@ -637,9 +641,9 @@ class ClauseElement(
return self._negate()
def _negate(self) -> ClauseElement:
- return UnaryExpression(
- self.self_group(against=operators.inv), operator=operators.inv
- )
+ grouped = self.self_group(against=operators.inv)
+ assert isinstance(grouped, ColumnElement)
+ return UnaryExpression(grouped, operator=operators.inv)
def __bool__(self):
raise TypeError("Boolean value of this clause is not defined")
@@ -1290,12 +1294,6 @@ class ColumnElement(
@overload
def self_group(
- self: ColumnElement[bool], against: Optional[OperatorType] = None
- ) -> ColumnElement[bool]:
- ...
-
- @overload
- def self_group(
self: ColumnElement[Any], against: Optional[OperatorType] = None
) -> ColumnElement[Any]:
...
@@ -1764,6 +1762,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
key: str
type: TypeEngine[_T]
+ value: Optional[_T]
_is_crud = False
_is_bind_parameter = True
@@ -1883,7 +1882,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
return cloned
@property
- def effective_value(self):
+ def effective_value(self) -> Optional[_T]:
"""Return the value of this bound parameter,
taking into account if the ``callable`` parameter
was set.
@@ -1893,11 +1892,12 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
"""
if self.callable:
- return self.callable()
+ # TODO: set up protocol for bind parameter callable
+ return self.callable() # type: ignore
else:
return self.value
- def render_literal_execute(self):
+ def render_literal_execute(self) -> BindParameter[_T]:
"""Produce a copy of this bound parameter that will enable the
:paramref:`_sql.BindParameter.literal_execute` flag.
@@ -2513,8 +2513,10 @@ class ClauseList(
self.operator = operator
self.group = group
self.group_contents = group_contents
+ clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses
if _flatten_sub_clauses:
- clauses = util.flatten_iterator(clauses)
+ clauses_iterator = util.flatten_iterator(clauses_iterator)
+
self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
@@ -2523,31 +2525,35 @@ class ClauseList(
coercions.expect(
text_converter_role, clause, apply_propagate_attrs=self
).self_group(against=self.operator)
- for clause in clauses
+ for clause in clauses_iterator
]
else:
self.clauses = [
coercions.expect(
text_converter_role, clause, apply_propagate_attrs=self
)
- for clause in clauses
+ for clause in clauses_iterator
]
self._is_implicitly_boolean = operators.is_boolean(self.operator)
@classmethod
- def _construct_raw(cls, operator, clauses=None):
+ def _construct_raw(
+ cls,
+ operator: OperatorType,
+ clauses: Optional[Sequence[ColumnElement[Any]]] = None,
+ ) -> ClauseList:
self = cls.__new__(cls)
- self.clauses = clauses if clauses else []
+ self.clauses = list(clauses) if clauses else []
self.group = True
self.operator = operator
self.group_contents = True
self._is_implicitly_boolean = False
return self
- def __iter__(self):
+ def __iter__(self) -> Iterator[ColumnElement[Any]]:
return iter(self.clauses)
- def __len__(self):
+ def __len__(self) -> int:
return len(self.clauses)
@property
@@ -2708,10 +2714,10 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
def _construct_raw(
cls,
operator: OperatorType,
- clauses: Optional[List[ColumnElement[Any]]] = None,
+ clauses: Optional[Sequence[ColumnElement[Any]]] = None,
) -> BooleanClauseList:
self = cls.__new__(cls)
- self.clauses = clauses if clauses else []
+ self.clauses = list(clauses) if clauses else []
self.group = True
self.operator = operator
self.group_contents = True
@@ -2781,7 +2787,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
sqltypes = util.preloaded.sql_sqltypes
if types is None:
- init_clauses = [
+ init_clauses: List[ColumnElement[Any]] = [
coercions.expect(roles.ExpressionElementRole, c)
for c in clauses
]
@@ -2908,7 +2914,7 @@ class Case(ColumnElement[_T]):
]
if whenlist:
- type_ = list(whenlist[-1])[-1].type
+ type_ = whenlist[-1][-1].type
else:
type_ = None
@@ -3098,6 +3104,8 @@ class _label_reference(ColumnElement[_T]):
("element", InternalTraversal.dp_clauseelement)
]
+ element: ColumnElement[_T]
+
def __init__(self, element: ColumnElement[_T]):
self.element = element
@@ -3212,7 +3220,9 @@ class UnaryExpression(ColumnElement[_T]):
cls,
expr: _ColumnExpressionArgument[_T],
) -> UnaryExpression[_T]:
- col_expr = coercions.expect(roles.ExpressionElementRole, expr)
+ col_expr: ColumnElement[_T] = coercions.expect(
+ roles.ExpressionElementRole, expr
+ )
return UnaryExpression(
col_expr,
operator=operators.distinct_op,
@@ -3265,7 +3275,7 @@ class CollectionAggregate(UnaryExpression[_T]):
def _create_any(
cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
- col_expr = coercions.expect(
+ col_expr: ColumnElement[_T] = coercions.expect(
roles.ExpressionElementRole,
expr,
)
@@ -3281,7 +3291,7 @@ class CollectionAggregate(UnaryExpression[_T]):
def _create_all(
cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
- col_expr = coercions.expect(
+ col_expr: ColumnElement[_T] = coercions.expect(
roles.ExpressionElementRole,
expr,
)
@@ -3374,6 +3384,9 @@ class BinaryExpression(ColumnElement[_T]):
modifiers: Optional[Mapping[str, Any]]
+ left: ColumnElement[Any]
+ right: Union[ColumnElement[Any], ClauseList]
+
def __init__(
self,
left: ColumnElement[Any],
@@ -4147,7 +4160,13 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
def foreign_keys(self):
return self.element.foreign_keys
- def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+ def _copy_internals(
+ self,
+ *,
+ clone: _CloneCallableType = _clone,
+ anonymize_labels: bool = False,
+ **kw: Any,
+ ) -> None:
self._reset_memoizations()
self._element = clone(self._element, **kw)
if anonymize_labels:
@@ -4447,7 +4466,9 @@ class TableValuedColumn(NamedColumn[_T]):
self.key = self.name = scalar_alias.name
self.type = type_
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
self.scalar_alias = clone(self.scalar_alias, **kw)
self.key = self.name = self.scalar_alias.name
@@ -4467,7 +4488,7 @@ class CollationClause(ColumnElement[str]):
def _create_collation_expression(
cls, expression: _ColumnExpressionArgument[str], collation: str
) -> BinaryExpression[str]:
- expr = coercions.expect(roles.ExpressionElementRole, expression)
+ expr = coercions.expect(roles.ExpressionElementRole[str], expression)
return BinaryExpression(
expr,
CollationClause(collation),