diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-12 16:11:13 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-13 01:07:44 +0300 |
commit | 9568bf2f1297c87ec1b93306b79de925fb2da08e (patch) | |
tree | 2969e7951f8883c41f1359ebbc98bcf85a3fdad6 | |
parent | 0a6b0f683edee8bf22d85dc655ad61a8285fd312 (diff) | |
download | apscheduler-9568bf2f1297c87ec1b93306b79de925fb2da08e.tar.gz |
Implemented one-shot event subscriptions
Such subscriptions are delivered the first matching event and then unsubscribed automatically.
-rw-r--r-- | src/apscheduler/abc.py | 5 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/async_local.py | 26 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/base.py | 14 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/local.py | 13 | ||||
-rw-r--r-- | src/apscheduler/schedulers/sync.py | 2 | ||||
-rw-r--r-- | src/apscheduler/workers/sync.py | 2 | ||||
-rw-r--r-- | tests/test_eventbrokers.py | 157 |
7 files changed, 93 insertions, 126 deletions
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py index 3d85eed..58d74cd 100644 --- a/src/apscheduler/abc.py +++ b/src/apscheduler/abc.py @@ -92,13 +92,16 @@ class EventSource(metaclass=ABCMeta): @abstractmethod def subscribe( self, callback: Callable[[Event], Any], - event_types: Optional[Iterable[type[Event]]] = None + event_types: Optional[Iterable[type[Event]]] = None, + *, + one_shot: bool = False ) -> Subscription: """ Subscribe to events from this event source. :param callback: callable to be called with the event object when an event is published :param event_types: an iterable of concrete Event classes to subscribe to + :param one_shot: if ``True``, automatically unsubscribe after the first matching event """ diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py index 590f0cb..79030f3 100644 --- a/src/apscheduler/eventbrokers/async_local.py +++ b/src/apscheduler/eventbrokers/async_local.py @@ -36,15 +36,21 @@ class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker): await self.publish_local(event) async def publish_local(self, event: Event) -> None: - async def deliver_event(func: Callable[[Event], Any]) -> None: - try: - retval = func(event) - if iscoroutine(retval): - await retval - except BaseException: - self._logger.exception('Error delivering %s event', event.__class__.__name__) - event_type = type(event) - for subscription in self._subscriptions.values(): + one_shot_tokens: list[object] = [] + for token, subscription in self._subscriptions.items(): if subscription.event_types is None or event_type in subscription.event_types: - self._task_group.start_soon(deliver_event, subscription.callback) + self._task_group.start_soon(self._deliver_event, subscription.callback, event) + if subscription.one_shot: + one_shot_tokens.append(subscription.token) + + for token in one_shot_tokens: + super().unsubscribe(token) + + async def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None: + try: + retval = func(event) + if iscoroutine(retval): + await retval + except BaseException: + self._logger.exception('Error delivering %s event', event.__class__.__name__) diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py index 23bae8a..9947f68 100644 --- a/src/apscheduler/eventbrokers/base.py +++ b/src/apscheduler/eventbrokers/base.py @@ -16,31 +16,33 @@ from ..exceptions import DeserializationError class LocalSubscription(Subscription): callback: Callable[[Event], Any] event_types: Optional[set[type[Event]]] + one_shot: bool + token: object _source: BaseEventBroker - _token: object def unsubscribe(self) -> None: - self._source.unsubscribe(self._token) + self._source.unsubscribe(self.token) @attr.define(eq=False) class BaseEventBroker(EventBroker): _logger: Logger = attr.field(init=False) - _subscriptions: dict[object, Subscription] = attr.field(init=False, factory=dict) + _subscriptions: dict[object, LocalSubscription] = attr.field(init=False, factory=dict) def __attrs_post_init__(self) -> None: self._logger = getLogger(self.__class__.__module__) def subscribe(self, callback: Callable[[Event], Any], - event_types: Optional[Iterable[type[Event]]] = None) -> Subscription: + event_types: Optional[Iterable[type[Event]]] = None, *, + one_shot: bool = False) -> Subscription: types = set(event_types) if event_types else None token = object() - subscription = LocalSubscription(callback, types, self, token) + subscription = LocalSubscription(callback, types, one_shot, token, self) self._subscriptions[token] = subscription return subscription def unsubscribe(self, token: object) -> None: - self._subscriptions.pop(token) + self._subscriptions.pop(token, None) class DistributedEventBrokerMixin: diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py index 24de3eb..acf0c9a 100644 --- a/src/apscheduler/eventbrokers/local.py +++ b/src/apscheduler/eventbrokers/local.py @@ -31,13 +31,14 @@ class LocalEventBroker(BaseEventBroker): del self._executor def subscribe(self, callback: Callable[[Event], Any], - event_types: Optional[Iterable[type[Event]]] = None) -> Subscription: + event_types: Optional[Iterable[type[Event]]] = None, *, + one_shot: bool = False) -> Subscription: if iscoroutinefunction(callback): raise ValueError('Coroutine functions are not supported as callbacks on a synchronous ' 'event source') with self._subscriptions_lock: - return super().subscribe(callback, event_types) + return super().subscribe(callback, event_types, one_shot=one_shot) def unsubscribe(self, token: object) -> None: with self._subscriptions_lock: @@ -49,9 +50,15 @@ class LocalEventBroker(BaseEventBroker): def publish_local(self, event: Event) -> None: event_type = type(event) with self._subscriptions_lock: - for subscription in self._subscriptions.values(): + one_shot_tokens: list[object] = [] + for token, subscription in self._subscriptions.items(): if subscription.event_types is None or event_type in subscription.event_types: self._executor.submit(self._deliver_event, subscription.callback, event) + if subscription.one_shot: + one_shot_tokens.append(subscription.token) + + for token in one_shot_tokens: + super().unsubscribe(token) def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None: try: diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py index dd3f37e..221b284 100644 --- a/src/apscheduler/schedulers/sync.py +++ b/src/apscheduler/schedulers/sync.py @@ -73,7 +73,7 @@ class Scheduler: # Start the scheduler and return when it has signalled readiness or raised an exception start_future: Future[Event] = Future() - with self._events.subscribe(start_future.set_result): + with self._events.subscribe(start_future.set_result, one_shot=True): run_future = self._executor.submit(self.run) wait([start_future, run_future], return_when=FIRST_COMPLETED) diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py index 824cce8..be805ee 100644 --- a/src/apscheduler/workers/sync.py +++ b/src/apscheduler/workers/sync.py @@ -64,7 +64,7 @@ class Worker: # Start the worker and return when it has signalled readiness or raised an exception start_future: Future[None] = Future() - with self._events.subscribe(start_future.set_result): + with self._events.subscribe(start_future.set_result, one_shot=True): self._executor = ThreadPoolExecutor(1) run_future = self._executor.submit(self.run) wait([start_future, run_future], return_when=FIRST_COMPLETED) diff --git a/tests/test_eventbrokers.py b/tests/test_eventbrokers.py index 7097001..024b63d 100644 --- a/tests/test_eventbrokers.py +++ b/tests/test_eventbrokers.py @@ -79,16 +79,10 @@ def async_broker(request: FixtureRequest) -> Callable[[], AsyncEventBroker]: class TestEventBroker: def test_publish_subscribe(self, broker: EventBroker) -> None: - def subscriber1(event) -> None: - queue.put_nowait(event) - - def subscriber2(event) -> None: - queue.put_nowait(event) - queue = Queue() with broker: - broker.subscribe(subscriber1) - broker.subscribe(subscriber2) + broker.subscribe(queue.put_nowait) + broker.subscribe(queue.put_nowait) event = ScheduleAdded( schedule_id='schedule1', next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)) @@ -102,6 +96,25 @@ class TestEventBroker: assert event1.schedule_id == 'schedule1' assert event1.next_fire_time == datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc) + def test_subscribe_one_shot(self, broker: EventBroker) -> None: + queue = Queue() + with broker: + broker.subscribe(queue.put_nowait, one_shot=True) + event = ScheduleAdded( + schedule_id='schedule1', + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)) + broker.publish(event) + event = ScheduleAdded( + schedule_id='schedule2', + next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc)) + broker.publish(event) + received_event = queue.get(timeout=3) + with pytest.raises(Empty): + queue.get(timeout=0.1) + + assert isinstance(received_event, ScheduleAdded) + assert received_event.schedule_id == 'schedule1' + def test_unsubscribe(self, broker: EventBroker, caplog) -> None: queue = Queue() with broker: @@ -140,24 +153,18 @@ class TestEventBroker: @pytest.mark.anyio class TestAsyncEventBroker: async def test_publish_subscribe(self, async_broker: AsyncEventBroker) -> None: - def subscriber1(event) -> None: - send.send_nowait(event) - - async def subscriber2(event) -> None: - await send.send(event) - send, receive = create_memory_object_stream(2) async with async_broker: - async_broker.subscribe(subscriber1) - async_broker.subscribe(subscriber2) + async_broker.subscribe(send.send) + async_broker.subscribe(send.send_nowait) event = ScheduleAdded( schedule_id='schedule1', next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)) await async_broker.publish(event) - with fail_after(3): - event1 = await receive.receive() - event2 = await receive.receive() + with fail_after(3): + event1 = await receive.receive() + event2 = await receive.receive() assert event1 == event2 assert isinstance(event1, ScheduleAdded) @@ -165,6 +172,28 @@ class TestAsyncEventBroker: assert event1.schedule_id == 'schedule1' assert event1.next_fire_time == datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc) + async def test_subscribe_one_shot(self, async_broker: AsyncEventBroker) -> None: + send, receive = create_memory_object_stream(2) + async with async_broker: + async_broker.subscribe(send.send, one_shot=True) + event = ScheduleAdded( + schedule_id='schedule1', + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)) + await async_broker.publish(event) + event = ScheduleAdded( + schedule_id='schedule2', + next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc)) + await async_broker.publish(event) + + with fail_after(3): + received_event = await receive.receive() + + with pytest.raises(TimeoutError), fail_after(0.1): + await receive.receive() + + assert isinstance(received_event, ScheduleAdded) + assert received_event.schedule_id == 'schedule1' + async def test_unsubscribe(self, async_broker: AsyncEventBroker) -> None: send, receive = create_memory_object_stream() async with async_broker: @@ -191,92 +220,12 @@ class TestAsyncEventBroker: raise Exception('foo') timestamp = datetime.now(timezone.utc) - events = [] + send, receive = create_memory_object_stream() async with async_broker: async_broker.subscribe(bad_subscriber) - async_broker.subscribe(events.append) + async_broker.subscribe(send.send) await async_broker.publish(Event(timestamp=timestamp)) - assert isinstance(events[0], Event) - assert events[0].timestamp == timestamp - assert 'Error delivering Event' in caplog.text -# -# def test_subscribe_coroutine_callback(self) -> None: -# async def callback(event: Event) -> None: -# pass -# -# with EventBroker() as eventhub: -# with pytest.raises(ValueError, match='Coroutine functions are not supported'): -# eventhub.subscribe(callback) -# -# def test_relay_events(self) -> None: -# timestamp = datetime.now(timezone.utc) -# events = [] -# with EventBroker() as eventhub1, EventBroker() as eventhub2: -# eventhub2.relay_events_from(eventhub1) -# eventhub2.subscribe(events.append) -# eventhub1.publish(Event(timestamp=timestamp)) -# -# assert isinstance(events[0], Event) -# assert events[0].timestamp == timestamp -# -# -# @pytest.mark.anyio -# class TestAsyncEventHub: -# async def test_publish(self) -> None: -# async def async_setitem(event: Event) -> None: -# events[1] = event -# -# timestamp = datetime.now(timezone.utc) -# events: List[Optional[Event]] = [None, None] -# async with AsyncEventBroker() as eventhub: -# eventhub.subscribe(partial(setitem, events, 0)) -# eventhub.subscribe(async_setitem) -# eventhub.publish(Event(timestamp=timestamp)) -# -# assert events[0] is events[1] -# assert isinstance(events[0], Event) -# assert events[0].timestamp == timestamp -# -# async def test_unsubscribe(self) -> None: -# timestamp = datetime.now(timezone.utc) -# events = [] -# async with AsyncEventBroker() as eventhub: -# token = eventhub.subscribe(events.append) -# eventhub.publish(Event(timestamp=timestamp)) -# eventhub.unsubscribe(token) -# eventhub.publish(Event(timestamp=timestamp)) -# -# assert len(events) == 1 -# -# async def test_publish_no_subscribers(self, caplog: LogCaptureFixture) -> None: -# async with AsyncEventBroker() as eventhub: -# eventhub.publish(Event(timestamp=datetime.now(timezone.utc))) -# -# assert not caplog.text -# -# async def test_publish_exception(self, caplog: LogCaptureFixture) -> None: -# def bad_subscriber(event: Event) -> None: -# raise Exception('foo') -# -# timestamp = datetime.now(timezone.utc) -# events = [] -# async with AsyncEventBroker() as eventhub: -# eventhub.subscribe(bad_subscriber) -# eventhub.subscribe(events.append) -# eventhub.publish(Event(timestamp=timestamp)) -# -# assert isinstance(events[0], Event) -# assert events[0].timestamp == timestamp -# assert 'Error delivering Event' in caplog.text -# -# async def test_relay_events(self) -> None: -# timestamp = datetime.now(timezone.utc) -# events = [] -# async with AsyncEventBroker() as eventhub1, AsyncEventBroker() as eventhub2: -# eventhub1.relay_events_from(eventhub2) -# eventhub1.subscribe(events.append) -# eventhub2.publish(Event(timestamp=timestamp)) -# -# assert isinstance(events[0], Event) -# assert events[0].timestamp == timestamp + received_event = await receive.receive() + assert received_event.timestamp == timestamp + assert 'Error delivering Event' in caplog.text |