summaryrefslogtreecommitdiff
path: root/src/apscheduler/marshalling.py
blob: 357a973b451f8127913ca12f1a568c7a5f319dc0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import annotations

import sys
from datetime import date, datetime, tzinfo
from functools import partial
from typing import Any, Callable, overload

from ._exceptions import DeserializationError, SerializationError

if sys.version_info >= (3, 9):
    from zoneinfo import ZoneInfo
else:
    from backports.zoneinfo import ZoneInfo


def marshal_object(obj) -> tuple[str, Any]:
    return (
        f"{obj.__class__.__module__}:{obj.__class__.__qualname__}",
        obj.__getstate__(),
    )


def unmarshal_object(ref: str, state):
    cls = callable_from_ref(ref)
    instance = cls.__new__(cls)
    instance.__setstate__(state)
    return instance


@overload
def marshal_date(value: None) -> None:
    ...


@overload
def marshal_date(value: date) -> str:
    ...


def marshal_date(value):
    return value.isoformat() if value is not None else None


@overload
def unmarshal_date(value: None) -> None:
    ...


@overload
def unmarshal_date(value: str) -> date:
    ...


def unmarshal_date(value):
    if value is None:
        return None
    elif len(value) == 10:
        return date.fromisoformat(value)
    else:
        return datetime.fromisoformat(value)


def marshal_timezone(value: tzinfo) -> str:
    if isinstance(value, ZoneInfo):
        return value.key
    elif hasattr(value, "zone"):  # pytz timezones
        return value.zone

    raise SerializationError(
        f"Unserializable time zone: {value!r}\n"
        f"Only time zones from the zoneinfo or pytz modules can be serialized."
    )


def unmarshal_timezone(value: str) -> ZoneInfo:
    return ZoneInfo(value)


def callable_to_ref(func: Callable) -> str:
    """
    Return a reference to the given callable.

    :raises SerializationError: if the given object is not callable, is a partial(),
        lambda or local function or does not have the ``__module__`` and
        ``__qualname__`` attributes

    """
    if isinstance(func, partial):
        raise SerializationError("Cannot create a reference to a partial()")

    if not hasattr(func, "__module__"):
        raise SerializationError("Callable has no __module__ attribute")
    if not hasattr(func, "__qualname__"):
        raise SerializationError("Callable has no __qualname__ attribute")
    if "<lambda>" in func.__qualname__:
        raise SerializationError("Cannot create a reference to a lambda")
    if "<locals>" in func.__qualname__:
        raise SerializationError("Cannot create a reference to a nested function")

    return f"{func.__module__}:{func.__qualname__}"


def callable_from_ref(ref: str) -> Callable:
    """
    Return the callable pointed to by ``ref``.

    :raises DeserializationError: if the reference could not be resolved or the looked
        up object is not callable

    """
    if ":" not in ref:
        raise ValueError(f"Invalid reference: {ref}")

    modulename, rest = ref.split(":", 1)
    try:
        obj = __import__(modulename, fromlist=[rest])
    except ImportError:
        raise LookupError(f"Error resolving reference {ref!r}: could not import module")

    try:
        for name in rest.split("."):
            obj = getattr(obj, name)
    except Exception:
        raise DeserializationError(
            f"Error resolving reference {ref!r}: error looking up object"
        )

    if not callable(obj):
        raise DeserializationError(
            f"{ref!r} points to an object of type "
            f"{obj.__class__.__qualname__} which is not callable"
        )

    return obj