diff options
author | Raoul Hidalgo Charman <raoul.hidalgocharman@codethink.co.uk> | 2018-12-07 17:51:20 +0000 |
---|---|---|
committer | Raoul Hidalgo Charman <raoul.hidalgocharman@codethink.co.uk> | 2019-01-16 11:55:07 +0000 |
commit | d2cc4798f5b32956bcbfcfc91ff1fc8069c446be (patch) | |
tree | 950c45612481e78de04ddbb2130a3b05280367a8 | |
parent | 2683f98ae2aa90e3f23c2750224fe8a7410f336c (diff) | |
download | buildstream-d2cc4798f5b32956bcbfcfc91ff1fc8069c446be.tar.gz |
casremote.py: Move remote CAS classes into its own file
Part of #802
-rw-r--r-- | buildstream/_cas/__init__.py | 3 | ||||
-rw-r--r-- | buildstream/_cas/cascache.py | 275 | ||||
-rw-r--r-- | buildstream/_cas/casremote.py | 247 | ||||
-rw-r--r-- | buildstream/_exceptions.py | 15 |
4 files changed, 283 insertions, 257 deletions
diff --git a/buildstream/_cas/__init__.py b/buildstream/_cas/__init__.py index 7386109bc..a88e41371 100644 --- a/buildstream/_cas/__init__.py +++ b/buildstream/_cas/__init__.py @@ -17,4 +17,5 @@ # Authors: # Tristan Van Berkom <tristan.vanberkom@codethink.co.uk> -from .cascache import CASCache, CASRemote, CASRemoteSpec +from .cascache import CASCache +from .casremote import CASRemote, CASRemoteSpec diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py index 482d4006f..5f62e6105 100644 --- a/buildstream/_cas/cascache.py +++ b/buildstream/_cas/cascache.py @@ -17,7 +17,6 @@ # Authors: # Jürg Billeter <juerg.billeter@codethink.co.uk> -from collections import namedtuple import hashlib import itertools import io @@ -26,76 +25,17 @@ import stat import tempfile import uuid import contextlib -from urllib.parse import urlparse import grpc -from .._protos.google.rpc import code_pb2 -from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc -from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc -from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc +from .._protos.google.bytestream import bytestream_pb2 +from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 +from .._protos.buildstream.v2 import buildstream_pb2 from .. import utils -from .._exceptions import CASError, LoadError, LoadErrorReason -from .. import _yaml +from .._exceptions import CASCacheError - -# The default limit for gRPC messages is 4 MiB. -# Limit payload to 1 MiB to leave sufficient headroom for metadata. -_MAX_PAYLOAD_BYTES = 1024 * 1024 - - -class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')): - - # _new_from_config_node - # - # Creates an CASRemoteSpec() from a YAML loaded node - # - @staticmethod - def _new_from_config_node(spec_node, basedir=None): - _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance-name']) - url = _yaml.node_get(spec_node, str, 'url') - push = _yaml.node_get(spec_node, bool, 'push', default_value=False) - if not url: - provenance = _yaml.node_get_provenance(spec_node, 'url') - raise LoadError(LoadErrorReason.INVALID_DATA, - "{}: empty artifact cache URL".format(provenance)) - - instance_name = _yaml.node_get(spec_node, str, 'instance-name', default_value=None) - - server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None) - if server_cert and basedir: - server_cert = os.path.join(basedir, server_cert) - - client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None) - if client_key and basedir: - client_key = os.path.join(basedir, client_key) - - client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None) - if client_cert and basedir: - client_cert = os.path.join(basedir, client_cert) - - if client_key and not client_cert: - provenance = _yaml.node_get_provenance(spec_node, 'client-key') - raise LoadError(LoadErrorReason.INVALID_DATA, - "{}: 'client-key' was specified without 'client-cert'".format(provenance)) - - if client_cert and not client_key: - provenance = _yaml.node_get_provenance(spec_node, 'client-cert') - raise LoadError(LoadErrorReason.INVALID_DATA, - "{}: 'client-cert' was specified without 'client-key'".format(provenance)) - - return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name) - - -CASRemoteSpec.__new__.__defaults__ = (None, None, None, None) - - -class BlobNotFound(CASError): - - def __init__(self, blob, msg): - self.blob = blob - super().__init__(msg) +from .casremote import CASRemote, BlobNotFound, _CASBatchRead, _CASBatchUpdate, _MAX_PAYLOAD_BYTES # A CASCache manages a CAS repository as specified in the Remote Execution API. @@ -120,7 +60,7 @@ class CASCache(): headdir = os.path.join(self.casdir, 'refs', 'heads') objdir = os.path.join(self.casdir, 'objects') if not (os.path.isdir(headdir) and os.path.isdir(objdir)): - raise CASError("CAS repository check failed for '{}'".format(self.casdir)) + raise CASCacheError("CAS repository check failed for '{}'".format(self.casdir)) # contains(): # @@ -169,7 +109,7 @@ class CASCache(): # subdir (str): Optional specific dir to extract # # Raises: - # CASError: In cases there was an OSError, or if the ref did not exist. + # CASCacheError: In cases there was an OSError, or if the ref did not exist. # # Returns: path to extracted directory # @@ -201,7 +141,7 @@ class CASCache(): # Another process beat us to rename pass except OSError as e: - raise CASError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e + raise CASCacheError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e return originaldest @@ -306,7 +246,7 @@ class CASCache(): return True except grpc.RpcError as e: if e.code() != grpc.StatusCode.NOT_FOUND: - raise CASError("Failed to pull ref {}: {}".format(ref, e)) from e + raise CASCacheError("Failed to pull ref {}: {}".format(ref, e)) from e else: return False except BlobNotFound as e: @@ -360,7 +300,7 @@ class CASCache(): # (bool): True if any remote was updated, False if no pushes were required # # Raises: - # (CASError): if there was an error + # (CASCacheError): if there was an error # def push(self, refs, remote): skipped_remote = True @@ -395,7 +335,7 @@ class CASCache(): skipped_remote = False except grpc.RpcError as e: if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED: - raise CASError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e + raise CASCacheError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e return not skipped_remote @@ -408,7 +348,7 @@ class CASCache(): # directory (Directory): A virtual directory object to push. # # Raises: - # (CASError): if there was an error + # (CASCacheError): if there was an error # def push_directory(self, remote, directory): remote.init() @@ -424,7 +364,7 @@ class CASCache(): # message (Message): A protobuf message to push. # # Raises: - # (CASError): if there was an error + # (CASCacheError): if there was an error # def push_message(self, remote, message): @@ -531,7 +471,7 @@ class CASCache(): pass except OSError as e: - raise CASError("Failed to hash object: {}".format(e)) from e + raise CASCacheError("Failed to hash object: {}".format(e)) from e return digest @@ -572,7 +512,7 @@ class CASCache(): return digest except FileNotFoundError as e: - raise CASError("Attempt to access unavailable ref: {}".format(e)) from e + raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e # update_mtime() # @@ -585,7 +525,7 @@ class CASCache(): try: os.utime(self._refpath(ref)) except FileNotFoundError as e: - raise CASError("Attempt to access unavailable ref: {}".format(e)) from e + raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e # calculate_cache_size() # @@ -676,7 +616,7 @@ class CASCache(): # Remove cache ref refpath = self._refpath(ref) if not os.path.exists(refpath): - raise CASError("Could not find ref '{}'".format(ref)) + raise CASCacheError("Could not find ref '{}'".format(ref)) os.unlink(refpath) @@ -792,7 +732,7 @@ class CASCache(): # The process serving the socket can't be cached anyway pass else: - raise CASError("Unsupported file type for {}".format(full_path)) + raise CASCacheError("Unsupported file type for {}".format(full_path)) return self.add_object(digest=dir_digest, buffer=directory.SerializeToString()) @@ -811,7 +751,7 @@ class CASCache(): if dirnode.name == name: return dirnode.digest - raise CASError("Subdirectory {} not found".format(name)) + raise CASCacheError("Subdirectory {} not found".format(name)) def _diff_trees(self, tree_a, tree_b, *, added, removed, modified, path=""): dir_a = remote_execution_pb2.Directory() @@ -1150,183 +1090,6 @@ class CASCache(): batch.send() -# Represents a single remote CAS cache. -# -class CASRemote(): - def __init__(self, spec): - self.spec = spec - self._initialized = False - self.channel = None - self.bytestream = None - self.cas = None - self.ref_storage = None - self.batch_update_supported = None - self.batch_read_supported = None - self.capabilities = None - self.max_batch_total_size_bytes = None - - def init(self): - if not self._initialized: - url = urlparse(self.spec.url) - if url.scheme == 'http': - port = url.port or 80 - self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port)) - elif url.scheme == 'https': - port = url.port or 443 - - if self.spec.server_cert: - with open(self.spec.server_cert, 'rb') as f: - server_cert_bytes = f.read() - else: - server_cert_bytes = None - - if self.spec.client_key: - with open(self.spec.client_key, 'rb') as f: - client_key_bytes = f.read() - else: - client_key_bytes = None - - if self.spec.client_cert: - with open(self.spec.client_cert, 'rb') as f: - client_cert_bytes = f.read() - else: - client_cert_bytes = None - - credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes, - private_key=client_key_bytes, - certificate_chain=client_cert_bytes) - self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials) - else: - raise CASError("Unsupported URL: {}".format(self.spec.url)) - - self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) - self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) - self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) - self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) - - self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES - try: - request = remote_execution_pb2.GetCapabilitiesRequest(instance_name=self.spec.instance_name) - response = self.capabilities.GetCapabilities(request) - server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes - if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: - self.max_batch_total_size_bytes = server_max_batch_total_size_bytes - except grpc.RpcError as e: - # Simply use the defaults for servers that don't implement GetCapabilities() - if e.code() != grpc.StatusCode.UNIMPLEMENTED: - raise - - # Check whether the server supports BatchReadBlobs() - self.batch_read_supported = False - try: - request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=self.spec.instance_name) - response = self.cas.BatchReadBlobs(request) - self.batch_read_supported = True - except grpc.RpcError as e: - if e.code() != grpc.StatusCode.UNIMPLEMENTED: - raise - - # Check whether the server supports BatchUpdateBlobs() - self.batch_update_supported = False - try: - request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=self.spec.instance_name) - response = self.cas.BatchUpdateBlobs(request) - self.batch_update_supported = True - except grpc.RpcError as e: - if (e.code() != grpc.StatusCode.UNIMPLEMENTED and - e.code() != grpc.StatusCode.PERMISSION_DENIED): - raise - - self._initialized = True - - -# Represents a batch of blobs queued for fetching. -# -class _CASBatchRead(): - def __init__(self, remote): - self._remote = remote - self._max_total_size_bytes = remote.max_batch_total_size_bytes - self._request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=remote.spec.instance_name) - self._size = 0 - self._sent = False - - def add(self, digest): - assert not self._sent - - new_batch_size = self._size + digest.size_bytes - if new_batch_size > self._max_total_size_bytes: - # Not enough space left in current batch - return False - - request_digest = self._request.digests.add() - request_digest.hash = digest.hash - request_digest.size_bytes = digest.size_bytes - self._size = new_batch_size - return True - - def send(self): - assert not self._sent - self._sent = True - - if not self._request.digests: - return - - batch_response = self._remote.cas.BatchReadBlobs(self._request) - - for response in batch_response.responses: - if response.status.code == code_pb2.NOT_FOUND: - raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format( - response.digest.hash, response.status.code)) - if response.status.code != code_pb2.OK: - raise CASError("Failed to download blob {}: {}".format( - response.digest.hash, response.status.code)) - if response.digest.size_bytes != len(response.data): - raise CASError("Failed to download blob {}: expected {} bytes, received {} bytes".format( - response.digest.hash, response.digest.size_bytes, len(response.data))) - - yield (response.digest, response.data) - - -# Represents a batch of blobs queued for upload. -# -class _CASBatchUpdate(): - def __init__(self, remote): - self._remote = remote - self._max_total_size_bytes = remote.max_batch_total_size_bytes - self._request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=remote.spec.instance_name) - self._size = 0 - self._sent = False - - def add(self, digest, stream): - assert not self._sent - - new_batch_size = self._size + digest.size_bytes - if new_batch_size > self._max_total_size_bytes: - # Not enough space left in current batch - return False - - blob_request = self._request.requests.add() - blob_request.digest.hash = digest.hash - blob_request.digest.size_bytes = digest.size_bytes - blob_request.data = stream.read(digest.size_bytes) - self._size = new_batch_size - return True - - def send(self): - assert not self._sent - self._sent = True - - if not self._request.requests: - return - - batch_response = self._remote.cas.BatchUpdateBlobs(self._request) - - for response in batch_response.responses: - if response.status.code != code_pb2.OK: - raise CASError("Failed to upload blob {}: {}".format( - response.digest.hash, response.status.code)) - - def _grouper(iterable, n): while True: try: diff --git a/buildstream/_cas/casremote.py b/buildstream/_cas/casremote.py new file mode 100644 index 000000000..59eb7e363 --- /dev/null +++ b/buildstream/_cas/casremote.py @@ -0,0 +1,247 @@ +from collections import namedtuple +import os +from urllib.parse import urlparse + +import grpc + +from .. import _yaml +from .._protos.google.rpc import code_pb2 +from .._protos.google.bytestream import bytestream_pb2_grpc +from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc +from .._protos.buildstream.v2 import buildstream_pb2_grpc + +from .._exceptions import CASRemoteError, LoadError, LoadErrorReason + +# The default limit for gRPC messages is 4 MiB. +# Limit payload to 1 MiB to leave sufficient headroom for metadata. +_MAX_PAYLOAD_BYTES = 1024 * 1024 + + +class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')): + + # _new_from_config_node + # + # Creates an CASRemoteSpec() from a YAML loaded node + # + @staticmethod + def _new_from_config_node(spec_node, basedir=None): + _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance_name']) + url = _yaml.node_get(spec_node, str, 'url') + push = _yaml.node_get(spec_node, bool, 'push', default_value=False) + if not url: + provenance = _yaml.node_get_provenance(spec_node, 'url') + raise LoadError(LoadErrorReason.INVALID_DATA, + "{}: empty artifact cache URL".format(provenance)) + + instance_name = _yaml.node_get(spec_node, str, 'server-cert', default_value=None) + + server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None) + if server_cert and basedir: + server_cert = os.path.join(basedir, server_cert) + + client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None) + if client_key and basedir: + client_key = os.path.join(basedir, client_key) + + client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None) + if client_cert and basedir: + client_cert = os.path.join(basedir, client_cert) + + if client_key and not client_cert: + provenance = _yaml.node_get_provenance(spec_node, 'client-key') + raise LoadError(LoadErrorReason.INVALID_DATA, + "{}: 'client-key' was specified without 'client-cert'".format(provenance)) + + if client_cert and not client_key: + provenance = _yaml.node_get_provenance(spec_node, 'client-cert') + raise LoadError(LoadErrorReason.INVALID_DATA, + "{}: 'client-cert' was specified without 'client-key'".format(provenance)) + + return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name) + + +CASRemoteSpec.__new__.__defaults__ = (None, None, None, None) + + +class BlobNotFound(CASRemoteError): + + def __init__(self, blob, msg): + self.blob = blob + super().__init__(msg) + + +# Represents a single remote CAS cache. +# +class CASRemote(): + def __init__(self, spec): + self.spec = spec + self._initialized = False + self.channel = None + self.bytestream = None + self.cas = None + self.ref_storage = None + self.batch_update_supported = None + self.batch_read_supported = None + self.capabilities = None + self.max_batch_total_size_bytes = None + + def init(self): + if not self._initialized: + url = urlparse(self.spec.url) + if url.scheme == 'http': + port = url.port or 80 + self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port)) + elif url.scheme == 'https': + port = url.port or 443 + + if self.spec.server_cert: + with open(self.spec.server_cert, 'rb') as f: + server_cert_bytes = f.read() + else: + server_cert_bytes = None + + if self.spec.client_key: + with open(self.spec.client_key, 'rb') as f: + client_key_bytes = f.read() + else: + client_key_bytes = None + + if self.spec.client_cert: + with open(self.spec.client_cert, 'rb') as f: + client_cert_bytes = f.read() + else: + client_cert_bytes = None + + credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes, + private_key=client_key_bytes, + certificate_chain=client_cert_bytes) + self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials) + else: + raise CASRemoteError("Unsupported URL: {}".format(self.spec.url)) + + self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) + self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) + self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) + + self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES + try: + request = remote_execution_pb2.GetCapabilitiesRequest() + response = self.capabilities.GetCapabilities(request) + server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes + if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: + self.max_batch_total_size_bytes = server_max_batch_total_size_bytes + except grpc.RpcError as e: + # Simply use the defaults for servers that don't implement GetCapabilities() + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + + # Check whether the server supports BatchReadBlobs() + self.batch_read_supported = False + try: + request = remote_execution_pb2.BatchReadBlobsRequest() + response = self.cas.BatchReadBlobs(request) + self.batch_read_supported = True + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + + # Check whether the server supports BatchUpdateBlobs() + self.batch_update_supported = False + try: + request = remote_execution_pb2.BatchUpdateBlobsRequest() + response = self.cas.BatchUpdateBlobs(request) + self.batch_update_supported = True + except grpc.RpcError as e: + if (e.code() != grpc.StatusCode.UNIMPLEMENTED and + e.code() != grpc.StatusCode.PERMISSION_DENIED): + raise + + self._initialized = True + + +# Represents a batch of blobs queued for fetching. +# +class _CASBatchRead(): + def __init__(self, remote): + self._remote = remote + self._max_total_size_bytes = remote.max_batch_total_size_bytes + self._request = remote_execution_pb2.BatchReadBlobsRequest() + self._size = 0 + self._sent = False + + def add(self, digest): + assert not self._sent + + new_batch_size = self._size + digest.size_bytes + if new_batch_size > self._max_total_size_bytes: + # Not enough space left in current batch + return False + + request_digest = self._request.digests.add() + request_digest.hash = digest.hash + request_digest.size_bytes = digest.size_bytes + self._size = new_batch_size + return True + + def send(self): + assert not self._sent + self._sent = True + + if not self._request.digests: + return + + batch_response = self._remote.cas.BatchReadBlobs(self._request) + + for response in batch_response.responses: + if response.status.code == code_pb2.NOT_FOUND: + raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format( + response.digest.hash, response.status.code)) + if response.status.code != code_pb2.OK: + raise CASRemoteError("Failed to download blob {}: {}".format( + response.digest.hash, response.status.code)) + if response.digest.size_bytes != len(response.data): + raise CASRemoteError("Failed to download blob {}: expected {} bytes, received {} bytes".format( + response.digest.hash, response.digest.size_bytes, len(response.data))) + + yield (response.digest, response.data) + + +# Represents a batch of blobs queued for upload. +# +class _CASBatchUpdate(): + def __init__(self, remote): + self._remote = remote + self._max_total_size_bytes = remote.max_batch_total_size_bytes + self._request = remote_execution_pb2.BatchUpdateBlobsRequest() + self._size = 0 + self._sent = False + + def add(self, digest, stream): + assert not self._sent + + new_batch_size = self._size + digest.size_bytes + if new_batch_size > self._max_total_size_bytes: + # Not enough space left in current batch + return False + + blob_request = self._request.requests.add() + blob_request.digest.hash = digest.hash + blob_request.digest.size_bytes = digest.size_bytes + blob_request.data = stream.read(digest.size_bytes) + self._size = new_batch_size + return True + + def send(self): + assert not self._sent + self._sent = True + + if not self._request.requests: + return + + batch_response = self._remote.cas.BatchUpdateBlobs(self._request) + + for response in batch_response.responses: + if response.status.code != code_pb2.OK: + raise CASRemoteError("Failed to upload blob {}: {}".format( + response.digest.hash, response.status.code)) diff --git a/buildstream/_exceptions.py b/buildstream/_exceptions.py index ba0b9fabb..ea5ea62f2 100644 --- a/buildstream/_exceptions.py +++ b/buildstream/_exceptions.py @@ -284,6 +284,21 @@ class CASError(BstError): super().__init__(message, detail=detail, domain=ErrorDomain.CAS, reason=reason, temporary=True) +# CASRemoteError +# +# Raised when errors are encountered in the remote CAS +class CASRemoteError(CASError): + pass + + +# CASCacheError +# +# Raised when errors are encountered in the local CASCacheError +# +class CASCacheError(CASError): + pass + + # PipelineError # # Raised from pipeline operations |