diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-10 15:42:35 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-11 22:11:07 -0400 |
commit | a45e2284dad17fbbba3bea9d5e5304aab21c8c94 (patch) | |
tree | ac31614f2d53059570e2edffe731baf384baea23 /lib/sqlalchemy/ext/asyncio/base.py | |
parent | aa9cd878e8249a4a758c7f968e929e92fede42a5 (diff) | |
download | sqlalchemy-a45e2284dad17fbbba3bea9d5e5304aab21c8c94.tar.gz |
pep-484: asyncio
in this patch the asyncio/events.py module, which
existed only to raise errors when trying to attach event
listeners, is removed, as we were already coding an asyncio-specific
workaround in upstream Pool / Session to raise this error,
just moved the error out to the target and did the same thing
for Engine.
We also add an async_sessionmaker class. The initial rationale
here is because sessionmaker() is hardcoded to Session subclasses,
and there's not a way to get the use case of
sessionmaker(class_=AsyncSession) to type correctly without changing
the sessionmaker() symbol itself to be a function and not a class,
which gets too complicated for what this is. Additionally,
_SessionClassMethods has only three methods on it, one of which
is not usable with asyncio (close_all()), the others
not generally used from the session class.
Change-Id: I064a5fa5d91cc8d5bbe9597437536e37b4e801fe
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/base.py')
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/base.py | 118 |
1 files changed, 98 insertions, 20 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 3f77f5500..7fdd2d7e0 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -1,36 +1,103 @@ +# ext/asyncio/base.py +# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + import abc import functools +from typing import Any +from typing import ClassVar +from typing import Dict +from typing import Generic +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Type +from typing import TypeVar import weakref from . import exc as async_exc +from ... import util +from ...util.typing import Literal + +_T = TypeVar("_T", bound=Any) + + +_PT = TypeVar("_PT", bound=Any) -class ReversibleProxy: - # weakref.ref(async proxy object) -> weakref.ref(sync proxied object) - _proxy_objects = {} +SelfReversibleProxy = TypeVar( + "SelfReversibleProxy", bound="ReversibleProxy[Any]" +) + + +class ReversibleProxy(Generic[_PT]): + _proxy_objects: ClassVar[ + Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]] + ] = {} __slots__ = ("__weakref__",) - def _assign_proxied(self, target): + @overload + def _assign_proxied(self, target: _PT) -> _PT: + ... + + @overload + def _assign_proxied(self, target: None) -> None: + ... + + def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: if target is not None: - target_ref = weakref.ref(target, ReversibleProxy._target_gced) + target_ref: weakref.ref[_PT] = weakref.ref( + target, ReversibleProxy._target_gced + ) proxy_ref = weakref.ref( self, - functools.partial(ReversibleProxy._target_gced, target_ref), + functools.partial( # type: ignore + ReversibleProxy._target_gced, target_ref + ), ) ReversibleProxy._proxy_objects[target_ref] = proxy_ref return target @classmethod - def _target_gced(cls, ref, proxy_ref=None): + def _target_gced( + cls: Type[SelfReversibleProxy], + ref: weakref.ref[_PT], + proxy_ref: Optional[weakref.ref[SelfReversibleProxy]] = None, + ) -> None: cls._proxy_objects.pop(ref, None) @classmethod - def _regenerate_proxy_for_target(cls, target): + def _regenerate_proxy_for_target( + cls: Type[SelfReversibleProxy], target: _PT + ) -> SelfReversibleProxy: raise NotImplementedError() + @overload @classmethod - def _retrieve_proxy_for_target(cls, target, regenerate=True): + def _retrieve_proxy_for_target( + cls: Type[SelfReversibleProxy], + target: _PT, + regenerate: Literal[True] = ..., + ) -> SelfReversibleProxy: + ... + + @overload + @classmethod + def _retrieve_proxy_for_target( + cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True + ) -> Optional[SelfReversibleProxy]: + ... + + @classmethod + def _retrieve_proxy_for_target( + cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True + ) -> Optional[SelfReversibleProxy]: try: proxy_ref = cls._proxy_objects[weakref.ref(target)] except KeyError: @@ -38,7 +105,7 @@ class ReversibleProxy: else: proxy = proxy_ref() if proxy is not None: - return proxy + return proxy # type: ignore if regenerate: return cls._regenerate_proxy_for_target(target) @@ -46,43 +113,54 @@ class ReversibleProxy: return None +SelfStartableContext = TypeVar( + "SelfStartableContext", bound="StartableContext" +) + + class StartableContext(abc.ABC): __slots__ = () @abc.abstractmethod - async def start(self, is_ctxmanager=False): - pass + async def start( + self: SelfStartableContext, is_ctxmanager: bool = False + ) -> Any: + raise NotImplementedError() - def __await__(self): + def __await__(self) -> Any: return self.start().__await__() - async def __aenter__(self): + async def __aenter__(self: SelfStartableContext) -> Any: return await self.start(is_ctxmanager=True) @abc.abstractmethod - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: pass - def _raise_for_not_started(self): + def _raise_for_not_started(self) -> NoReturn: raise async_exc.AsyncContextNotStarted( "%s context has not been started and object has not been awaited." % (self.__class__.__name__) ) -class ProxyComparable(ReversibleProxy): +class ProxyComparable(ReversibleProxy[_PT]): __slots__ = () - def __hash__(self): + @util.ro_non_memoized_property + def _proxied(self) -> _PT: + raise NotImplementedError() + + def __hash__(self) -> int: return id(self) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, self.__class__) and self._proxied == other._proxied ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return ( not isinstance(other, self.__class__) or self._proxied != other._proxied |