This is an automated email from the ASF dual-hosted git repository.

root pushed a commit to branch tlater/casd-socket-permissions
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit 2a02568e0cbce765eb5d497afdb38df6d04d826a
Author: Tristan Maat <[email protected]>
AuthorDate: Tue Oct 15 17:44:46 2019 +0100

    casserver.py: Proxy CAS requests to buildbox-casd
---
 src/buildstream/_cas/casserver.py | 274 +++++++++++---------------------------
 tests/testutils/artifactshare.py  |   4 +-
 2 files changed, 77 insertions(+), 201 deletions(-)

diff --git a/src/buildstream/_cas/casserver.py 
b/src/buildstream/_cas/casserver.py
index 4f07639..d1bef68 100644
--- a/src/buildstream/_cas/casserver.py
+++ b/src/buildstream/_cas/casserver.py
@@ -36,8 +36,7 @@ from google.protobuf.message import DecodeError
 import click
 
 from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, 
remote_execution_pb2_grpc
-from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
-from .._protos.google.rpc import code_pb2
+from .._protos.google.bytestream import bytestream_pb2_grpc
 from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc, \
     artifact_pb2, artifact_pb2_grpc, source_pb2, source_pb2_grpc
 
@@ -49,6 +48,66 @@ from .. import utils
 _MAX_PAYLOAD_BYTES = 1024 * 1024
 
 
+# CASRemote:
+#
+# A class that handles connections to a CAS remote - this is a (very)
+# slimmed down version of BuildStream's CASRemote.
+#
+class CASRemote:
+    def __init__(self, url: str):
+        self._url = url
+
+        self._bytestream = None
+        self._cas = None
+
+        # FIXME: We should allow setting up a secure channel. This
+        # isn't currently required, since we will only proxy to a
+        # process on the same host, but if we ever allow proxying to
+        # external services this will need to change.
+        self._channel = None
+
+    def _initialize_remote(self):
+        if self._channel:
+            assert self._cas and self._bytestream, "Stubs seem to have 
disappeared"
+            return
+        assert not (self._cas or self._bytestream), "Our cas/bytestream stubs 
should not have been set"
+
+        # Set up the remote channel
+        self._channel = grpc.insecure_channel(self._url)
+
+        # Assert that we support all capabilities we need
+        capabilities = 
remote_execution_pb2_grpc.CapabilitiesStub(self._channel)
+        start_wait = time.time()
+        while True:
+            try:
+                
capabilities.GetCapabilities(remote_execution_pb2.GetCapabilitiesRequest())
+                break
+            except grpc.RpcError as e:
+                if e.code() == grpc.StatusCode.UNAVAILABLE:
+                    # If connecting to casd, it may not be ready yet,
+                    # try again after a 10ms delay, but don't wait for
+                    # more than 15s
+                    if time.time() < start_wait + 15:
+                        time.sleep(1 / 100)
+                        continue
+
+                raise
+
+        # Set up the RPC stubs
+        self._bytestream = bytestream_pb2_grpc.ByteStreamStub(self._channel)
+        self._cas = 
remote_execution_pb2_grpc.ContentAddressableStorageStub(self._channel)
+
+    def get_cas(self) -> 
remote_execution_pb2_grpc.ContentAddressableStorageStub:
+        self._initialize_remote()
+        assert self._cas is not None, "CAS stub was not initialized"
+        return self._cas
+
+    def get_bytestream(self) -> bytestream_pb2_grpc.ByteStreamStub:
+        self._initialize_remote()
+        assert self._bytestream is not None, "Bytestream stub was not 
initialized"
+        return self._bytestream
+
+
 # CASCache:
 #
 # A slimmed down version of `buildstream._cas.cascache.CASCache` -
