This is an automated email from the ASF dual-hosted git repository. github-bot pushed a commit to branch aevri/casdprocessmanager2 in repository https://gitbox.apache.org/repos/asf/buildstream.git
commit 84aab60bd8066439ed4971a23b21288d30729a87 Author: Angelos Evripiotis <[email protected]> AuthorDate: Mon Oct 14 13:53:00 2019 +0100 Extract casd_channel logic to CASDConnection Encapsulate the management of a connection to CASD, so we can hide the details of how it happens. This will make it easier to port to Windows, as we will have to take a different approach there. Also make get_local_cas() public, since it is already used outside of the CASCache class. --- src/buildstream/_cas/cascache.py | 98 ++++++++++++++---------------- src/buildstream/_cas/casdprocessmanager.py | 83 ++++++++++++++++++++++++- src/buildstream/_cas/casremote.py | 6 +- 3 files changed, 130 insertions(+), 57 deletions(-) diff --git a/src/buildstream/_cas/cascache.py b/src/buildstream/_cas/cascache.py index aefc1b9..65359ff 100644 --- a/src/buildstream/_cas/cascache.py +++ b/src/buildstream/_cas/cascache.py @@ -31,14 +31,14 @@ import time import grpc from .._protos.google.rpc import code_pb2 -from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc -from .._protos.build.buildgrid import local_cas_pb2, local_cas_pb2_grpc +from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 +from .._protos.build.buildgrid import local_cas_pb2 from .. import _signals, utils from ..types import FastEnum from .._exceptions import CASCacheError -from .casdprocessmanager import CASDProcessManager +from .casdprocessmanager import CASDConnection, CASDProcessManager from .casremote import _CASBatchRead, _CASBatchUpdate _BUFFER_SIZE = 65536 @@ -74,9 +74,6 @@ class CASCache(): os.makedirs(os.path.join(self.casdir, 'objects'), exist_ok=True) os.makedirs(self.tmpdir, exist_ok=True) - self._casd_channel = None - self._casd_cas = None - self._local_cas = None self._cache_usage_monitor = None self._cache_usage_monitor_forbidden = False @@ -107,43 +104,12 @@ class CASCache(): return state - def _init_casd(self): - assert self._casd_process_manager, "CASCache was instantiated without buildbox-casd" - - if not self._casd_channel: - while not os.path.exists(self._casd_process_manager.socket_path): - # casd is not ready yet, try again after a 10ms delay, - # but don't wait for more than 15s - if time.time() > self._casd_process_manager.start_time + 15: - raise CASCacheError("Timed out waiting for buildbox-casd to become ready") - - time.sleep(0.01) - - self._casd_channel = grpc.insecure_channel('unix:' + self._casd_process_manager.socket_path) - self._casd_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self._casd_channel) - self._local_cas = local_cas_pb2_grpc.LocalContentAddressableStorageStub(self._casd_channel) - - # Call GetCapabilities() to establish connection to casd - capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self._casd_channel) - capabilities.GetCapabilities(remote_execution_pb2.GetCapabilitiesRequest()) - - # _get_cas(): - # - # Return ContentAddressableStorage stub for buildbox-casd channel. - # - def _get_cas(self): - if not self._casd_cas: - self._init_casd() - return self._casd_cas - - # _get_local_cas(): + # get_local_cas(): # # Return LocalCAS stub for buildbox-casd channel. # - def _get_local_cas(self): - if not self._local_cas: - self._init_casd() - return self._local_cas + def get_local_cas(self): + return self._casd_process_manager.get_connection().get_local_cas() # preflight(): # @@ -161,18 +127,17 @@ class CASCache(): # against fork() with open gRPC channels. # def has_open_grpc_channels(self): - return bool(self._casd_channel) + if self._casd_process_manager: + return self._casd_process_manager.has_open_grpc_channels() + return False # close_grpc_channels(): # # Close the casd channel if it exists # def close_grpc_channels(self): - if self._casd_channel: - self._local_cas = None - self._casd_cas = None - self._casd_channel.close() - self._casd_channel = None + if self._casd_process_manager: + self._casd_process_manager.close_grpc_channels() # release_resources(): # @@ -390,8 +355,7 @@ class CASCache(): request.path.append(path) - local_cas = self._get_local_cas() - + local_cas = self.get_local_cas() response = local_cas.CaptureFiles(request) if len(response.responses) != 1: @@ -417,7 +381,7 @@ class CASCache(): # (Digest): The digest of the imported directory # def import_directory(self, path): - local_cas = self._get_local_cas() + local_cas = self.get_local_cas() request = local_cas_pb2.CaptureTreeRequest() request.path.append(path) @@ -537,7 +501,7 @@ class CASCache(): # Returns: List of missing Digest objects # def remote_missing_blobs(self, remote, blobs): - cas = self._get_cas() + cas = self._casd_process_manager.get_connection().get_cas() instance_name = remote.local_cas_instance_name missing_blobs = dict() @@ -1032,7 +996,7 @@ class _CASCacheUsageMonitor: disk_usage = self._disk_usage disk_quota = self._disk_quota - local_cas = self.cas._get_local_cas() + local_cas = self.cas.get_local_cas() while True: try: @@ -1071,5 +1035,33 @@ def _grouper(iterable, n): # class _LimitedCASDProcessManagerProxy: def __init__(self, casd_process_manager): - self.socket_path = casd_process_manager.socket_path - self.start_time = casd_process_manager.start_time + self._casd_connection = None + self._connection_string = casd_process_manager.connection_string + self._start_time = casd_process_manager.start_time + self._socket_path = casd_process_manager.socket_path + + # get_connection(): + # + # Return ContentAddressableStorage stub for buildbox-casd channel. + # + def get_connection(self): + if not self._casd_connection: + self._casd_connection = CASDConnection( + self._socket_path, self._connection_string, self._start_time) + return self._casd_connection + + # has_open_grpc_channels(): + # + # Return whether there are gRPC channel instances. This is used to safeguard + # against fork() with open gRPC channels. + # + def has_open_grpc_channels(self): + return bool(self._casd_connection) + + # close_grpc_channels(): + # + # Close the casd channel if it exists + # + def close_grpc_channels(self): + if self._casd_connection: + self._casd_connection.close() diff --git a/src/buildstream/_cas/casdprocessmanager.py b/src/buildstream/_cas/casdprocessmanager.py index 3a434ad..c096db1 100644 --- a/src/buildstream/_cas/casdprocessmanager.py +++ b/src/buildstream/_cas/casdprocessmanager.py @@ -25,7 +25,13 @@ import subprocess import tempfile import time +import grpc + +from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc +from .._protos.build.buildgrid import local_cas_pb2_grpc + from .. import _signals, utils +from .._exceptions import CASCacheError from .._message import Message, MessageType _CASD_MAX_LOGFILES = 10 @@ -47,13 +53,16 @@ class CASDProcessManager: def __init__(self, path, log_dir, log_level, cache_quota, protect_session_blobs): self._log_dir = log_dir + self._casd_connection = None + # Place socket in global/user temporary directory to avoid hitting # the socket path length limit. self._socket_tempdir = tempfile.mkdtemp(prefix='buildstream') self.socket_path = os.path.join(self._socket_tempdir, 'casd.sock') + self.connection_string = "unix:" + self.socket_path casd_args = [utils.get_host_tool('buildbox-casd')] - casd_args.append('--bind=unix:' + self.socket_path) + casd_args.append('--bind=' + self.connection_string) casd_args.append('--log-level=' + log_level.value) if cache_quota is not None: @@ -215,3 +224,75 @@ class CASDProcessManager: assert self._failure_callback is not None self._process.returncode = returncode self._failure_callback() + + # get_connection(): + # + # Return ContentAddressableStorage stub for buildbox-casd channel. + # + def get_connection(self): + if not self._casd_connection: + self._casd_connection = CASDConnection( + self.socket_path, self.connection_string, self.start_time) + return self._casd_connection + + # has_open_grpc_channels(): + # + # Return whether there are gRPC channel instances. This is used to safeguard + # against fork() with open gRPC channels. + # + def has_open_grpc_channels(self): + return bool(self._casd_connection) + + # close_grpc_channels(): + # + # Close the casd channel if it exists + # + def close_grpc_channels(self): + if self._casd_connection: + self._casd_connection.close() + + +class CASDConnection: + def __init__(self, socket_path, connection_string, start_time): + while not os.path.exists(socket_path): + # casd is not ready yet, try again after a 10ms delay, + # but don't wait for more than 15s + if time.time() > start_time + 15: + raise CASCacheError("Timed out waiting for buildbox-casd to become ready") + + time.sleep(0.01) + + self._casd_channel = grpc.insecure_channel(connection_string) + self._casd_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self._casd_channel) + self._local_cas = local_cas_pb2_grpc.LocalContentAddressableStorageStub(self._casd_channel) + + # Call GetCapabilities() to establish connection to casd + capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self._casd_channel) + capabilities.GetCapabilities(remote_execution_pb2.GetCapabilitiesRequest()) + + # get_cas(): + # + # Return ContentAddressableStorage stub for buildbox-casd channel. + # + def get_cas(self): + assert self._casd_channel is not None + return self._casd_cas + + # get_local_cas(): + # + # Return LocalCAS stub for buildbox-casd channel. + # + def get_local_cas(self): + assert self._casd_channel is not None + return self._local_cas + + # close(): + # + # Close the casd channel. + # + def close(self): + assert self._casd_channel is not None + self._local_cas = None + self._casd_cas = None + self._casd_channel.close() + self._casd_channel = None diff --git a/src/buildstream/_cas/casremote.py b/src/buildstream/_cas/casremote.py index a054b28..c89ea9f 100644 --- a/src/buildstream/_cas/casremote.py +++ b/src/buildstream/_cas/casremote.py @@ -55,7 +55,7 @@ class CASRemote(BaseRemote): # be called outside of init(). # def _configure_protocols(self): - local_cas = self.cascache._get_local_cas() + local_cas = self.cascache.get_local_cas() request = local_cas_pb2.GetInstanceNameForRemoteRequest() request.url = self.spec.url if self.spec.instance_name: @@ -115,7 +115,7 @@ class _CASBatchRead(): if not self._requests: return - local_cas = self._remote.cascache._get_local_cas() + local_cas = self._remote.cascache.get_local_cas() for request in self._requests: batch_response = local_cas.FetchMissingBlobs(request) @@ -163,7 +163,7 @@ class _CASBatchUpdate(): if not self._requests: return - local_cas = self._remote.cascache._get_local_cas() + local_cas = self._remote.cascache.get_local_cas() for request in self._requests: batch_response = local_cas.UploadMissingBlobs(request)
