summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio/base.py
blob: 7fdd2d7e064314f1b3b5b55b2304ea28cc1c8d3a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# ext/asyncio/base.py
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

from __future__ import annotations

import abc
import functools
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import Generic
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Type
from typing import TypeVar
import weakref

from . import exc as async_exc
from ... import util
from ...util.typing import Literal

_T = TypeVar("_T", bound=Any)


_PT = TypeVar("_PT", bound=Any)


SelfReversibleProxy = TypeVar(
    "SelfReversibleProxy", bound="ReversibleProxy[Any]"
)


class ReversibleProxy(Generic[_PT]):
    _proxy_objects: ClassVar[
        Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]]
    ] = {}
    __slots__ = ("__weakref__",)

    @overload
    def _assign_proxied(self, target: _PT) -> _PT:
        ...

    @overload
    def _assign_proxied(self, target: None) -> None:
        ...

    def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
        if target is not None:
            target_ref: weakref.ref[_PT] = weakref.ref(
                target, ReversibleProxy._target_gced
            )
            proxy_ref = weakref.ref(
                self,
                functools.partial(  # type: ignore
                    ReversibleProxy._target_gced, target_ref
                ),
            )
            ReversibleProxy._proxy_objects[target_ref] = proxy_ref

        return target

    @classmethod
    def _target_gced(
        cls: Type[SelfReversibleProxy],
        ref: weakref.ref[_PT],
        proxy_ref: Optional[weakref.ref[SelfReversibleProxy]] = None,
    ) -> None:
        cls._proxy_objects.pop(ref, None)

    @classmethod
    def _regenerate_proxy_for_target(
        cls: Type[SelfReversibleProxy], target: _PT
    ) -> SelfReversibleProxy:
        raise NotImplementedError()

    @overload
    @classmethod
    def _retrieve_proxy_for_target(
        cls: Type[SelfReversibleProxy],
        target: _PT,
        regenerate: Literal[True] = ...,
    ) -> SelfReversibleProxy:
        ...

    @overload
    @classmethod
    def _retrieve_proxy_for_target(
        cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
    ) -> Optional[SelfReversibleProxy]:
        ...

    @classmethod
    def _retrieve_proxy_for_target(
        cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
    ) -> Optional[SelfReversibleProxy]:
        try:
            proxy_ref = cls._proxy_objects[weakref.ref(target)]
        except KeyError:
            pass
        else:
            proxy = proxy_ref()
            if proxy is not None:
                return proxy  # type: ignore

        if regenerate:
            return cls._regenerate_proxy_for_target(target)
        else:
            return None


SelfStartableContext = TypeVar(
    "SelfStartableContext", bound="StartableContext"
)


class StartableContext(abc.ABC):
    __slots__ = ()

    @abc.abstractmethod
    async def start(
        self: SelfStartableContext, is_ctxmanager: bool = False
    ) -> Any:
        raise NotImplementedError()

    def __await__(self) -> Any:
        return self.start().__await__()

    async def __aenter__(self: SelfStartableContext) -> Any:
        return await self.start(is_ctxmanager=True)

    @abc.abstractmethod
    async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
        pass

    def _raise_for_not_started(self) -> NoReturn:
        raise async_exc.AsyncContextNotStarted(
            "%s context has not been started and object has not been awaited."
            % (self.__class__.__name__)
        )


class ProxyComparable(ReversibleProxy[_PT]):
    __slots__ = ()

    @util.ro_non_memoized_property
    def _proxied(self) -> _PT:
        raise NotImplementedError()

    def __hash__(self) -> int:
        return id(self)

    def __eq__(self, other: Any) -> bool:
        return (
            isinstance(other, self.__class__)
            and self._proxied == other._proxied
        )

    def __ne__(self, other: Any) -> bool:
        return (
            not isinstance(other, self.__class__)
            or self._proxied != other._proxied
        )