@@ -378,6 +437,7 @@ def create_server(repo, *, enable_push, quota, index_only, 
log_level=LogLevel.WA
     cas_runner = CASdRunner(os.path.abspath(repo), cache_quota=quota)
     cas_runner.start_casd()
     cas_cache = CASCache(os.path.abspath(repo))
+    cas = CASRemote('unix:' + cas_runner.get_socket_path())
 
     try:
         root = os.path.abspath(repo)
@@ -418,7 +478,6 @@ def create_server(repo, *, enable_push, quota, index_only, 
log_level=LogLevel.WA
         yield server
 
     finally:
-        cas.release_resources()
         cas_runner.stop()
 
 
@@ -489,221 +548,41 @@ def server_main(repo, port, server_key, server_cert, 
client_certs, enable_push,
 
 
 class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer):
-    def __init__(self, cas, *, enable_push):
+    def __init__(self, remote, *, enable_push):
         super().__init__()
-        self.cas = cas
+        self.bytestream = remote.get_bytestream()
         self.enable_push = enable_push
         self.logger = logging.getLogger("casserver")
 
     def Read(self, request, context):
         self.logger.info("Read")
-        resource_name = request.resource_name
-        client_digest = _digest_from_download_resource_name(resource_name)
-        if client_digest is None:
-            context.set_code(grpc.StatusCode.NOT_FOUND)
-            return
-
-        if request.read_offset > client_digest.size_bytes:
-            context.set_code(grpc.StatusCode.OUT_OF_RANGE)
-            return
-
-        try:
-            with open(self.cas.objpath(client_digest), 'rb') as f:
-                if os.fstat(f.fileno()).st_size != client_digest.size_bytes:
-                    context.set_code(grpc.StatusCode.NOT_FOUND)
-                    return
-
-                os.utime(f.fileno())
-
-                if request.read_offset > 0:
-                    f.seek(request.read_offset)
-
-                remaining = client_digest.size_bytes - request.read_offset
-                while remaining > 0:
-                    chunk_size = min(remaining, _MAX_PAYLOAD_BYTES)
-                    remaining -= chunk_size
-
-                    response = bytestream_pb2.ReadResponse()
-                    # max. 64 kB chunks
-                    response.data = f.read(chunk_size)
-                    yield response
-        except FileNotFoundError:
-            context.set_code(grpc.StatusCode.NOT_FOUND)
+        return self.bytestream.Read(request)
 
     def Write(self, request_iterator, context):
         self.logger.info("Write")
-        response = bytestream_pb2.WriteResponse()
-
-        if not self.enable_push:
-            context.set_code(grpc.StatusCode.PERMISSION_DENIED)
-            return response
-
-        offset = 0
-        finished = False
-        resource_name = None
-        with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out:
-            for request in request_iterator:
-                if finished or request.write_offset != offset:
-                    context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
-                    return response
-
-                if resource_name is None:
-                    # First request
-                    resource_name = request.resource_name
-                    client_digest = 
_digest_from_upload_resource_name(resource_name)
-                    if client_digest is None:
-                        context.set_code(grpc.StatusCode.NOT_FOUND)
-                        return response
-
-                    while True:
-                        if client_digest.size_bytes == 0:
-                            break
-
-                        try:
-                            os.posix_fallocate(out.fileno(), 0, 
client_digest.size_bytes)
-                            break
-                        except OSError as e:
-                            # Multiple upload can happen in the same time
-                            if e.errno != errno.ENOSPC:
-                                raise
-
-                elif request.resource_name:
-                    # If it is set on subsequent calls, it **must** match the 
value of the first request.
-                    if request.resource_name != resource_name:
-                        context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
-                        return response
-
-                if (offset + len(request.data)) > client_digest.size_bytes:
-                    context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
-                    return response
-
-                out.write(request.data)
-                offset += len(request.data)
-                if request.finish_write:
-                    if client_digest.size_bytes != offset:
-                        context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
-                        return response
-                    out.flush()
-
-                    try:
-                        digest = self.cas.add_object(path=out.name, 
link_directly=True)
-                    except CASCacheError as e:
-                        if e.reason == "cache-too-full":
-                            
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
-                        else:
-                            context.set_code(grpc.StatusCode.INTERNAL)
-                        return response
-
-                    if digest.hash != client_digest.hash:
-                        context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
-                        return response
-
-                    finished = True
-
-        assert finished
-
-        response.committed_size = offset
-        return response
+        return self.bytestream.Write(request_iterator)
 
 
 class 
_ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddressableStorageServicer):
-    def __init__(self, cas, *, enable_push):
+    def __init__(self, remote, *, enable_push):
         super().__init__()
-        self.cas = cas
+        self.cas = remote.get_cas()
         self.enable_push = enable_push
         self.logger = logging.getLogger("casserver")
 
     def FindMissingBlobs(self, request, context):
         self.logger.info("FindMissingBlobs")
-        response = remote_execution_pb2.FindMissingBlobsResponse()
-        for digest in request.blob_digests:
-            objpath = self.cas.objpath(digest)
-            try:
-                os.utime(objpath)
-            except OSError as e:
-                if e.errno != errno.ENOENT:
-                    raise
-
-                d = response.missing_blob_digests.add()
-                d.hash = digest.hash
-                d.size_bytes = digest.size_bytes
-
-        return response
+        self.logger.debug(request.blob_digests)
+        return self.cas.FindMissingBlobs(request)
 
     def BatchReadBlobs(self, request, context):
         self.logger.info("BatchReadBlobs")
-        self.logger.debug(request.digests)
-        response = remote_execution_pb2.BatchReadBlobsResponse()
-        batch_size = 0
-
-        for digest in request.digests:
-            batch_size += digest.size_bytes
-            if batch_size > _MAX_PAYLOAD_BYTES:
-                context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
-                return response
-
-            blob_response = response.responses.add()
-            blob_response.digest.hash = digest.hash
-            blob_response.digest.size_bytes = digest.size_bytes
-            try:
-                objpath = self.cas.objpath(digest)
-                with open(objpath, 'rb') as f:
-                    if os.fstat(f.fileno()).st_size != digest.size_bytes:
-                        blob_response.status.code = code_pb2.NOT_FOUND
-                        continue
-
-                    os.utime(f.fileno())
-
-                    blob_response.data = f.read(digest.size_bytes)
-            except FileNotFoundError:
-                blob_response.status.code = code_pb2.NOT_FOUND
-
-        return response
+        return self.cas.BatchReadBlobs(request)
 
     def BatchUpdateBlobs(self, request, context):
         self.logger.info("BatchUpdateBlobs")
         self.logger.debug([request.digest for request in request.requests])
-        response = remote_execution_pb2.BatchUpdateBlobsResponse()
-
-        if not self.enable_push:
-            context.set_code(grpc.StatusCode.PERMISSION_DENIED)
-            return response
-
-        batch_size = 0
-
-        for blob_request in request.requests:
-            digest = blob_request.digest
-
-            batch_size += digest.size_bytes
-            if batch_size > _MAX_PAYLOAD_BYTES:
-                context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
-                return response
-
-            blob_response = response.responses.add()
-            blob_response.digest.hash = digest.hash
-            blob_response.digest.size_bytes = digest.size_bytes
-
-            if len(blob_request.data) != digest.size_bytes:
-                blob_response.status.code = code_pb2.FAILED_PRECONDITION
-                continue
-
-            with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out:
-                out.write(blob_request.data)
-                out.flush()
-
-                try:
-                    server_digest = self.cas.add_object(path=out.name)
-                except CASCacheError as e:
-                    if e.reason == "cache-too-full":
-                        blob_response.status.code = code_pb2.RESOURCE_EXHAUSTED
-                    else:
-                        blob_response.status.code = code_pb2.INTERNAL
-                    continue
-
-                if server_digest.hash != digest.hash:
-                    blob_response.status.code = code_pb2.FAILED_PRECONDITION
-
-        return response
+        return self.cas.BatchUpdateBlobs(request)
 
 
 class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer):
