summaryrefslogtreecommitdiff
path: root/test/lib/ansible_test/_internal/timeout.py
blob: 2fb2f44a27d22432224b0e0f062937188d3b08c5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Timeout management for tests."""
from __future__ import annotations

import dataclasses
import datetime
import functools
import os
import signal
import time
import typing as t

from .io import (
    read_json_file,
)

from .config import (
    CommonConfig,
    TestConfig,
)

from .util import (
    display,
    TimeoutExpiredError,
)

from .thread import (
    WrappedThread,
)

from .constants import (
    TIMEOUT_PATH,
)

from .test import (
    TestTimeout,
)


@dataclasses.dataclass(frozen=True)
class TimeoutDetail:
    """Details required to enforce a timeout on test execution."""

    _DEADLINE_FORMAT = '%Y-%m-%dT%H:%M:%SZ'  # format used to maintain backwards compatibility with previous versions of ansible-test

    deadline: datetime.datetime
    duration: int | float  # minutes

    @property
    def remaining(self) -> datetime.timedelta:
        """The amount of time remaining before the timeout occurs. If the timeout has passed, this will be a negative duration."""
        return self.deadline - datetime.datetime.now(tz=datetime.timezone.utc).replace(microsecond=0)

    def to_dict(self) -> dict[str, t.Any]:
        """Return timeout details as a dictionary suitable for JSON serialization."""
        return dict(
            deadline=self.deadline.strftime(self._DEADLINE_FORMAT),
            duration=self.duration,
        )

    @staticmethod
    def from_dict(value: dict[str, t.Any]) -> TimeoutDetail:
        """Return a TimeoutDetail instance using the value previously returned by to_dict."""
        return TimeoutDetail(
            deadline=datetime.datetime.strptime(value['deadline'], TimeoutDetail._DEADLINE_FORMAT).replace(tzinfo=datetime.timezone.utc),
            duration=value['duration'],
        )

    @staticmethod
    def create(duration: int | float) -> TimeoutDetail | None:
        """Return a new TimeoutDetail instance for the specified duration (in minutes), or None if the duration is zero."""
        if not duration:
            return None

        if duration == int(duration):
            duration = int(duration)

        return TimeoutDetail(
            deadline=datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) + datetime.timedelta(seconds=int(duration * 60)),
            duration=duration,
        )


def get_timeout() -> TimeoutDetail | None:
    """Return details about the currently set timeout, if any, otherwise return None."""
    try:
        return TimeoutDetail.from_dict(read_json_file(TIMEOUT_PATH))
    except FileNotFoundError:
        return None


def configure_timeout(args):  # type: (CommonConfig) -> None
    """Configure the timeout."""
    if isinstance(args, TestConfig):
        configure_test_timeout(args)  # only tests are subject to the timeout


def configure_test_timeout(args):  # type: (TestConfig) -> None
    """Configure the test timeout."""
    timeout = get_timeout()

    if not timeout:
        return

    timeout_remaining = timeout.remaining

    test_timeout = TestTimeout(timeout.duration)

    if timeout_remaining <= datetime.timedelta():
        test_timeout.write(args)

        raise TimeoutExpiredError(f'The {timeout.duration} minute test timeout expired {timeout_remaining * -1} ago at {timeout.deadline}.')

    display.info(f'The {timeout.duration} minute test timeout expires in {timeout_remaining} at {timeout.deadline}.', verbosity=1)

    def timeout_handler(_dummy1, _dummy2):
        """Runs when SIGUSR1 is received."""
        test_timeout.write(args)

        raise TimeoutExpiredError(f'Tests aborted after exceeding the {timeout.duration} minute time limit.')

    def timeout_waiter(timeout_seconds):  # type: (int) -> None
        """Background thread which will kill the current process if the timeout elapses."""
        time.sleep(timeout_seconds)
        os.kill(os.getpid(), signal.SIGUSR1)

    signal.signal(signal.SIGUSR1, timeout_handler)

    instance = WrappedThread(functools.partial(timeout_waiter, timeout_remaining.total_seconds()))
    instance.daemon = True
    instance.start()