diff options
author | Jürg Billeter <j@bitron.ch> | 2018-09-25 09:31:55 +0000 |
---|---|---|
committer | Jürg Billeter <j@bitron.ch> | 2018-09-25 09:31:55 +0000 |
commit | 81c51dbfc472e331c67ca656c54ca4055c7f91fe (patch) | |
tree | 8012bcd42ad23919061af531fec013f932ebd54c | |
parent | a76339deb7303901e8602635523b1776b4a1bb0c (diff) | |
parent | 697d10f298c0d63d15e8e7c9e19f3a581ed9fd50 (diff) | |
download | buildstream-81c51dbfc472e331c67ca656c54ca4055c7f91fe.tar.gz |
Merge branch 'juerg/cas-batch' into 'master'
_artifactcache/cascache.py: Use BatchReadBlobs
Closes #554
See merge request BuildStream/buildstream!813
-rw-r--r-- | buildstream/_artifactcache/cascache.py | 206 | ||||
-rw-r--r-- | buildstream/_artifactcache/casserver.py | 11 |
2 files changed, 174 insertions, 43 deletions
diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_artifactcache/cascache.py index 50f74a927..56c70ba9b 100644 --- a/buildstream/_artifactcache/cascache.py +++ b/buildstream/_artifactcache/cascache.py @@ -44,6 +44,11 @@ from .._exceptions import ArtifactError from . import ArtifactCache +# The default limit for gRPC messages is 4 MiB. +# Limit payload to 1 MiB to leave sufficient headroom for metadata. +_MAX_PAYLOAD_BYTES = 1024 * 1024 + + # A CASCache manages artifacts in a CAS repository as specified in the # Remote Execution API. # @@ -854,6 +859,80 @@ class CASCache(ArtifactCache): assert digest.size_bytes == os.fstat(stream.fileno()).st_size + # _ensure_blob(): + # + # Fetch and add blob if it's not already local. + # + # Args: + # remote (Remote): The remote to use. + # digest (Digest): Digest object for the blob to fetch. + # + # Returns: + # (str): The path of the object + # + def _ensure_blob(self, remote, digest): + objpath = self.objpath(digest) + if os.path.exists(objpath): + # already in local repository + return objpath + + with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: + self._fetch_blob(remote, digest, f) + + added_digest = self.add_object(path=f.name) + assert added_digest.hash == digest.hash + + return objpath + + def _batch_download_complete(self, batch): + for digest, data in batch.send(): + with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: + f.write(data) + f.flush() + + added_digest = self.add_object(path=f.name) + assert added_digest.hash == digest.hash + + # Helper function for _fetch_directory(). + def _fetch_directory_batch(self, remote, batch, fetch_queue, fetch_next_queue): + self._batch_download_complete(batch) + + # All previously scheduled directories are now locally available, + # move them to the processing queue. + fetch_queue.extend(fetch_next_queue) + fetch_next_queue.clear() + return _CASBatchRead(remote) + + # Helper function for _fetch_directory(). + def _fetch_directory_node(self, remote, digest, batch, fetch_queue, fetch_next_queue, *, recursive=False): + in_local_cache = os.path.exists(self.objpath(digest)) + + if in_local_cache: + # Skip download, already in local cache. + pass + elif (digest.size_bytes >= remote.max_batch_total_size_bytes or + not remote.batch_read_supported): + # Too large for batch request, download in independent request. + self._ensure_blob(remote, digest) + in_local_cache = True + else: + if not batch.add(digest): + # Not enough space left in batch request. + # Complete pending batch first. + batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) + batch.add(digest) + + if recursive: + if in_local_cache: + # Add directory to processing queue. + fetch_queue.append(digest) + else: + # Directory will be available after completing pending batch. + # Add directory to deferred processing queue. + fetch_next_queue.append(digest) + + return batch + # _fetch_directory(): # # Fetches remote directory and adds it to content addressable store. @@ -867,39 +946,32 @@ class CASCache(ArtifactCache): # dir_digest (Digest): Digest object for the directory to fetch. # def _fetch_directory(self, remote, dir_digest): - objpath = self.objpath(dir_digest) - if os.path.exists(objpath): - # already in local cache - return - - with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out: - self._fetch_blob(remote, dir_digest, out) - - directory = remote_execution_pb2.Directory() + fetch_queue = [dir_digest] + fetch_next_queue = [] + batch = _CASBatchRead(remote) - with open(out.name, 'rb') as f: - directory.ParseFromString(f.read()) + while len(fetch_queue) + len(fetch_next_queue) > 0: + if len(fetch_queue) == 0: + batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) - for filenode in directory.files: - fileobjpath = self.objpath(filenode.digest) - if os.path.exists(fileobjpath): - # already in local cache - continue + dir_digest = fetch_queue.pop(0) - with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: - self._fetch_blob(remote, filenode.digest, f) + objpath = self._ensure_blob(remote, dir_digest) - digest = self.add_object(path=f.name) - assert digest.hash == filenode.digest.hash + directory = remote_execution_pb2.Directory() + with open(objpath, 'rb') as f: + directory.ParseFromString(f.read()) for dirnode in directory.directories: - self._fetch_directory(remote, dirnode.digest) + batch = self._fetch_directory_node(remote, dirnode.digest, batch, + fetch_queue, fetch_next_queue, recursive=True) + + for filenode in directory.files: + batch = self._fetch_directory_node(remote, filenode.digest, batch, + fetch_queue, fetch_next_queue) - # Place directory blob only in final location when we've - # downloaded all referenced blobs to avoid dangling - # references in the repository. - digest = self.add_object(path=out.name) - assert digest.hash == dir_digest.hash + # Fetch final batch + self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) def _fetch_tree(self, remote, digest): # download but do not store the Tree object @@ -914,16 +986,7 @@ class CASCache(ArtifactCache): tree.children.extend([tree.root]) for directory in tree.children: for filenode in directory.files: - fileobjpath = self.objpath(filenode.digest) - if os.path.exists(fileobjpath): - # already in local cache - continue - - with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: - self._fetch_blob(remote, filenode.digest, f) - - added_digest = self.add_object(path=f.name) - assert added_digest.hash == filenode.digest.hash + self._ensure_blob(remote, filenode.digest) # place directory blob only in final location when we've downloaded # all referenced blobs to avoid dangling references in the repository @@ -942,12 +1005,12 @@ class CASCache(ArtifactCache): finished = False remaining = digest.size_bytes while not finished: - chunk_size = min(remaining, 64 * 1024) + chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) remaining -= chunk_size request = bytestream_pb2.WriteRequest() request.write_offset = offset - # max. 64 kB chunks + # max. _MAX_PAYLOAD_BYTES chunks request.data = instream.read(chunk_size) request.resource_name = resname request.finish_write = remaining <= 0 @@ -1035,11 +1098,78 @@ class _CASRemote(): self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) + self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES + try: + request = remote_execution_pb2.GetCapabilitiesRequest() + response = self.capabilities.GetCapabilities(request) + server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes + if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: + self.max_batch_total_size_bytes = server_max_batch_total_size_bytes + except grpc.RpcError as e: + # Simply use the defaults for servers that don't implement GetCapabilities() + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + + # Check whether the server supports BatchReadBlobs() + self.batch_read_supported = False + try: + request = remote_execution_pb2.BatchReadBlobsRequest() + response = self.cas.BatchReadBlobs(request) + self.batch_read_supported = True + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + self._initialized = True +# Represents a batch of blobs queued for fetching. +# +class _CASBatchRead(): + def __init__(self, remote): + self._remote = remote + self._max_total_size_bytes = remote.max_batch_total_size_bytes + self._request = remote_execution_pb2.BatchReadBlobsRequest() + self._size = 0 + self._sent = False + + def add(self, digest): + assert not self._sent + + new_batch_size = self._size + digest.size_bytes + if new_batch_size > self._max_total_size_bytes: + # Not enough space left in current batch + return False + + request_digest = self._request.digests.add() + request_digest.hash = digest.hash + request_digest.size_bytes = digest.size_bytes + self._size = new_batch_size + return True + + def send(self): + assert not self._sent + self._sent = True + + if len(self._request.digests) == 0: + return + + batch_response = self._remote.cas.BatchReadBlobs(self._request) + + for response in batch_response.responses: + if response.status.code != grpc.StatusCode.OK.value[0]: + raise ArtifactError("Failed to download blob {}: {}".format( + response.digest.hash, response.status.code)) + if response.digest.size_bytes != len(response.data): + raise ArtifactError("Failed to download blob {}: expected {} bytes, received {} bytes".format( + response.digest.hash, response.digest.size_bytes, len(response.data))) + + yield (response.digest, response.data) + + def _grouper(iterable, n): while True: try: diff --git a/buildstream/_artifactcache/casserver.py b/buildstream/_artifactcache/casserver.py index 8c3ece27d..d833878d5 100644 --- a/buildstream/_artifactcache/casserver.py +++ b/buildstream/_artifactcache/casserver.py @@ -38,8 +38,9 @@ from .._context import Context from .cascache import CASCache -# The default limit for gRPC messages is 4 MiB -_MAX_BATCH_TOTAL_SIZE_BYTES = 4 * 1024 * 1024 +# The default limit for gRPC messages is 4 MiB. +# Limit payload to 1 MiB to leave sufficient headroom for metadata. +_MAX_PAYLOAD_BYTES = 1024 * 1024 # Trying to push an artifact that is too large @@ -158,7 +159,7 @@ class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer): remaining = client_digest.size_bytes - request.read_offset while remaining > 0: - chunk_size = min(remaining, 64 * 1024) + chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) remaining -= chunk_size response = bytestream_pb2.ReadResponse() @@ -242,7 +243,7 @@ class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddres for digest in request.digests: batch_size += digest.size_bytes - if batch_size > _MAX_BATCH_TOTAL_SIZE_BYTES: + if batch_size > _MAX_PAYLOAD_BYTES: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) return response @@ -269,7 +270,7 @@ class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer): cache_capabilities = response.cache_capabilities cache_capabilities.digest_function.append(remote_execution_pb2.SHA256) cache_capabilities.action_cache_update_capabilities.update_enabled = False - cache_capabilities.max_batch_total_size_bytes = _MAX_BATCH_TOTAL_SIZE_BYTES + cache_capabilities.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES cache_capabilities.symlink_absolute_path_strategy = remote_execution_pb2.CacheCapabilities.ALLOWED response.deprecated_api_version.major = 2 |