diff options
-rw-r--r-- | buildstream/_cas/cascache.py | 18 | ||||
-rw-r--r-- | buildstream/_cas/casremote.py | 72 |
2 files changed, 77 insertions, 13 deletions
diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py index 5a6251815..f24194309 100644 --- a/buildstream/_cas/cascache.py +++ b/buildstream/_cas/cascache.py @@ -33,7 +33,7 @@ from .._protos.buildstream.v2 import buildstream_pb2 from .. import utils from .._exceptions import CASCacheError -from .casremote import BlobNotFound, _CASBatchRead, _CASBatchUpdate +from .casremote import BlobNotFound, _CASBatchRead, _CASBatchUpdate, _retry # A CASCache manages a CAS repository as specified in the Remote Execution API. @@ -203,7 +203,9 @@ class CASCache(): request = buildstream_pb2.GetReferenceRequest(instance_name=remote.spec.instance_name) request.key = ref - response = remote.ref_storage.GetReference(request) + for attempt in _retry(): + with attempt: + response = remote.ref_storage.GetReference(request) tree = remote_execution_pb2.Digest() tree.hash = response.digest.hash @@ -288,7 +290,9 @@ class CASCache(): try: request = buildstream_pb2.GetReferenceRequest(instance_name=remote.spec.instance_name) request.key = ref - response = remote.ref_storage.GetReference(request) + for attempt in _retry(): + with attempt: + response = remote.ref_storage.GetReference(request) if response.digest.hash == tree.hash and response.digest.size_bytes == tree.size_bytes: # ref is already on the server with the same tree @@ -305,7 +309,9 @@ class CASCache(): request.keys.append(ref) request.digest.hash = tree.hash request.digest.size_bytes = tree.size_bytes - remote.ref_storage.UpdateReference(request) + for attempt in _retry(): + with attempt: + remote.ref_storage.UpdateReference(request) skipped_remote = False except grpc.RpcError as e: @@ -983,7 +989,9 @@ class CASCache(): d.hash = required_digest.hash d.size_bytes = required_digest.size_bytes - response = remote.cas.FindMissingBlobs(request) + for attempt in _retry(): + with attempt: + response = remote.cas.FindMissingBlobs(request) for missing_digest in response.missing_blob_digests: d = remote_execution_pb2.Digest() d.hash = missing_digest.hash diff --git a/buildstream/_cas/casremote.py b/buildstream/_cas/casremote.py index 56ba4c5d8..ba8ae7365 100644 --- a/buildstream/_cas/casremote.py +++ b/buildstream/_cas/casremote.py @@ -23,6 +23,46 @@ from .. import utils _MAX_PAYLOAD_BYTES = 1024 * 1024 +class _Attempt(): + + def __init__(self, last_attempt=False): + self.__passed = None + self.__last_attempt = last_attempt + + def passed(self): + return self.__passed + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + try: + if exc_type is None: + self.__passed = True + else: + self.__passed = False + if exc_value is not None: + raise exc_value + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + return not self.__last_attempt + elif e.code() == grpc.StatusCode.ABORTED: + raise CASRemoteError("grpc aborted: {}".format(str(e)), + detail=e.details(), + temporary=True) from e + else: + return False + return False + + +def _retry(tries=5): + for a in range(tries): + attempt = _Attempt(last_attempt=(a == tries - 1)) + yield attempt + if attempt.passed(): + break + + class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')): # _new_from_config_node @@ -133,7 +173,9 @@ class CASRemote(): self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES try: request = remote_execution_pb2.GetCapabilitiesRequest() - response = self.capabilities.GetCapabilities(request) + for attempt in _retry(): + with attempt: + response = self.capabilities.GetCapabilities(request) server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: self.max_batch_total_size_bytes = server_max_batch_total_size_bytes @@ -146,7 +188,9 @@ class CASRemote(): self.batch_read_supported = False try: request = remote_execution_pb2.BatchReadBlobsRequest() - response = self.cas.BatchReadBlobs(request) + for attempt in _retry(): + with attempt: + response = self.cas.BatchReadBlobs(request) self.batch_read_supported = True except grpc.RpcError as e: if e.code() != grpc.StatusCode.UNIMPLEMENTED: @@ -156,7 +200,9 @@ class CASRemote(): self.batch_update_supported = False try: request = remote_execution_pb2.BatchUpdateBlobsRequest() - response = self.cas.BatchUpdateBlobs(request) + for attempt in _retry(): + with attempt: + response = self.cas.BatchUpdateBlobs(request) self.batch_update_supported = True except grpc.RpcError as e: if (e.code() != grpc.StatusCode.UNIMPLEMENTED and @@ -180,7 +226,9 @@ class CASRemote(): remote.init() request = buildstream_pb2.StatusRequest() - response = remote.ref_storage.Status(request) + for attempt in _retry(): + with attempt: + response = remote.ref_storage.Status(request) if remote_spec.push and not response.allow_updates: q.put('CAS server does not allow push') @@ -226,7 +274,9 @@ class CASRemote(): request = remote_execution_pb2.FindMissingBlobsRequest() request.blob_digests.extend([digest]) - response = self.cas.FindMissingBlobs(request) + for attempt in _retry(): + with attempt: + response = self.cas.FindMissingBlobs(request) if digest in response.missing_blob_digests: return False @@ -292,7 +342,9 @@ class CASRemote(): offset += chunk_size finished = request.finish_write - response = self.bytestream.Write(request_stream(resource_name, stream)) + for attempt in _retry(): + with attempt: + response = self.bytestream.Write(request_stream(resource_name, stream)) assert response.committed_size == digest.size_bytes @@ -328,7 +380,9 @@ class _CASBatchRead(): if not self._request.digests: return - batch_response = self._remote.cas.BatchReadBlobs(self._request) + for attempt in _retry(): + with attempt: + batch_response = self._remote.cas.BatchReadBlobs(self._request) for response in batch_response.responses: if response.status.code == code_pb2.NOT_FOUND: @@ -376,7 +430,9 @@ class _CASBatchUpdate(): if not self._request.requests: return - batch_response = self._remote.cas.BatchUpdateBlobs(self._request) + for attempt in _retry(): + with attempt: + batch_response = self._remote.cas.BatchUpdateBlobs(self._request) for response in batch_response.responses: if response.status.code != code_pb2.OK: |