diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-06 22:46:49 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-06 22:46:49 +0300 |
commit | 4e2585a6f613905135164d3f6a5c6adf752ba441 (patch) | |
tree | 772163753df70b4a540d7c7678d05effbd3110b2 /src/apscheduler/datastores | |
parent | 279d9c42059a29619c89553d5468b4b6ca43dd6d (diff) | |
download | apscheduler-4e2585a6f613905135164d3f6a5c6adf752ba441.tar.gz |
Use the real UUID column type where supported
Diffstat (limited to 'src/apscheduler/datastores')
-rw-r--r-- | src/apscheduler/datastores/async_/sqlalchemy.py | 46 | ||||
-rw-r--r-- | src/apscheduler/datastores/sync/sqlalchemy.py | 46 |
2 files changed, 64 insertions, 28 deletions
diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py index 619b1c5..beaac4f 100644 --- a/src/apscheduler/datastores/async_/sqlalchemy.py +++ b/src/apscheduler/datastores/async_/sqlalchemy.py @@ -14,8 +14,9 @@ import sniffio from anyio import TASK_STATUS_IGNORED, create_task_group, sleep from attr import asdict from sqlalchemy import ( - Column, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam, func, or_, select) -from sqlalchemy.engine import URL + TIMESTAMP, Column, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, and_, + bindparam, func, or_, select) +from sqlalchemy.engine import URL, Dialect from sqlalchemy.exc import CompileError, IntegrityError from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine @@ -59,6 +60,17 @@ def json_object_hook(obj: dict[str, Any]) -> Any: return obj +class EmulatedUUID(TypeDecorator): + impl = Unicode(32) + cache_ok = True + + def process_bind_param(self, value, dialect: Dialect) -> Any: + return value.hex + + def process_result_value(self, value: Any, dialect: Dialect): + return UUID(value) if value else None + + @reentrant @attr.define(eq=False) class SQLAlchemyDataStore(AsyncDataStore): @@ -138,12 +150,18 @@ class SQLAlchemyDataStore(AsyncDataStore): def get_table_definitions(self) -> MetaData: if self.engine.dialect.name in ('mysql', 'mariadb'): - from sqlalchemy.dialects.mysql import TIMESTAMP - timestamp_type = TIMESTAMP(fsp=6) + from sqlalchemy.dialects import mysql + timestamp_type = mysql.TIMESTAMP(fsp=6) else: - from sqlalchemy.types import TIMESTAMP timestamp_type = TIMESTAMP(timezone=True) + if self.engine.dialect.name == 'postgresql': + from sqlalchemy.dialects import postgresql + + job_id_type = postgresql.UUID(as_uuid=True) + else: + job_id_type = EmulatedUUID + metadata = MetaData() Table( 'metadata', @@ -173,7 +191,7 @@ class SQLAlchemyDataStore(AsyncDataStore): Table( 'jobs', metadata, - Column('id', Unicode(32), primary_key=True), + Column('id', job_id_type, primary_key=True), Column('task_id', Unicode(500), nullable=False, index=True), Column('serialized_data', LargeBinary, nullable=False), Column('created_at', timestamp_type, nullable=False), @@ -183,7 +201,7 @@ class SQLAlchemyDataStore(AsyncDataStore): Table( 'job_results', metadata, - Column('job_id', Unicode(32), primary_key=True), + Column('job_id', job_id_type, primary_key=True), Column('finished_at', timestamp_type, index=True), Column('serialized_data', LargeBinary, nullable=False) ) @@ -465,7 +483,7 @@ class SQLAlchemyDataStore(AsyncDataStore): async def add_job(self, job: Job) -> None: now = datetime.now(timezone.utc) serialized_data = self.serializer.serialize(job) - insert = self.t_jobs.insert().values(id=job.id.hex, task_id=job.task_id, + insert = self.t_jobs.insert().values(id=job.id, task_id=job.task_id, created_at=now, serialized_data=serialized_data) async with self.engine.begin() as conn: await conn.execute(insert) @@ -478,7 +496,7 @@ class SQLAlchemyDataStore(AsyncDataStore): query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ order_by(self.t_jobs.c.id) if ids: - job_ids = [job_id.hex for job_id in ids] + job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) async with self.engine.begin() as conn: @@ -529,7 +547,7 @@ class SQLAlchemyDataStore(AsyncDataStore): if acquired_jobs: # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id.hex for job in acquired_jobs] + acquired_job_ids = [job.id for job in acquired_jobs] update = self.t_jobs.update().\ values(acquired_by=worker_id, acquired_until=acquired_until).\ where(self.t_jobs.c.id.in_(acquired_job_ids)) @@ -553,7 +571,7 @@ class SQLAlchemyDataStore(AsyncDataStore): now = datetime.now(timezone.utc) serialized_data = self.serializer.serialize(result) insert = self.t_job_results.insert().\ - values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_data) + values(job_id=job.id, finished_at=now, serialized_data=serialized_data) await conn.execute(insert) # Decrement the running jobs counter @@ -563,17 +581,17 @@ class SQLAlchemyDataStore(AsyncDataStore): await conn.execute(update) # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex) + delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id) await conn.execute(delete) async def get_job_result(self, job_id: UUID) -> Optional[JobResult]: async with self.engine.begin() as conn: query = select(self.t_job_results.c.serialized_data).\ - where(self.t_job_results.c.job_id == job_id.hex) + where(self.t_job_results.c.job_id == job_id) result = await conn.execute(query) delete = self.t_job_results.delete().\ - where(self.t_job_results.c.job_id == job_id.hex) + where(self.t_job_results.c.job_id == job_id) await conn.execute(delete) serialized_data = result.scalar() diff --git a/src/apscheduler/datastores/sync/sqlalchemy.py b/src/apscheduler/datastores/sync/sqlalchemy.py index d4c1d3f..4914c1b 100644 --- a/src/apscheduler/datastores/sync/sqlalchemy.py +++ b/src/apscheduler/datastores/sync/sqlalchemy.py @@ -8,8 +8,9 @@ from uuid import UUID import attr from sqlalchemy import ( - Column, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam, or_, select) -from sqlalchemy.engine import URL + TIMESTAMP, Column, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, and_, + bindparam, or_, select) +from sqlalchemy.engine import URL, Dialect from sqlalchemy.exc import CompileError, IntegrityError from sqlalchemy.future import Engine, create_engine from sqlalchemy.sql.ddl import DropTable @@ -28,6 +29,17 @@ from ...structures import JobResult, Task from ...util import reentrant +class EmulatedUUID(TypeDecorator): + impl = Unicode(32) + cache_ok = True + + def process_bind_param(self, value, dialect: Dialect) -> Any: + return value.hex + + def process_result_value(self, value: Any, dialect: Dialect): + return UUID(value) if value else None + + @reentrant @attr.define(eq=False) class SQLAlchemyDataStore(DataStore): @@ -90,12 +102,18 @@ class SQLAlchemyDataStore(DataStore): def get_table_definitions(self) -> MetaData: if self.engine.dialect.name in ('mysql', 'mariadb'): - from sqlalchemy.dialects.mysql import TIMESTAMP - timestamp_type = TIMESTAMP(fsp=6) + from sqlalchemy.dialects import mysql + timestamp_type = mysql.TIMESTAMP(fsp=6) else: - from sqlalchemy.types import TIMESTAMP timestamp_type = TIMESTAMP(timezone=True) + if self.engine.dialect.name == 'postgresql': + from sqlalchemy.dialects import postgresql + + job_id_type = postgresql.UUID(as_uuid=True) + else: + job_id_type = EmulatedUUID + metadata = MetaData() Table( 'metadata', @@ -125,7 +143,7 @@ class SQLAlchemyDataStore(DataStore): Table( 'jobs', metadata, - Column('id', Unicode(32), primary_key=True), + Column('id', job_id_type, primary_key=True), Column('task_id', Unicode(500), nullable=False, index=True), Column('serialized_data', LargeBinary, nullable=False), Column('created_at', timestamp_type, nullable=False), @@ -135,7 +153,7 @@ class SQLAlchemyDataStore(DataStore): Table( 'job_results', metadata, - Column('job_id', Unicode(32), primary_key=True), + Column('job_id', job_id_type, primary_key=True), Column('finished_at', timestamp_type, index=True), Column('serialized_data', LargeBinary, nullable=False) ) @@ -371,7 +389,7 @@ class SQLAlchemyDataStore(DataStore): def add_job(self, job: Job) -> None: now = datetime.now(timezone.utc) serialized_data = self.serializer.serialize(job) - insert = self.t_jobs.insert().values(id=job.id.hex, task_id=job.task_id, + insert = self.t_jobs.insert().values(id=job.id, task_id=job.task_id, created_at=now, serialized_data=serialized_data) with self.engine.begin() as conn: conn.execute(insert) @@ -384,7 +402,7 @@ class SQLAlchemyDataStore(DataStore): query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ order_by(self.t_jobs.c.id) if ids: - job_ids = [job_id.hex for job_id in ids] + job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) with self.engine.begin() as conn: @@ -434,7 +452,7 @@ class SQLAlchemyDataStore(DataStore): if acquired_jobs: # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id.hex for job in acquired_jobs] + acquired_job_ids = [job.id for job in acquired_jobs] update = self.t_jobs.update().\ values(acquired_by=worker_id, acquired_until=acquired_until).\ where(self.t_jobs.c.id.in_(acquired_job_ids)) @@ -458,7 +476,7 @@ class SQLAlchemyDataStore(DataStore): now = datetime.now(timezone.utc) serialized_result = self.serializer.serialize(result) insert = self.t_job_results.insert().\ - values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_result) + values(job_id=job.id, finished_at=now, serialized_data=serialized_result) conn.execute(insert) # Decrement the running jobs counter @@ -468,19 +486,19 @@ class SQLAlchemyDataStore(DataStore): conn.execute(update) # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex) + delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id) conn.execute(delete) def get_job_result(self, job_id: UUID) -> Optional[JobResult]: with self.engine.begin() as conn: # Retrieve the result query = select(self.t_job_results.c.serialized_data).\ - where(self.t_job_results.c.job_id == job_id.hex) + where(self.t_job_results.c.job_id == job_id) result = conn.execute(query) # Delete the result delete = self.t_job_results.delete().\ - where(self.t_job_results.c.job_id == job_id.hex) + where(self.t_job_results.c.job_id == job_id) conn.execute(delete) serialized_result = result.scalar() |