This is an automated email from the ASF dual-hosted git repository. root pushed a commit to branch tlater/casd-socket-permissions in repository https://gitbox.apache.org/repos/asf/buildstream.git
commit 2a02568e0cbce765eb5d497afdb38df6d04d826a Author: Tristan Maat <[email protected]> AuthorDate: Tue Oct 15 17:44:46 2019 +0100 casserver.py: Proxy CAS requests to buildbox-casd --- src/buildstream/_cas/casserver.py | 274 +++++++++++--------------------------- tests/testutils/artifactshare.py | 4 +- 2 files changed, 77 insertions(+), 201 deletions(-) diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py index 4f07639..d1bef68 100644 --- a/src/buildstream/_cas/casserver.py +++ b/src/buildstream/_cas/casserver.py @@ -36,8 +36,7 @@ from google.protobuf.message import DecodeError import click from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc -from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc -from .._protos.google.rpc import code_pb2 +from .._protos.google.bytestream import bytestream_pb2_grpc from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc, \ artifact_pb2, artifact_pb2_grpc, source_pb2, source_pb2_grpc @@ -49,6 +48,66 @@ from .. import utils _MAX_PAYLOAD_BYTES = 1024 * 1024 +# CASRemote: +# +# A class that handles connections to a CAS remote - this is a (very) +# slimmed down version of BuildStream's CASRemote. +# +class CASRemote: + def __init__(self, url: str): + self._url = url + + self._bytestream = None + self._cas = None + + # FIXME: We should allow setting up a secure channel. This + # isn't currently required, since we will only proxy to a + # process on the same host, but if we ever allow proxying to + # external services this will need to change. + self._channel = None + + def _initialize_remote(self): + if self._channel: + assert self._cas and self._bytestream, "Stubs seem to have disappeared" + return + assert not (self._cas or self._bytestream), "Our cas/bytestream stubs should not have been set" + + # Set up the remote channel + self._channel = grpc.insecure_channel(self._url) + + # Assert that we support all capabilities we need + capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self._channel) + start_wait = time.time() + while True: + try: + capabilities.GetCapabilities(remote_execution_pb2.GetCapabilitiesRequest()) + break + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + # If connecting to casd, it may not be ready yet, + # try again after a 10ms delay, but don't wait for + # more than 15s + if time.time() < start_wait + 15: + time.sleep(1 / 100) + continue + + raise + + # Set up the RPC stubs + self._bytestream = bytestream_pb2_grpc.ByteStreamStub(self._channel) + self._cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self._channel) + + def get_cas(self) -> remote_execution_pb2_grpc.ContentAddressableStorageStub: + self._initialize_remote() + assert self._cas is not None, "CAS stub was not initialized" + return self._cas + + def get_bytestream(self) -> bytestream_pb2_grpc.ByteStreamStub: + self._initialize_remote() + assert self._bytestream is not None, "Bytestream stub was not initialized" + return self._bytestream + + # CASCache: # # A slimmed down version of `buildstream._cas.cascache.CASCache` - @@ -378,6 +437,7 @@ def create_server(repo, *, enable_push, quota, index_only, log_level=LogLevel.WA cas_runner = CASdRunner(os.path.abspath(repo), cache_quota=quota) cas_runner.start_casd() cas_cache = CASCache(os.path.abspath(repo)) + cas = CASRemote('unix:' + cas_runner.get_socket_path()) try: root = os.path.abspath(repo) @@ -418,7 +478,6 @@ def create_server(repo, *, enable_push, quota, index_only, log_level=LogLevel.WA yield server finally: - cas.release_resources() cas_runner.stop() @@ -489,221 +548,41 @@ def server_main(repo, port, server_key, server_cert, client_certs, enable_push, class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer): - def __init__(self, cas, *, enable_push): + def __init__(self, remote, *, enable_push): super().__init__() - self.cas = cas + self.bytestream = remote.get_bytestream() self.enable_push = enable_push self.logger = logging.getLogger("casserver") def Read(self, request, context): self.logger.info("Read") - resource_name = request.resource_name - client_digest = _digest_from_download_resource_name(resource_name) - if client_digest is None: - context.set_code(grpc.StatusCode.NOT_FOUND) - return - - if request.read_offset > client_digest.size_bytes: - context.set_code(grpc.StatusCode.OUT_OF_RANGE) - return - - try: - with open(self.cas.objpath(client_digest), 'rb') as f: - if os.fstat(f.fileno()).st_size != client_digest.size_bytes: - context.set_code(grpc.StatusCode.NOT_FOUND) - return - - os.utime(f.fileno()) - - if request.read_offset > 0: - f.seek(request.read_offset) - - remaining = client_digest.size_bytes - request.read_offset - while remaining > 0: - chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) - remaining -= chunk_size - - response = bytestream_pb2.ReadResponse() - # max. 64 kB chunks - response.data = f.read(chunk_size) - yield response - except FileNotFoundError: - context.set_code(grpc.StatusCode.NOT_FOUND) + return self.bytestream.Read(request) def Write(self, request_iterator, context): self.logger.info("Write") - response = bytestream_pb2.WriteResponse() - - if not self.enable_push: - context.set_code(grpc.StatusCode.PERMISSION_DENIED) - return response - - offset = 0 - finished = False - resource_name = None - with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out: - for request in request_iterator: - if finished or request.write_offset != offset: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - - if resource_name is None: - # First request - resource_name = request.resource_name - client_digest = _digest_from_upload_resource_name(resource_name) - if client_digest is None: - context.set_code(grpc.StatusCode.NOT_FOUND) - return response - - while True: - if client_digest.size_bytes == 0: - break - - try: - os.posix_fallocate(out.fileno(), 0, client_digest.size_bytes) - break - except OSError as e: - # Multiple upload can happen in the same time - if e.errno != errno.ENOSPC: - raise - - elif request.resource_name: - # If it is set on subsequent calls, it **must** match the value of the first request. - if request.resource_name != resource_name: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - - if (offset + len(request.data)) > client_digest.size_bytes: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - - out.write(request.data) - offset += len(request.data) - if request.finish_write: - if client_digest.size_bytes != offset: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - out.flush() - - try: - digest = self.cas.add_object(path=out.name, link_directly=True) - except CASCacheError as e: - if e.reason == "cache-too-full": - context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) - else: - context.set_code(grpc.StatusCode.INTERNAL) - return response - - if digest.hash != client_digest.hash: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - - finished = True - - assert finished - - response.committed_size = offset - return response + return self.bytestream.Write(request_iterator) class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddressableStorageServicer): - def __init__(self, cas, *, enable_push): + def __init__(self, remote, *, enable_push): super().__init__() - self.cas = cas + self.cas = remote.get_cas() self.enable_push = enable_push self.logger = logging.getLogger("casserver") def FindMissingBlobs(self, request, context): self.logger.info("FindMissingBlobs") - response = remote_execution_pb2.FindMissingBlobsResponse() - for digest in request.blob_digests: - objpath = self.cas.objpath(digest) - try: - os.utime(objpath) - except OSError as e: - if e.errno != errno.ENOENT: - raise - - d = response.missing_blob_digests.add() - d.hash = digest.hash - d.size_bytes = digest.size_bytes - - return response + self.logger.debug(request.blob_digests) + return self.cas.FindMissingBlobs(request) def BatchReadBlobs(self, request, context): self.logger.info("BatchReadBlobs") - self.logger.debug(request.digests) - response = remote_execution_pb2.BatchReadBlobsResponse() - batch_size = 0 - - for digest in request.digests: - batch_size += digest.size_bytes - if batch_size > _MAX_PAYLOAD_BYTES: - context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - return response - - blob_response = response.responses.add() - blob_response.digest.hash = digest.hash - blob_response.digest.size_bytes = digest.size_bytes - try: - objpath = self.cas.objpath(digest) - with open(objpath, 'rb') as f: - if os.fstat(f.fileno()).st_size != digest.size_bytes: - blob_response.status.code = code_pb2.NOT_FOUND - continue - - os.utime(f.fileno()) - - blob_response.data = f.read(digest.size_bytes) - except FileNotFoundError: - blob_response.status.code = code_pb2.NOT_FOUND - - return response + return self.cas.BatchReadBlobs(request) def BatchUpdateBlobs(self, request, context): self.logger.info("BatchUpdateBlobs") self.logger.debug([request.digest for request in request.requests]) - response = remote_execution_pb2.BatchUpdateBlobsResponse() - - if not self.enable_push: - context.set_code(grpc.StatusCode.PERMISSION_DENIED) - return response - - batch_size = 0 - - for blob_request in request.requests: - digest = blob_request.digest - - batch_size += digest.size_bytes - if batch_size > _MAX_PAYLOAD_BYTES: - context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - return response - - blob_response = response.responses.add() - blob_response.digest.hash = digest.hash - blob_response.digest.size_bytes = digest.size_bytes - - if len(blob_request.data) != digest.size_bytes: - blob_response.status.code = code_pb2.FAILED_PRECONDITION - continue - - with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out: - out.write(blob_request.data) - out.flush() - - try: - server_digest = self.cas.add_object(path=out.name) - except CASCacheError as e: - if e.reason == "cache-too-full": - blob_response.status.code = code_pb2.RESOURCE_EXHAUSTED - else: - blob_response.status.code = code_pb2.INTERNAL - continue - - if server_digest.hash != digest.hash: - blob_response.status.code = code_pb2.FAILED_PRECONDITION - - return response + return self.cas.BatchUpdateBlobs(request) class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer): @@ -728,9 +607,9 @@ class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer): class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): - def __init__(self, cas, cas_cache, *, enable_push): + def __init__(self, remote, cas_cache, *, enable_push): super().__init__() - self.cas = cas + self.cas = remote.get_cas() self.cas_cache = cas_cache self.enable_push = enable_push self.logger = logging.getLogger("casserver") @@ -779,13 +658,12 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): - def __init__(self, cas, root, cas_cache, *, update_cas=True): + def __init__(self, remote, root, cas_cache, *, update_cas=True): super().__init__() - self.cas = cas + self.cas = remote.get_cas() self.cas_cache = cas_cache self.artifactdir = os.path.join(root, 'artifacts', 'refs') self.update_cas = update_cas - os.makedirs(artifactdir, exist_ok=True) self.logger = logging.getLogger("casserver") def GetArtifact(self, request, context): diff --git a/tests/testutils/artifactshare.py b/tests/testutils/artifactshare.py index 18ecc5e..d86cafa 100644 --- a/tests/testutils/artifactshare.py +++ b/tests/testutils/artifactshare.py @@ -39,11 +39,9 @@ class ArtifactShare(): # in tests as a remote artifact push/pull configuration # self.repodir = os.path.join(self.directory, 'repo') - os.makedirs(self.repodir) - self.artifactdir = os.path.join(self.repodir, 'artifacts', 'refs') - os.makedirs(self.artifactdir) self.cas = CASCache(self.repodir, casd=casd) + self.artifactdir = os.path.join(self.repodir, 'artifacts', 'refs') self.quota = quota self.index_only = index_only
