summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--apscheduler/datastores/async_/postgresql.py11
-rw-r--r--tests/test_datastores.py1
2 files changed, 8 insertions, 4 deletions
diff --git a/apscheduler/datastores/async_/postgresql.py b/apscheduler/datastores/async_/postgresql.py
index 8744656..b91f9b9 100644
--- a/apscheduler/datastores/async_/postgresql.py
+++ b/apscheduler/datastores/async_/postgresql.py
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type
from uuid import UUID
import sniffio
-from anyio import create_task_group, sleep
+from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
from anyio.abc import TaskGroup
from asyncpg import Connection, UniqueViolationError
from asyncpg.pool import Pool
@@ -112,7 +112,7 @@ class PostgresqlDataStore(AsyncDataStore):
if self.notify_channel:
self._task_group = create_task_group()
await self._task_group.__aenter__()
- self._task_group.start_soon(self._listen_notifications)
+ await self._task_group.start(self._listen_notifications)
return self
@@ -144,7 +144,7 @@ class PostgresqlDataStore(AsyncDataStore):
self._events.publish(event)
- async def _listen_notifications(self) -> None:
+ async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None:
def callback(connection, pid, channel: str, payload: str) -> None:
self._logger.debug('Received notification on channel %s: %s', channel, payload)
event_type, _, json_data = payload.partition(' ')
@@ -158,9 +158,14 @@ class PostgresqlDataStore(AsyncDataStore):
event = event_class(**event_data)
self._events.publish(event)
+ task_started_sent = False
while True:
async with self.pool.acquire() as conn:
await conn.add_listener(self.notify_channel, callback)
+ if not task_started_sent:
+ task_status.started()
+ task_started_sent = True
+
try:
while True:
await sleep(self.max_idle_time)
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 9773228..b635e03 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -47,7 +47,6 @@ async def capture_events(
event_types: Optional[Set[Type[Event]]] = None
) -> AsyncGenerator[List[Event], None, None]:
def listener(event: Event) -> None:
- print('received', event)
events.append(event)
if len(events) == limit:
limit_event.set()