damccorm commented on code in PR #36213: URL: https://github.com/apache/beam/pull/36213#discussion_r2370092144
########## sdks/python/apache_beam/transforms/util.py: ########## @@ -317,6 +326,191 @@ def RemoveDuplicates(pcoll): return pcoll | 'RemoveDuplicates' >> Distinct() +class Secret(): + """A secret management class used for handling sensitive data. + + This class provides a generic interface for secret management. Implementations + of this class should handle fetching secrets from a secret management system. + """ + def get_secret_bytes(self) -> bytes: + """Returns the secret as a byte string.""" + raise NotImplementedError() + + @staticmethod + def generate_secret_bytes() -> bytes: + """Generates a new secret key.""" + return Fernet.generate_key() + + +class GcpSecret(Secret): + """A secret manager implementation that retrieves secrets from Google Cloud + Secret Manager. + """ + def __init__(self, version_name: str): + """Initializes a GcpSecret object. + + Args: + version_name: The full version name of the secret in Google Cloud Secret + Manager. For example: + projects/<id>/secrets/<secret_name>/versions/1. + For more info, see + https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1beta1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1beta1_services_secret_manager_service_SecretManagerServiceClient_access_secret_version + """ + self._version_name = version_name + + def get_secret_bytes(self) -> bytes: + try: + from google.cloud import secretmanager + client = secretmanager.SecretManagerServiceClient() + response = client.access_secret_version( + request={"name": self._version_name}) + secret = response.payload.data + return secret + except Exception as e: + raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}') + + +class _EncryptMessage(DoFn): + """A DoFn that encrypts the key and value of each element.""" + def __init__( + self, + hmac_key_secret: Secret, + key_coder: coders.Coder, + value_coder: coders.Coder): + self.hmac_key_secret = hmac_key_secret + self.key_coder = key_coder + self.value_coder = value_coder + + def setup(self): + self._hmac_key = self.hmac_key_secret.get_secret_bytes() + self.fernet = Fernet(self._hmac_key) + + def process(self, + element: Any) -> Iterable[Tuple[bytes, Tuple[bytes, bytes]]]: + """Encrypts the key and value of an element. + + Args: + element: A tuple containing the key and value to be encrypted. + + Yields: + A tuple containing the HMAC of the encoded key, and a tuple of the + encrypted key and value. + """ + k, v = element + encoded_key = self.key_coder.encode(k) + encoded_value = self.value_coder.encode(v) + hmac_encoded_key = hmac.new(self._hmac_key, encoded_key, + hashlib.sha256).digest() + out_element = ( + hmac_encoded_key, + (self.fernet.encrypt(encoded_key), self.fernet.encrypt(encoded_value))) + yield out_element + + +class _DecryptMessage(DoFn): + """A DoFn that decrypts the key and value of each element.""" + def __init__( + self, + hmac_key_secret: Secret, + key_coder: coders.Coder, + value_coder: coders.Coder): + self.hmac_key_secret = hmac_key_secret + self.key_coder = key_coder + self.value_coder = value_coder + + def setup(self): + hmac_key = self.hmac_key_secret.get_secret_bytes() + self.fernet = Fernet(hmac_key) + + def decode_value(self, encoded_element: Tuple[bytes, bytes]) -> Any: + encrypted_value = encoded_element[1] + encoded_value = self.fernet.decrypt(encrypted_value) + real_val = self.value_coder.decode(encoded_value) + return real_val + + def filter_elements_by_key( + self, + encrypted_key: bytes, + encoded_elements: Iterable[Tuple[bytes, bytes]]) -> Iterable[Any]: + for e in encoded_elements: + if encrypted_key == self.fernet.decrypt(e[0]): + yield self.decode_value(e) + + # Right now, GBK always returns a list of elements, so we match this behavior + # here. This does mean that the whole list will be materialized every time, + # but passing an Iterable containing an Iterable breaks when pickling happens + def process( + self, element: Tuple[bytes, Iterable[Tuple[bytes, bytes]]] + ) -> Iterable[Tuple[Any, List[Any]]]: + """Decrypts the key and values of an element. + + Args: + element: A tuple containing the HMAC of the encoded key and an iterable + of tuples of encrypted keys and values. + + Yields: + A tuple containing the decrypted key and a list of decrypted values. + """ + unused_hmac_encoded_key, encoded_elements = element + seen_keys = set() + + # Since there could be hmac collisions, we will use the fernet encrypted + # key to confirm that the mapping is actually correct. + for e in encoded_elements: + encrypted_key, unused_encrypted_value = e + encoded_key = self.fernet.decrypt(encrypted_key) + if encoded_key in seen_keys: + continue + seen_keys.add(encoded_key) + real_key = self.key_coder.decode(encoded_key) + + yield ( + real_key, + list(self.filter_elements_by_key(encoded_key, encoded_elements))) + + [email protected]_input_types(Tuple[K, V]) [email protected]_output_types(Tuple[K, Iterable[V]]) +class GroupByEncryptedKey(PTransform): + """A PTransform that provides a secure alternative to GroupByKey. + + This transform encrypts the keys of the input PCollection, performs a + GroupByKey on the encrypted keys, and then decrypts the keys in the output. + This is useful when the keys contain sensitive data that should not be + stored at rest by the runner. + + """ + def __init__(self, hmac_key: Secret): + """Initializes a GroupByEncryptedKey transform. + + Args: + hmac_key: A Secret object that provides the secret key for HMAC and + encryption. For example, a GcpSecret can be used to access a secret + stored in GCP Secret Manager + """ + self._hmac_key = hmac_key + + def expand(self, pcoll): + kv_type_hint = pcoll.element_type + if kv_type_hint and kv_type_hint != typehints.Any: + coder = coders.registry.get_coder(kv_type_hint) + if not coder.is_kv_coder(): + raise ValueError( + 'Input elements to the transform %s with stateful DoFn must be ' + 'key-value pairs.' % self) + key_coder = coder.key_coder() Review Comment: I think defining this as a coder could be messy because then you would need to access the secret at construction time and include that as part of the serialized graph definition. This would then not provide sufficient security guarantees since the graph itself would have all information needed to decrypt the value. Potentially you could include the work of downloading the secret in the coder definition, but I don't think we gain much from this today (and errors might be messy to debug). It also seems nice that we have the actual transform definition which helps make it obvious what is happening from the graph. ########## sdks/python/apache_beam/transforms/util_test.py: ########## @@ -85,8 +92,10 @@ try: import dill + from google.cloud import secretmanager Review Comment: Thanks, updated. It should run in I think, or at least in the postcommit. I'm not 100% sure where all we have GCP deps installed. I'll verify it runs correctly this time, though, thanks for flagging it. ########## sdks/python/apache_beam/transforms/util.py: ########## @@ -317,6 +326,191 @@ def RemoveDuplicates(pcoll): return pcoll | 'RemoveDuplicates' >> Distinct() +class Secret(): + """A secret management class used for handling sensitive data. + + This class provides a generic interface for secret management. Implementations + of this class should handle fetching secrets from a secret management system. + """ + def get_secret_bytes(self) -> bytes: + """Returns the secret as a byte string.""" + raise NotImplementedError() + + @staticmethod + def generate_secret_bytes() -> bytes: + """Generates a new secret key.""" + return Fernet.generate_key() + + +class GcpSecret(Secret): + """A secret manager implementation that retrieves secrets from Google Cloud + Secret Manager. + """ + def __init__(self, version_name: str): + """Initializes a GcpSecret object. + + Args: + version_name: The full version name of the secret in Google Cloud Secret + Manager. For example: + projects/<id>/secrets/<secret_name>/versions/1. + For more info, see + https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1beta1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1beta1_services_secret_manager_service_SecretManagerServiceClient_access_secret_version + """ + self._version_name = version_name + + def get_secret_bytes(self) -> bytes: + try: + from google.cloud import secretmanager + client = secretmanager.SecretManagerServiceClient() + response = client.access_secret_version( + request={"name": self._version_name}) + secret = response.payload.data + return secret + except Exception as e: + raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}') + + +class _EncryptMessage(DoFn): + """A DoFn that encrypts the key and value of each element.""" + def __init__( + self, + hmac_key_secret: Secret, + key_coder: coders.Coder, + value_coder: coders.Coder): + self.hmac_key_secret = hmac_key_secret + self.key_coder = key_coder + self.value_coder = value_coder + + def setup(self): + self._hmac_key = self.hmac_key_secret.get_secret_bytes() + self.fernet = Fernet(self._hmac_key) + + def process(self, + element: Any) -> Iterable[Tuple[bytes, Tuple[bytes, bytes]]]: + """Encrypts the key and value of an element. + + Args: + element: A tuple containing the key and value to be encrypted. + + Yields: + A tuple containing the HMAC of the encoded key, and a tuple of the + encrypted key and value. + """ + k, v = element + encoded_key = self.key_coder.encode(k) + encoded_value = self.value_coder.encode(v) + hmac_encoded_key = hmac.new(self._hmac_key, encoded_key, + hashlib.sha256).digest() + out_element = ( + hmac_encoded_key, + (self.fernet.encrypt(encoded_key), self.fernet.encrypt(encoded_value))) + yield out_element + + +class _DecryptMessage(DoFn): + """A DoFn that decrypts the key and value of each element.""" + def __init__( + self, + hmac_key_secret: Secret, + key_coder: coders.Coder, + value_coder: coders.Coder): + self.hmac_key_secret = hmac_key_secret + self.key_coder = key_coder + self.value_coder = value_coder + + def setup(self): + hmac_key = self.hmac_key_secret.get_secret_bytes() + self.fernet = Fernet(hmac_key) + + def decode_value(self, encoded_element: Tuple[bytes, bytes]) -> Any: + encrypted_value = encoded_element[1] + encoded_value = self.fernet.decrypt(encrypted_value) + real_val = self.value_coder.decode(encoded_value) + return real_val + + def filter_elements_by_key( + self, + encrypted_key: bytes, + encoded_elements: Iterable[Tuple[bytes, bytes]]) -> Iterable[Any]: + for e in encoded_elements: + if encrypted_key == self.fernet.decrypt(e[0]): + yield self.decode_value(e) + + # Right now, GBK always returns a list of elements, so we match this behavior + # here. This does mean that the whole list will be materialized every time, + # but passing an Iterable containing an Iterable breaks when pickling happens + def process( + self, element: Tuple[bytes, Iterable[Tuple[bytes, bytes]]] + ) -> Iterable[Tuple[Any, List[Any]]]: + """Decrypts the key and values of an element. + + Args: + element: A tuple containing the HMAC of the encoded key and an iterable + of tuples of encrypted keys and values. + + Yields: + A tuple containing the decrypted key and a list of decrypted values. + """ + unused_hmac_encoded_key, encoded_elements = element + seen_keys = set() + + # Since there could be hmac collisions, we will use the fernet encrypted + # key to confirm that the mapping is actually correct. + for e in encoded_elements: + encrypted_key, unused_encrypted_value = e + encoded_key = self.fernet.decrypt(encrypted_key) + if encoded_key in seen_keys: + continue + seen_keys.add(encoded_key) + real_key = self.key_coder.decode(encoded_key) + + yield ( + real_key, + list(self.filter_elements_by_key(encoded_key, encoded_elements))) + + [email protected]_input_types(Tuple[K, V]) [email protected]_output_types(Tuple[K, Iterable[V]]) +class GroupByEncryptedKey(PTransform): + """A PTransform that provides a secure alternative to GroupByKey. + + This transform encrypts the keys of the input PCollection, performs a + GroupByKey on the encrypted keys, and then decrypts the keys in the output. + This is useful when the keys contain sensitive data that should not be + stored at rest by the runner. + + """ + def __init__(self, hmac_key: Secret): + """Initializes a GroupByEncryptedKey transform. + + Args: + hmac_key: A Secret object that provides the secret key for HMAC and + encryption. For example, a GcpSecret can be used to access a secret + stored in GCP Secret Manager + """ + self._hmac_key = hmac_key + + def expand(self, pcoll): + kv_type_hint = pcoll.element_type + if kv_type_hint and kv_type_hint != typehints.Any: + coder = coders.registry.get_coder(kv_type_hint) + if not coder.is_kv_coder(): + raise ValueError( + 'Input elements to the transform %s with stateful DoFn must be ' + 'key-value pairs.' % self) + key_coder = coder.key_coder() Review Comment: Yeah, this is a good call. Interestingly, we do check this for the direct runner GBK implementation, but not more broadly as far as I can tell. But we should definitely be verifying here. Updated -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
