summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/buildstream/_cas/casserver.py7
-rw-r--r--src/buildstream/_signals.py19
2 files changed, 18 insertions, 8 deletions
diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py
index 013fb07dd..04c5eb836 100644
--- a/src/buildstream/_cas/casserver.py
+++ b/src/buildstream/_cas/casserver.py
@@ -30,6 +30,7 @@ import grpc
import click
from .._protos.build.bazel.remote.asset.v1 import remote_asset_pb2_grpc
+from .. import _signals
from .._protos.build.bazel.remote.execution.v2 import (
remote_execution_pb2,
remote_execution_pb2_grpc,
@@ -137,7 +138,11 @@ def create_server(repo, *, enable_push, quota, index_only, log_level=LogLevel.Le
_ReferenceStorageServicer(casd_channel, root, enable_push=enable_push), server
)
- yield server
+ # Ensure we have the signal handler set for SIGTERM
+ # This allows threads from GRPC to call our methods that do register
+ # handlers at exit.
+ with _signals.terminator(lambda: None):
+ yield server
finally:
casd_channel.close()
diff --git a/src/buildstream/_signals.py b/src/buildstream/_signals.py
index 03b55b052..8032752a8 100644
--- a/src/buildstream/_signals.py
+++ b/src/buildstream/_signals.py
@@ -33,6 +33,9 @@ from typing import Callable, Deque
terminator_stack: Deque[Callable] = deque()
suspendable_stack: Deque[Callable] = deque()
+terminator_lock = threading.Lock()
+suspendable_lock = threading.Lock()
+
# Per process SIGTERM handler
def terminator_handler(signal_, frame):
@@ -80,13 +83,10 @@ def terminator_handler(signal_, frame):
def terminator(terminate_func):
global terminator_stack # pylint: disable=global-statement
- # Signal handling only works in the main thread
- if threading.current_thread() != threading.main_thread():
- yield
- return
-
outermost = bool(not terminator_stack)
+ assert threading.current_thread() == threading.main_thread() or not outermost
+
terminator_stack.append(terminate_func)
if outermost:
original_handler = signal.signal(signal.SIGTERM, terminator_handler)
@@ -96,7 +96,9 @@ def terminator(terminate_func):
finally:
if outermost:
signal.signal(signal.SIGTERM, original_handler)
- terminator_stack.pop()
+
+ with terminator_lock:
+ terminator_stack.remove(terminate_func)
# Just a simple object for holding on to two callbacks
@@ -146,6 +148,8 @@ def suspendable(suspend_callback, resume_callback):
global suspendable_stack # pylint: disable=global-statement
outermost = bool(not suspendable_stack)
+ assert threading.current_thread() == threading.main_thread() or not outermost
+
suspender = Suspender(suspend_callback, resume_callback)
suspendable_stack.append(suspender)
@@ -158,7 +162,8 @@ def suspendable(suspend_callback, resume_callback):
if outermost:
signal.signal(signal.SIGTSTP, original_stop)
- suspendable_stack.pop()
+ with suspendable_lock:
+ suspendable_stack.remove(suspender)
# blocked()