diff options
author | Tristan Van Berkom <tristan.van.berkom@gmail.com> | 2018-10-03 08:05:47 +0000 |
---|---|---|
committer | Tristan Van Berkom <tristan.van.berkom@gmail.com> | 2018-10-03 08:05:47 +0000 |
commit | 244e3c7c5992fc863f0b363f86c4d3e32be894fb (patch) | |
tree | 676165f4d633eaab70a5dddd42f0b715a046999e | |
parent | 6e820362700e4083cb57a8b8603a20e373f30cee (diff) | |
parent | f585b23314535af51fbd1eb80d695550188cfa99 (diff) | |
download | buildstream-244e3c7c5992fc863f0b363f86c4d3e32be894fb.tar.gz |
Merge branch 'juerg/cas-batch-1.2' into 'bst-1.2'
CAS: Implement BatchUpdateBlobs support
See merge request BuildStream/buildstream!844
-rw-r--r-- | buildstream/_artifactcache/cascache.py | 275 | ||||
-rw-r--r-- | buildstream/_artifactcache/casserver.py | 45 |
2 files changed, 224 insertions, 96 deletions
diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_artifactcache/cascache.py index e2c0d44b5..14932fba2 100644 --- a/buildstream/_artifactcache/cascache.py +++ b/buildstream/_artifactcache/cascache.py @@ -81,6 +81,7 @@ class CASCache(ArtifactCache): ################################################ # Implementation of abstract methods # ################################################ + def contains(self, element, key): refpath = self._refpath(self.get_artifact_fullname(element, key)) @@ -156,6 +157,7 @@ class CASCache(ArtifactCache): q = multiprocessing.Queue() for remote_spec in remote_specs: # Use subprocess to avoid creation of gRPC threads in main BuildStream process + # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details p = multiprocessing.Process(target=self._initialize_remote, args=(remote_spec, q)) try: @@ -268,109 +270,69 @@ class CASCache(ArtifactCache): self.set_ref(newref, tree) + def _push_refs_to_remote(self, refs, remote): + skipped_remote = True + try: + for ref in refs: + tree = self.resolve_ref(ref) + + # Check whether ref is already on the server in which case + # there is no need to push the artifact + try: + request = buildstream_pb2.GetReferenceRequest() + request.key = ref + response = remote.ref_storage.GetReference(request) + + if response.digest.hash == tree.hash and response.digest.size_bytes == tree.size_bytes: + # ref is already on the server with the same tree + continue + + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.NOT_FOUND: + # Intentionally re-raise RpcError for outer except block. + raise + + self._send_directory(remote, tree) + + request = buildstream_pb2.UpdateReferenceRequest() + request.keys.append(ref) + request.digest.hash = tree.hash + request.digest.size_bytes = tree.size_bytes + remote.ref_storage.UpdateReference(request) + + skipped_remote = False + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED: + raise ArtifactError("Failed to push artifact {}: {}".format(refs, e), temporary=True) from e + + return not skipped_remote + def push(self, element, keys): - refs = [self.get_artifact_fullname(element, key) for key in keys] + + refs = [self.get_artifact_fullname(element, key) for key in list(keys)] project = element._get_project() push_remotes = [r for r in self._remotes[project] if r.spec.push] pushed = False - display_key = element._get_brief_display_key() + for remote in push_remotes: remote.init() - skipped_remote = True + display_key = element._get_brief_display_key() element.status("Pushing artifact {} -> {}".format(display_key, remote.spec.url)) - try: - for ref in refs: - tree = self.resolve_ref(ref) - - # Check whether ref is already on the server in which case - # there is no need to push the artifact - try: - request = buildstream_pb2.GetReferenceRequest() - request.key = ref - response = remote.ref_storage.GetReference(request) - - if response.digest.hash == tree.hash and response.digest.size_bytes == tree.size_bytes: - # ref is already on the server with the same tree - continue - - except grpc.RpcError as e: - if e.code() != grpc.StatusCode.NOT_FOUND: - # Intentionally re-raise RpcError for outer except block. - raise - - missing_blobs = {} - required_blobs = self._required_blobs(tree) - - # Limit size of FindMissingBlobs request - for required_blobs_group in _grouper(required_blobs, 512): - request = remote_execution_pb2.FindMissingBlobsRequest() - - for required_digest in required_blobs_group: - d = request.blob_digests.add() - d.hash = required_digest.hash - d.size_bytes = required_digest.size_bytes - - response = remote.cas.FindMissingBlobs(request) - for digest in response.missing_blob_digests: - d = remote_execution_pb2.Digest() - d.hash = digest.hash - d.size_bytes = digest.size_bytes - missing_blobs[d.hash] = d - - # Upload any blobs missing on the server - skipped_remote = False - for digest in missing_blobs.values(): - uuid_ = uuid.uuid4() - resource_name = '/'.join(['uploads', str(uuid_), 'blobs', - digest.hash, str(digest.size_bytes)]) - - def request_stream(resname): - with open(self.objpath(digest), 'rb') as f: - assert os.fstat(f.fileno()).st_size == digest.size_bytes - offset = 0 - finished = False - remaining = digest.size_bytes - while not finished: - chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) - remaining -= chunk_size - - request = bytestream_pb2.WriteRequest() - request.write_offset = offset - # max. _MAX_PAYLOAD_BYTES chunks - request.data = f.read(chunk_size) - request.resource_name = resname - request.finish_write = remaining <= 0 - yield request - offset += chunk_size - finished = request.finish_write - response = remote.bytestream.Write(request_stream(resource_name)) - - request = buildstream_pb2.UpdateReferenceRequest() - request.keys.append(ref) - request.digest.hash = tree.hash - request.digest.size_bytes = tree.size_bytes - remote.ref_storage.UpdateReference(request) - - pushed = True - - if not skipped_remote: - element.info("Pushed artifact {} -> {}".format(display_key, remote.spec.url)) - - except grpc.RpcError as e: - if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED: - raise ArtifactError("Failed to push artifact {}: {}".format(refs, e), temporary=True) from e - - if skipped_remote: + if self._push_refs_to_remote(refs, remote): + element.info("Pushed artifact {} -> {}".format(display_key, remote.spec.url)) + pushed = True + else: self.context.message(Message( None, MessageType.INFO, "Remote ({}) already has {} cached".format( remote.spec.url, element._get_brief_display_key()) )) + return pushed ################################################ @@ -599,6 +561,7 @@ class CASCache(ArtifactCache): ################################################ # Local Private Methods # ################################################ + def _checkout(self, dest, tree): os.makedirs(dest, exist_ok=True) @@ -776,16 +739,16 @@ class CASCache(ArtifactCache): # q.put(str(e)) - def _required_blobs(self, tree): + def _required_blobs(self, directory_digest): # parse directory, and recursively add blobs d = remote_execution_pb2.Digest() - d.hash = tree.hash - d.size_bytes = tree.size_bytes + d.hash = directory_digest.hash + d.size_bytes = directory_digest.size_bytes yield d directory = remote_execution_pb2.Directory() - with open(self.objpath(tree), 'rb') as f: + with open(self.objpath(directory_digest), 'rb') as f: directory.ParseFromString(f.read()) for filenode in directory.files: @@ -797,16 +760,16 @@ class CASCache(ArtifactCache): for dirnode in directory.directories: yield from self._required_blobs(dirnode.digest) - def _fetch_blob(self, remote, digest, out): + def _fetch_blob(self, remote, digest, stream): resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)]) request = bytestream_pb2.ReadRequest() request.resource_name = resource_name request.read_offset = 0 for response in remote.bytestream.Read(request): - out.write(response.data) + stream.write(response.data) + stream.flush() - out.flush() - assert digest.size_bytes == os.fstat(out.fileno()).st_size + assert digest.size_bytes == os.fstat(stream.fileno()).st_size # _ensure_blob(): # @@ -922,6 +885,79 @@ class CASCache(ArtifactCache): # Fetch final batch self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) + def _send_blob(self, remote, digest, stream, u_uid=uuid.uuid4()): + resource_name = '/'.join(['uploads', str(u_uid), 'blobs', + digest.hash, str(digest.size_bytes)]) + + def request_stream(resname, instream): + offset = 0 + finished = False + remaining = digest.size_bytes + while not finished: + chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) + remaining -= chunk_size + + request = bytestream_pb2.WriteRequest() + request.write_offset = offset + # max. _MAX_PAYLOAD_BYTES chunks + request.data = instream.read(chunk_size) + request.resource_name = resname + request.finish_write = remaining <= 0 + + yield request + + offset += chunk_size + finished = request.finish_write + + response = remote.bytestream.Write(request_stream(resource_name, stream)) + + assert response.committed_size == digest.size_bytes + + def _send_directory(self, remote, digest, u_uid=uuid.uuid4()): + required_blobs = self._required_blobs(digest) + + missing_blobs = dict() + # Limit size of FindMissingBlobs request + for required_blobs_group in _grouper(required_blobs, 512): + request = remote_execution_pb2.FindMissingBlobsRequest() + + for required_digest in required_blobs_group: + d = request.blob_digests.add() + d.hash = required_digest.hash + d.size_bytes = required_digest.size_bytes + + response = remote.cas.FindMissingBlobs(request) + for missing_digest in response.missing_blob_digests: + d = remote_execution_pb2.Digest() + d.hash = missing_digest.hash + d.size_bytes = missing_digest.size_bytes + missing_blobs[d.hash] = d + + # Upload any blobs missing on the server + self._send_blobs(remote, missing_blobs.values(), u_uid) + + def _send_blobs(self, remote, digests, u_uid=uuid.uuid4()): + batch = _CASBatchUpdate(remote) + + for digest in digests: + with open(self.objpath(digest), 'rb') as f: + assert os.fstat(f.fileno()).st_size == digest.size_bytes + + if (digest.size_bytes >= remote.max_batch_total_size_bytes or + not remote.batch_update_supported): + # Too large for batch request, upload in independent request. + self._send_blob(remote, digest, f, u_uid=u_uid) + else: + if not batch.add(digest, f): + # Not enough space left in batch request. + # Complete pending batch first. + batch.send() + batch = _CASBatchUpdate(remote) + batch.add(digest, f) + + # Send final batch + batch.send() + # Represents a single remote CAS cache. # @@ -995,6 +1031,17 @@ class _CASRemote(): if e.code() != grpc.StatusCode.UNIMPLEMENTED: raise + # Check whether the server supports BatchUpdateBlobs() + self.batch_update_supported = False + try: + request = remote_execution_pb2.BatchUpdateBlobsRequest() + response = self.cas.BatchUpdateBlobs(request) + self.batch_update_supported = True + except grpc.RpcError as e: + if (e.code() != grpc.StatusCode.UNIMPLEMENTED and + e.code() != grpc.StatusCode.PERMISSION_DENIED): + raise + self._initialized = True @@ -1042,6 +1089,46 @@ class _CASBatchRead(): yield (response.digest, response.data) +# Represents a batch of blobs queued for upload. +# +class _CASBatchUpdate(): + def __init__(self, remote): + self._remote = remote + self._max_total_size_bytes = remote.max_batch_total_size_bytes + self._request = remote_execution_pb2.BatchUpdateBlobsRequest() + self._size = 0 + self._sent = False + + def add(self, digest, stream): + assert not self._sent + + new_batch_size = self._size + digest.size_bytes + if new_batch_size > self._max_total_size_bytes: + # Not enough space left in current batch + return False + + blob_request = self._request.requests.add() + blob_request.digest.hash = digest.hash + blob_request.digest.size_bytes = digest.size_bytes + blob_request.data = stream.read(digest.size_bytes) + self._size = new_batch_size + return True + + def send(self): + assert not self._sent + self._sent = True + + if len(self._request.requests) == 0: + return + + batch_response = self._remote.cas.BatchUpdateBlobs(self._request) + + for response in batch_response.responses: + if response.status.code != grpc.StatusCode.OK.value[0]: + raise ArtifactError("Failed to upload blob {}: {}".format( + response.digest.hash, response.status.code)) + + def _grouper(iterable, n): while True: try: diff --git a/buildstream/_artifactcache/casserver.py b/buildstream/_artifactcache/casserver.py index d833878d5..62d06f3ce 100644 --- a/buildstream/_artifactcache/casserver.py +++ b/buildstream/_artifactcache/casserver.py @@ -70,7 +70,7 @@ def create_server(repo, *, enable_push): _ByteStreamServicer(artifactcache, enable_push=enable_push), server) remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server( - _ContentAddressableStorageServicer(artifactcache), server) + _ContentAddressableStorageServicer(artifactcache, enable_push=enable_push), server) remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server( _CapabilitiesServicer(), server) @@ -224,9 +224,10 @@ class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer): class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddressableStorageServicer): - def __init__(self, cas): + def __init__(self, cas, *, enable_push): super().__init__() self.cas = cas + self.enable_push = enable_push def FindMissingBlobs(self, request, context): response = remote_execution_pb2.FindMissingBlobsResponse() @@ -262,6 +263,46 @@ class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddres return response + def BatchUpdateBlobs(self, request, context): + response = remote_execution_pb2.BatchUpdateBlobsResponse() + + if not self.enable_push: + context.set_code(grpc.StatusCode.PERMISSION_DENIED) + return response + + batch_size = 0 + + for blob_request in request.requests: + digest = blob_request.digest + + batch_size += digest.size_bytes + if batch_size > _MAX_PAYLOAD_BYTES: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + return response + + blob_response = response.responses.add() + blob_response.digest.hash = digest.hash + blob_response.digest.size_bytes = digest.size_bytes + + if len(blob_request.data) != digest.size_bytes: + blob_response.status.code = grpc.StatusCode.FAILED_PRECONDITION + continue + + try: + _clean_up_cache(self.cas, digest.size_bytes) + + with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out: + out.write(blob_request.data) + out.flush() + server_digest = self.cas.add_object(path=out.name) + if server_digest.hash != digest.hash: + blob_response.status.code = grpc.StatusCode.FAILED_PRECONDITION + + except ArtifactTooLargeException: + blob_response.status.code = grpc.StatusCode.RESOURCE_EXHAUSTED + + return response + class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer): def GetCapabilities(self, request, context): |