diff options
-rw-r--r-- | src/buildstream/_cas/casserver.py | 7 | ||||
-rw-r--r-- | src/buildstream/_signals.py | 19 |
2 files changed, 18 insertions, 8 deletions
diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py index 013fb07dd..04c5eb836 100644 --- a/src/buildstream/_cas/casserver.py +++ b/src/buildstream/_cas/casserver.py @@ -30,6 +30,7 @@ import grpc import click from .._protos.build.bazel.remote.asset.v1 import remote_asset_pb2_grpc +from .. import _signals from .._protos.build.bazel.remote.execution.v2 import ( remote_execution_pb2, remote_execution_pb2_grpc, @@ -137,7 +138,11 @@ def create_server(repo, *, enable_push, quota, index_only, log_level=LogLevel.Le _ReferenceStorageServicer(casd_channel, root, enable_push=enable_push), server ) - yield server + # Ensure we have the signal handler set for SIGTERM + # This allows threads from GRPC to call our methods that do register + # handlers at exit. + with _signals.terminator(lambda: None): + yield server finally: casd_channel.close() diff --git a/src/buildstream/_signals.py b/src/buildstream/_signals.py index 03b55b052..8032752a8 100644 --- a/src/buildstream/_signals.py +++ b/src/buildstream/_signals.py @@ -33,6 +33,9 @@ from typing import Callable, Deque terminator_stack: Deque[Callable] = deque() suspendable_stack: Deque[Callable] = deque() +terminator_lock = threading.Lock() +suspendable_lock = threading.Lock() + # Per process SIGTERM handler def terminator_handler(signal_, frame): @@ -80,13 +83,10 @@ def terminator_handler(signal_, frame): def terminator(terminate_func): global terminator_stack # pylint: disable=global-statement - # Signal handling only works in the main thread - if threading.current_thread() != threading.main_thread(): - yield - return - outermost = bool(not terminator_stack) + assert threading.current_thread() == threading.main_thread() or not outermost + terminator_stack.append(terminate_func) if outermost: original_handler = signal.signal(signal.SIGTERM, terminator_handler) @@ -96,7 +96,9 @@ def terminator(terminate_func): finally: if outermost: signal.signal(signal.SIGTERM, original_handler) - terminator_stack.pop() + + with terminator_lock: + terminator_stack.remove(terminate_func) # Just a simple object for holding on to two callbacks @@ -146,6 +148,8 @@ def suspendable(suspend_callback, resume_callback): global suspendable_stack # pylint: disable=global-statement outermost = bool(not suspendable_stack) + assert threading.current_thread() == threading.main_thread() or not outermost + suspender = Suspender(suspend_callback, resume_callback) suspendable_stack.append(suspender) @@ -158,7 +162,8 @@ def suspendable(suspend_callback, resume_callback): if outermost: signal.signal(signal.SIGTSTP, original_stop) - suspendable_stack.pop() + with suspendable_lock: + suspendable_stack.remove(suspender) # blocked() |