summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/base.py')
-rw-r--r--lib/sqlalchemy/ext/asyncio/base.py118
1 files changed, 98 insertions, 20 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
index 3f77f5500..7fdd2d7e0 100644
--- a/lib/sqlalchemy/ext/asyncio/base.py
+++ b/lib/sqlalchemy/ext/asyncio/base.py
@@ -1,36 +1,103 @@
+# ext/asyncio/base.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import annotations
+
import abc
import functools
+from typing import Any
+from typing import ClassVar
+from typing import Dict
+from typing import Generic
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Type
+from typing import TypeVar
import weakref
from . import exc as async_exc
+from ... import util
+from ...util.typing import Literal
+
+_T = TypeVar("_T", bound=Any)
+
+
+_PT = TypeVar("_PT", bound=Any)
-class ReversibleProxy:
- # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
- _proxy_objects = {}
+SelfReversibleProxy = TypeVar(
+ "SelfReversibleProxy", bound="ReversibleProxy[Any]"
+)
+
+
+class ReversibleProxy(Generic[_PT]):
+ _proxy_objects: ClassVar[
+ Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]]
+ ] = {}
__slots__ = ("__weakref__",)
- def _assign_proxied(self, target):
+ @overload
+ def _assign_proxied(self, target: _PT) -> _PT:
+ ...
+
+ @overload
+ def _assign_proxied(self, target: None) -> None:
+ ...
+
+ def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
if target is not None:
- target_ref = weakref.ref(target, ReversibleProxy._target_gced)
+ target_ref: weakref.ref[_PT] = weakref.ref(
+ target, ReversibleProxy._target_gced
+ )
proxy_ref = weakref.ref(
self,
- functools.partial(ReversibleProxy._target_gced, target_ref),
+ functools.partial( # type: ignore
+ ReversibleProxy._target_gced, target_ref
+ ),
)
ReversibleProxy._proxy_objects[target_ref] = proxy_ref
return target
@classmethod
- def _target_gced(cls, ref, proxy_ref=None):
+ def _target_gced(
+ cls: Type[SelfReversibleProxy],
+ ref: weakref.ref[_PT],
+ proxy_ref: Optional[weakref.ref[SelfReversibleProxy]] = None,
+ ) -> None:
cls._proxy_objects.pop(ref, None)
@classmethod
- def _regenerate_proxy_for_target(cls, target):
+ def _regenerate_proxy_for_target(
+ cls: Type[SelfReversibleProxy], target: _PT
+ ) -> SelfReversibleProxy:
raise NotImplementedError()
+ @overload
@classmethod
- def _retrieve_proxy_for_target(cls, target, regenerate=True):
+ def _retrieve_proxy_for_target(
+ cls: Type[SelfReversibleProxy],
+ target: _PT,
+ regenerate: Literal[True] = ...,
+ ) -> SelfReversibleProxy:
+ ...
+
+ @overload
+ @classmethod
+ def _retrieve_proxy_for_target(
+ cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
+ ) -> Optional[SelfReversibleProxy]:
+ ...
+
+ @classmethod
+ def _retrieve_proxy_for_target(
+ cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
+ ) -> Optional[SelfReversibleProxy]:
try:
proxy_ref = cls._proxy_objects[weakref.ref(target)]
except KeyError:
@@ -38,7 +105,7 @@ class ReversibleProxy:
else:
proxy = proxy_ref()
if proxy is not None:
- return proxy
+ return proxy # type: ignore
if regenerate:
return cls._regenerate_proxy_for_target(target)
@@ -46,43 +113,54 @@ class ReversibleProxy:
return None
+SelfStartableContext = TypeVar(
+ "SelfStartableContext", bound="StartableContext"
+)
+
+
class StartableContext(abc.ABC):
__slots__ = ()
@abc.abstractmethod
- async def start(self, is_ctxmanager=False):
- pass
+ async def start(
+ self: SelfStartableContext, is_ctxmanager: bool = False
+ ) -> Any:
+ raise NotImplementedError()
- def __await__(self):
+ def __await__(self) -> Any:
return self.start().__await__()
- async def __aenter__(self):
+ async def __aenter__(self: SelfStartableContext) -> Any:
return await self.start(is_ctxmanager=True)
@abc.abstractmethod
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
pass
- def _raise_for_not_started(self):
+ def _raise_for_not_started(self) -> NoReturn:
raise async_exc.AsyncContextNotStarted(
"%s context has not been started and object has not been awaited."
% (self.__class__.__name__)
)
-class ProxyComparable(ReversibleProxy):
+class ProxyComparable(ReversibleProxy[_PT]):
__slots__ = ()
- def __hash__(self):
+ @util.ro_non_memoized_property
+ def _proxied(self) -> _PT:
+ raise NotImplementedError()
+
+ def __hash__(self) -> int:
return id(self)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return (
isinstance(other, self.__class__)
and self._proxied == other._proxied
)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return (
not isinstance(other, self.__class__)
or self._proxied != other._proxied