summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/mapped_collection.py
diff options
context:
space:
mode:
authorMaksim Latysh <m.latysh@godeltech.com>2023-01-24 11:03:44 -0500
committersqla-tester <sqla-tester@sqlalchemy.org>2023-01-24 11:03:44 -0500
commit173e4164834e7ac5c77184a425a32f9afd088af4 (patch)
tree88c1477ae4b277870b98808c3e70caacd8f73065 /lib/sqlalchemy/orm/mapped_collection.py
parent4d2f24e524c99d8255f451476679f5fa93647ad4 (diff)
downloadsqlalchemy-173e4164834e7ac5c77184a425a32f9afd088af4.tar.gz
Type annotations for sqlalchemy.orm.mapped_collection
<!-- Provide a general summary of your proposed changes in the Title field above --> ### Description <!-- Describe your changes in detail --> An attempt to annotate lib/sqlalchemy/orm/mapped_collection.py with type hints (issue https://github.com/sqlalchemy/sqlalchemy/issues/6810) ### Checklist <!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once) --> This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #<issue number>` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #<issue number>` in the commit message - please include tests. Closes: #9140 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9140 Pull-request-sha: facb4717134943dd651905f7c72618eb66a9eca5 Change-Id: I0fb80e2ea7ed2247c494487fb6c8d72efb4e9802
Diffstat (limited to 'lib/sqlalchemy/orm/mapped_collection.py')
-rw-r--r--lib/sqlalchemy/orm/mapped_collection.py136
1 files changed, 97 insertions, 39 deletions
diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py
index 8a65f847a..a2b085c76 100644
--- a/lib/sqlalchemy/orm/mapped_collection.py
+++ b/lib/sqlalchemy/orm/mapped_collection.py
@@ -4,14 +4,15 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Generic
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from . import base
@@ -22,11 +23,24 @@ from ..sql import coercions
from ..sql import expression
from ..sql import roles
+if TYPE_CHECKING:
+ from typing import List
+ from typing import Optional
+ from typing import Sequence
+ from typing import Tuple
+ from typing import Union
+
+ from . import AttributeEventToken
+ from . import Mapper
+ from ..sql.elements import ColumnElement
+
_KT = TypeVar("_KT", bound=Any)
_VT = TypeVar("_VT", bound=Any)
+_F = TypeVar("_F", bound=Callable[[Any], Any])
-class _PlainColumnGetter:
+
+class _PlainColumnGetter(Generic[_KT]):
"""Plain column getter, stores collection of Column objects
directly.
@@ -38,21 +52,26 @@ class _PlainColumnGetter:
__slots__ = ("cols", "composite")
- def __init__(self, cols):
+ def __init__(self, cols: Sequence[ColumnElement[_KT]]) -> None:
self.cols = cols
self.composite = len(cols) > 1
- def __reduce__(self):
+ def __reduce__(
+ self,
+ ) -> Tuple[
+ Type[_SerializableColumnGetterV2[_KT]],
+ Tuple[Sequence[Tuple[Optional[str], Optional[str]]]],
+ ]:
return _SerializableColumnGetterV2._reduce_from_cols(self.cols)
- def _cols(self, mapper):
+ def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]:
return self.cols
- def __call__(self, value):
+ def __call__(self, value: _KT) -> Union[_KT, Tuple[_KT, ...]]:
state = base.instance_state(value)
m = base._state_mapper(state)
- key = [
+ key: List[_KT] = [
m._get_state_attr_by_column(state, state.dict, col)
for col in self._cols(m)
]
@@ -62,7 +81,7 @@ class _PlainColumnGetter:
return key[0]
-class _SerializableColumnGetterV2(_PlainColumnGetter):
+class _SerializableColumnGetterV2(_PlainColumnGetter[_KT]):
"""Updated serializable getter which deals with
multi-table mapped classes.
@@ -76,38 +95,52 @@ class _SerializableColumnGetterV2(_PlainColumnGetter):
__slots__ = ("colkeys",)
- def __init__(self, colkeys):
+ def __init__(
+ self, colkeys: Sequence[Tuple[Optional[str], Optional[str]]]
+ ) -> None:
self.colkeys = colkeys
self.composite = len(colkeys) > 1
- def __reduce__(self):
+ def __reduce__(
+ self,
+ ) -> Tuple[
+ Type[_SerializableColumnGetterV2[_KT]],
+ Tuple[Sequence[Tuple[Optional[str], Optional[str]]]],
+ ]:
return self.__class__, (self.colkeys,)
@classmethod
- def _reduce_from_cols(cls, cols):
- def _table_key(c):
+ def _reduce_from_cols(
+ cls, cols: Sequence[ColumnElement[_KT]]
+ ) -> Tuple[
+ Type[_SerializableColumnGetterV2[_KT]],
+ Tuple[Sequence[Tuple[Optional[str], Optional[str]]]],
+ ]:
+ def _table_key(c: ColumnElement[_KT]) -> Optional[str]:
if not isinstance(c.table, expression.TableClause):
return None
else:
- return c.table.key
+ return c.table.key # type: ignore
colkeys = [(c.key, _table_key(c)) for c in cols]
return _SerializableColumnGetterV2, (colkeys,)
- def _cols(self, mapper):
- cols = []
+ def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]:
+ cols: List[ColumnElement[_KT]] = []
metadata = getattr(mapper.local_table, "metadata", None)
for (ckey, tkey) in self.colkeys:
if tkey is None or metadata is None or tkey not in metadata:
- cols.append(mapper.local_table.c[ckey])
+ cols.append(mapper.local_table.c[ckey]) # type: ignore
else:
cols.append(metadata.tables[tkey].c[ckey])
return cols
def column_keyed_dict(
- mapping_spec, *, ignore_unpopulated_attribute: bool = False
-):
+ mapping_spec: Union[Type[_KT], Callable[[_KT], _VT]],
+ *,
+ ignore_unpopulated_attribute: bool = False,
+) -> Type[KeyFuncDict[_KT, _KT]]:
"""A dictionary-based collection type with column-based keying.
.. versionchanged:: 2.0 Renamed :data:`.column_mapped_collection` to
@@ -155,7 +188,8 @@ def column_keyed_dict(
]
keyfunc = _PlainColumnGetter(cols)
return _mapped_collection_cls(
- keyfunc, ignore_unpopulated_attribute=ignore_unpopulated_attribute
+ keyfunc,
+ ignore_unpopulated_attribute=ignore_unpopulated_attribute,
)
@@ -169,13 +203,13 @@ class _AttrGetter:
dict_ = base.instance_dict(mapped_object)
return dict_.get(self.attr_name, base.NO_VALUE)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Type[_AttrGetter], Tuple[str]]:
return _AttrGetter, (self.attr_name,)
def attribute_keyed_dict(
attr_name: str, *, ignore_unpopulated_attribute: bool = False
-) -> Type[KeyFuncDict]:
+) -> Type[KeyFuncDict[_KT, _KT]]:
"""A dictionary-based collection type with attribute-based keying.
.. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to
@@ -223,7 +257,7 @@ def attribute_keyed_dict(
def keyfunc_mapping(
- keyfunc: Callable[[Any], _KT],
+ keyfunc: _F,
*,
ignore_unpopulated_attribute: bool = False,
) -> Type[KeyFuncDict[_KT, Any]]:
@@ -297,7 +331,12 @@ class KeyFuncDict(Dict[_KT, _VT]):
"""
- def __init__(self, keyfunc, *, ignore_unpopulated_attribute=False):
+ def __init__(
+ self,
+ keyfunc: _F,
+ *,
+ ignore_unpopulated_attribute: bool = False,
+ ) -> None:
"""Create a new collection with keying provided by keyfunc.
keyfunc may be any callable that takes an object and returns an object
@@ -315,21 +354,30 @@ class KeyFuncDict(Dict[_KT, _VT]):
self.ignore_unpopulated_attribute = ignore_unpopulated_attribute
@classmethod
- def _unreduce(cls, keyfunc, values):
- mp = KeyFuncDict(keyfunc)
+ def _unreduce(
+ cls, keyfunc: _F, values: Dict[_KT, _KT]
+ ) -> "KeyFuncDict[_KT, _KT]":
+ mp: KeyFuncDict[_KT, _KT] = KeyFuncDict(keyfunc)
mp.update(values)
return mp
- def __reduce__(self):
+ def __reduce__(
+ self,
+ ) -> Tuple[
+ Callable[[_KT, _KT], KeyFuncDict[_KT, _KT]],
+ Tuple[Any, Union[Dict[_KT, _KT], Dict[_KT, _KT]]],
+ ]:
return (KeyFuncDict._unreduce, (self.keyfunc, dict(self)))
- def _raise_for_unpopulated(self, value, initiator):
+ def _raise_for_unpopulated(
+ self, value: _KT, initiator: Optional[AttributeEventToken]
+ ) -> None:
mapper = base.instance_state(value).mapper
if initiator is None:
relationship = "unknown relationship"
else:
- relationship = mapper.attrs[initiator.key]
+ relationship = f"{mapper.attrs[initiator.key]}"
raise sa_exc.InvalidRequestError(
f"In event triggered from population of attribute {relationship} "
@@ -345,9 +393,13 @@ class KeyFuncDict(Dict[_KT, _VT]):
f"parameter on the mapped collection factory."
)
- @collection.appender
- @collection.internally_instrumented
- def set(self, value, _sa_initiator=None):
+ @collection.appender # type: ignore[misc]
+ @collection.internally_instrumented # type: ignore[misc]
+ def set(
+ self,
+ value: _KT,
+ _sa_initiator: Optional[AttributeEventToken] = None,
+ ) -> None:
"""Add an item by value, consulting the keyfunc for the key."""
key = self.keyfunc(value)
@@ -358,11 +410,15 @@ class KeyFuncDict(Dict[_KT, _VT]):
else:
return
- self.__setitem__(key, value, _sa_initiator)
+ self.__setitem__(key, value, _sa_initiator) # type: ignore[call-arg]
- @collection.remover
- @collection.internally_instrumented
- def remove(self, value, _sa_initiator=None):
+ @collection.remover # type: ignore[misc]
+ @collection.internally_instrumented # type: ignore[misc]
+ def remove(
+ self,
+ value: _KT,
+ _sa_initiator: Optional[AttributeEventToken] = None,
+ ) -> None:
"""Remove an item by value, consulting the keyfunc for the key."""
key = self.keyfunc(value)
@@ -381,12 +437,14 @@ class KeyFuncDict(Dict[_KT, _VT]):
"based on mutable properties or properties that only obtain "
"values after flush?" % (value, self[key], key)
)
- self.__delitem__(key, _sa_initiator)
+ self.__delitem__(key, _sa_initiator) # type: ignore[call-arg]
-def _mapped_collection_cls(keyfunc, ignore_unpopulated_attribute):
- class _MKeyfuncMapped(KeyFuncDict):
- def __init__(self):
+def _mapped_collection_cls(
+ keyfunc: _F, ignore_unpopulated_attribute: bool
+) -> Type[KeyFuncDict[_KT, _KT]]:
+ class _MKeyfuncMapped(KeyFuncDict[_KT, _KT]):
+ def __init__(self) -> None:
super().__init__(
keyfunc,
ignore_unpopulated_attribute=ignore_unpopulated_attribute,