summaryrefslogtreecommitdiff
path: root/src/apscheduler/datastores/mongodb.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/apscheduler/datastores/mongodb.py')
-rw-r--r--src/apscheduler/datastores/mongodb.py109
1 files changed, 44 insertions, 65 deletions
diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py
index 13299bb..4899f54 100644
--- a/src/apscheduler/datastores/mongodb.py
+++ b/src/apscheduler/datastores/mongodb.py
@@ -2,14 +2,13 @@ from __future__ import annotations
import operator
from collections import defaultdict
+from contextlib import AsyncExitStack
from datetime import datetime, timedelta, timezone
-from logging import Logger, getLogger
from typing import Any, Callable, ClassVar, Iterable
from uuid import UUID
import attrs
import pymongo
-import tenacity
from attrs.validators import instance_of
from bson import CodecOptions, UuidRepresentation
from bson.codec_options import TypeEncoder, TypeRegistry
@@ -35,10 +34,9 @@ from .._exceptions import (
SerializationError,
TaskLookupError,
)
-from .._structures import Job, JobResult, RetrySettings, Schedule, Task
-from ..abc import EventBroker, Serializer
-from ..serializers.pickle import PickleSerializer
-from .base import BaseDataStore
+from .._structures import Job, JobResult, Schedule, Task
+from ..abc import EventBroker
+from .base import BaseExternalDataStore
class CustomEncoder(TypeEncoder):
@@ -55,7 +53,7 @@ class CustomEncoder(TypeEncoder):
@attrs.define(eq=False)
-class MongoDBDataStore(BaseDataStore):
+class MongoDBDataStore(BaseExternalDataStore):
"""
Uses a MongoDB server to store data.
@@ -66,23 +64,11 @@ class MongoDBDataStore(BaseDataStore):
raises :exc:`pymongo.errors.ConnectionFailure`.
:param client: a PyMongo client
- :param serializer: the serializer used to (de)serialize tasks, schedules and jobs
- for storage
:param database: name of the database to use
- :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
- or worker can keep a lock on a schedule or task
- :param retry_settings: Tenacity settings for retrying operations in case of a
- database connecitivty problem
- :param start_from_scratch: erase all existing data during startup (useful for test
- suites)
"""
client: MongoClient = attrs.field(validator=instance_of(MongoClient))
- serializer: Serializer = attrs.field(factory=PickleSerializer, kw_only=True)
database: str = attrs.field(default="apscheduler", kw_only=True)
- lock_expiration_delay: float = attrs.field(default=30, kw_only=True)
- retry_settings: RetrySettings = attrs.field(default=RetrySettings(), kw_only=True)
- start_from_scratch: bool = attrs.field(default=False, kw_only=True)
_task_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Task)]
_schedule_attrs: ClassVar[list[str]] = [
@@ -90,8 +76,8 @@ class MongoDBDataStore(BaseDataStore):
]
_job_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Job)]
- _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__))
_local_tasks: dict[str, Task] = attrs.field(init=False, factory=dict)
+ _temporary_failure_exceptions = (ConnectionFailure,)
def __attrs_post_init__(self) -> None:
type_registry = TypeRegistry(
@@ -118,25 +104,10 @@ class MongoDBDataStore(BaseDataStore):
client = MongoClient(uri)
return cls(client, **options)
- def _retry(self) -> tenacity.Retrying:
- return tenacity.Retrying(
- stop=self.retry_settings.stop,
- wait=self.retry_settings.wait,
- retry=tenacity.retry_if_exception_type(ConnectionFailure),
- after=self._after_attempt,
- reraise=True,
- )
-
- def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None:
- self._logger.warning(
- "Temporary data store error (attempt %d): %s",
- retry_state.attempt_number,
- retry_state.outcome.exception(),
- )
-
- def start(self, event_broker: EventBroker) -> None:
- super().start(event_broker)
-
+ async def start(
+ self, exit_stack: AsyncExitStack, event_broker: EventBroker
+ ) -> None:
+ await super().start(exit_stack, event_broker)
server_info = self.client.server_info()
if server_info["versionArray"] < [4, 0]:
raise RuntimeError(
@@ -144,7 +115,7 @@ class MongoDBDataStore(BaseDataStore):
f"{server_info['version']}"
)
- for attempt in self._retry():
+ async for attempt in self._retry():
with attempt, self.client.start_session() as session:
if self.start_from_scratch:
self._tasks.delete_many({}, session=session)
@@ -159,7 +130,7 @@ class MongoDBDataStore(BaseDataStore):
self._jobs_results.create_index("finished_at", session=session)
self._jobs_results.create_index("expires_at", session=session)
- def add_task(self, task: Task) -> None:
+ async def add_task(self, task: Task) -> None:
for attempt in self._retry():
with attempt:
previous = self._tasks.find_one_and_update(
@@ -173,20 +144,20 @@ class MongoDBDataStore(BaseDataStore):
self._local_tasks[task.id] = task
if previous:
- self._events.publish(TaskUpdated(task_id=task.id))
+ await self._event_broker.publish(TaskUpdated(task_id=task.id))
else:
- self._events.publish(TaskAdded(task_id=task.id))
+ await self._event_broker.publish(TaskAdded(task_id=task.id))
- def remove_task(self, task_id: str) -> None:
+ async def remove_task(self, task_id: str) -> None:
for attempt in self._retry():
with attempt:
if not self._tasks.find_one_and_delete({"_id": task_id}):
raise TaskLookupError(task_id)
del self._local_tasks[task_id]
- self._events.publish(TaskRemoved(task_id=task_id))
+ await self._event_broker.publish(TaskRemoved(task_id=task_id))
- def get_task(self, task_id: str) -> Task:
+ async def get_task(self, task_id: str) -> Task:
try:
return self._local_tasks[task_id]
except KeyError:
@@ -205,7 +176,7 @@ class MongoDBDataStore(BaseDataStore):
)
return task
- def get_tasks(self) -> list[Task]:
+ async def get_tasks(self) -> list[Task]:
for attempt in self._retry():
with attempt:
tasks: list[Task] = []
@@ -217,7 +188,7 @@ class MongoDBDataStore(BaseDataStore):
return tasks
- def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
+ async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
for attempt in self._retry():
with attempt:
@@ -237,7 +208,9 @@ class MongoDBDataStore(BaseDataStore):
return schedules
- def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
+ async def add_schedule(
+ self, schedule: Schedule, conflict_policy: ConflictPolicy
+ ) -> None:
event: DataStoreEvent
document = schedule.marshal(self.serializer)
document["_id"] = document.pop("id")
@@ -258,14 +231,14 @@ class MongoDBDataStore(BaseDataStore):
event = ScheduleUpdated(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
else:
event = ScheduleAdded(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def remove_schedules(self, ids: Iterable[str]) -> None:
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
for attempt in self._retry():
with attempt, self.client.start_session() as session:
@@ -277,9 +250,9 @@ class MongoDBDataStore(BaseDataStore):
self._schedules.delete_many(filters, session=session)
for schedule_id in ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
for attempt in self._retry():
with attempt, self.client.start_session() as session:
schedules: list[Schedule] = []
@@ -318,7 +291,9 @@ class MongoDBDataStore(BaseDataStore):
return schedules
- def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None:
+ async def release_schedules(
+ self, scheduler_id: str, schedules: list[Schedule]
+ ) -> None:
updated_schedules: list[tuple[str, datetime]] = []
finished_schedule_ids: list[str] = []
@@ -365,12 +340,12 @@ class MongoDBDataStore(BaseDataStore):
event = ScheduleUpdated(
schedule_id=schedule_id, next_fire_time=next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
for schedule_id in finished_schedule_ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def get_next_schedule_run_time(self) -> datetime | None:
+ async def get_next_schedule_run_time(self) -> datetime | None:
for attempt in self._retry():
with attempt:
document = self._schedules.find_one(
@@ -384,7 +359,7 @@ class MongoDBDataStore(BaseDataStore):
else:
return None
- def add_job(self, job: Job) -> None:
+ async def add_job(self, job: Job) -> None:
document = job.marshal(self.serializer)
document["_id"] = document.pop("id")
for attempt in self._retry():
@@ -397,9 +372,9 @@ class MongoDBDataStore(BaseDataStore):
schedule_id=job.schedule_id,
tags=job.tags,
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
+ async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
for attempt in self._retry():
with attempt:
@@ -419,7 +394,7 @@ class MongoDBDataStore(BaseDataStore):
return jobs
- def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
+ async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
for attempt in self._retry():
with attempt, self.client.start_session() as session:
cursor = self._jobs.find(
@@ -488,11 +463,15 @@ class MongoDBDataStore(BaseDataStore):
# Publish the appropriate events
for job in acquired_jobs:
- self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id))
+ await self._event_broker.publish(
+ JobAcquired(job_id=job.id, worker_id=worker_id)
+ )
return acquired_jobs
- def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None:
+ async def release_job(
+ self, worker_id: str, task_id: str, result: JobResult
+ ) -> None:
for attempt in self._retry():
with attempt, self.client.start_session() as session:
# Record the job result
@@ -509,7 +488,7 @@ class MongoDBDataStore(BaseDataStore):
# Delete the job
self._jobs.delete_one({"_id": result.job_id}, session=session)
- def get_job_result(self, job_id: UUID) -> JobResult | None:
+ async def get_job_result(self, job_id: UUID) -> JobResult | None:
for attempt in self._retry():
with attempt:
document = self._jobs_results.find_one_and_delete({"_id": job_id})