@@ -728,9 +607,9 @@ class 
_CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer):
 
 
 class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer):
-    def __init__(self, cas, cas_cache, *, enable_push):
+    def __init__(self, remote, cas_cache, *, enable_push):
         super().__init__()
-        self.cas = cas
+        self.cas = remote.get_cas()
         self.cas_cache = cas_cache
         self.enable_push = enable_push
         self.logger = logging.getLogger("casserver")
@@ -779,13 +658,12 @@ class 
_ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer):
 
 class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer):
 
-    def __init__(self, cas, root, cas_cache, *, update_cas=True):
+    def __init__(self, remote, root, cas_cache, *, update_cas=True):
         super().__init__()
-        self.cas = cas
+        self.cas = remote.get_cas()
         self.cas_cache = cas_cache
         self.artifactdir = os.path.join(root, 'artifacts', 'refs')
         self.update_cas = update_cas
-        os.makedirs(artifactdir, exist_ok=True)
         self.logger = logging.getLogger("casserver")
 
     def GetArtifact(self, request, context):
diff --git a/tests/testutils/artifactshare.py b/tests/testutils/artifactshare.py
index 18ecc5e..d86cafa 100644
--- a/tests/testutils/artifactshare.py
+++ b/tests/testutils/artifactshare.py
@@ -39,11 +39,9 @@ class ArtifactShare():
         # in tests as a remote artifact push/pull configuration
         #
         self.repodir = os.path.join(self.directory, 'repo')
-        os.makedirs(self.repodir)
-        self.artifactdir = os.path.join(self.repodir, 'artifacts', 'refs')
-        os.makedirs(self.artifactdir)
 
         self.cas = CASCache(self.repodir, casd=casd)
+        self.artifactdir = os.path.join(self.repodir, 'artifacts', 'refs')
 
         self.quota = quota
         self.index_only = index_only

Reply via email to