This is an automated email from the ASF dual-hosted git repository. anandinguva pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 5a5779fe43f Refactor MLTrasform BaseOperation class (#27389) 5a5779fe43f is described below commit 5a5779fe43f5547f6227f5ce0c4dbca304eee3e1 Author: Anand Inguva <34158215+ananding...@users.noreply.github.com> AuthorDate: Mon Jul 10 20:58:23 2023 +0000 Refactor MLTrasform BaseOperation class (#27389) * Refactor code. Make get_artifacts abstract method of BaseOperation * rename apply to apply_transform * Add Docstring to MLTransform * Add _validate_transform method * Fix mypy * Provide default value while fetching vocab size from utils * Fix base test --- sdks/python/apache_beam/ml/transforms/base.py | 59 +++++++++++++++++++--- sdks/python/apache_beam/ml/transforms/base_test.py | 2 +- sdks/python/apache_beam/ml/transforms/handlers.py | 8 +-- .../apache_beam/ml/transforms/handlers_test.py | 13 ++--- sdks/python/apache_beam/ml/transforms/tft.py | 40 +++++++-------- sdks/python/apache_beam/ml/transforms/utils.py | 3 +- 6 files changed, 79 insertions(+), 46 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index f2906409484..04aa387580a 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -17,6 +17,7 @@ # pytype: skip-file import abc +from typing import Dict from typing import Generic from typing import List from typing import Optional @@ -55,16 +56,38 @@ class BaseOperation(Generic[OperationInputT, OperationOutputT], abc.ABC): self.columns = columns @abc.abstractmethod - def apply( - self, data: OperationInputT, output_column_name: str) -> OperationOutputT: + def apply_transform(self, data: OperationInputT, + output_column_name: str) -> Dict[str, OperationOutputT]: """ - Define any processing logic in the apply() method. + Define any processing logic in the apply_transform() method. processing logics are applied on inputs and returns a transformed output. Args: inputs: input data. """ + @abc.abstractmethod + def get_artifacts( + self, data: OperationInputT, + output_column_prefix: str) -> Optional[Dict[str, OperationOutputT]]: + """ + If the operation generates any artifacts, they can be returned from this + method. + """ + pass + + def __call__(self, data: OperationInputT, + output_column_name: str) -> Dict[str, OperationOutputT]: + """ + This method is called when the instance of the class is called. + This method will invoke the apply() method of the class. + """ + transformed_data = self.apply_transform(data, output_column_name) + artifacts = self.get_artifacts(data, output_column_name) + if artifacts: + transformed_data = {**transformed_data, **artifacts} + return transformed_data + class ProcessHandler(Generic[ExampleT, MLTransformOutputT], abc.ABC): """ @@ -96,6 +119,20 @@ class MLTransform(beam.PTransform[beam.PCollection[ExampleT], artifact_mode: str = ArtifactMode.PRODUCE, transforms: Optional[Sequence[BaseOperation]] = None): """ + MLTransform is a Beam PTransform that can be used to apply + transformations to the data. MLTransform is used to wrap the + data processing transforms provided by Apache Beam. MLTransform + works in two modes: produce and consume. In the produce mode, + MLTransform will apply the transforms to the data and store the + artifacts in the artifact_location. In the consume mode, MLTransform + will read the artifacts from the artifact_location and apply the + transforms to the data. The artifact_location should be a valid + storage path where the artifacts can be written to or read from. + + Note that when consuming artifacts, it is not necessary to pass the + transforms since they are inherently stored within the artifacts + themselves. + Args: artifact_location: A storage location for artifacts resulting from MLTransform. These artifacts include transformations applied to @@ -113,15 +150,16 @@ class MLTransform(beam.PTransform[beam.PCollection[ExampleT], i-th transform is the output of the (i-1)-th transform. Multi-input transforms are not supported yet. artifact_mode: Whether to produce or consume artifacts. If set to - 'consume', the handler will assume that the artifacts are already + 'consume', MLTransform will assume that the artifacts are already computed and stored in the artifact_location. Pass the same artifact location that was passed during produce phase to ensure that the - right artifacts are read. If set to 'produce', the handler + right artifacts are read. If set to 'produce', MLTransform will compute the artifacts and store them in the artifact_location. The artifacts will be read from this location during the consume phase. - There is no need to pass the transforms in this case since they are - already embedded in the stored artifacts. """ + if transforms: + _ = [self._validate_transform(transform) for transform in transforms] + # avoid circular import # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.ml.transforms.handlers import TFTProcessHandler @@ -161,5 +199,12 @@ class MLTransform(beam.PTransform[beam.PCollection[ExampleT], Returns: A MLTransform instance. """ + self._validate_transform(transform) self._process_handler.append_transform(transform) return self + + def _validate_transform(self, transform): + if not isinstance(transform, BaseOperation): + raise TypeError( + 'transform must be a subclass of BaseOperation. ' + 'Got: %s instead.' % type(transform)) diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index be208c93426..3ac59ff98a7 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -47,7 +47,7 @@ class _FakeOperation(TFTOperation): super().__init__(*args, **kwargs) self.name = name - def apply(self, inputs, output_column_name, **kwargs): + def apply_transform(self, inputs, output_column_name, **kwargs): return {output_column_name: inputs} diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index 9754204f9fe..09eabe7e3e6 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -126,17 +126,14 @@ class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type, *, artifact_location: str, transforms: Optional[Sequence[TFTOperation]] = None, - preprocessing_fn: typing.Optional[typing.Callable] = None, artifact_mode: str = ArtifactMode.PRODUCE): """ A handler class for processing data with TensorFlow Transform (TFT) - operations. This class is intended to be subclassed, with subclasses - implementing the `preprocessing_fn` method. + operations. """ self.transforms = transforms if transforms else [] self.transformed_schema: Dict[str, type] = {} self.artifact_location = artifact_location - self.preprocessing_fn = preprocessing_fn self.artifact_mode = artifact_mode if artifact_mode not in ['produce', 'consume']: raise ValueError('artifact_mode must be either `produce` or `consume`.') @@ -291,8 +288,7 @@ class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type, for transform in self.transforms: columns = transform.columns for col in columns: - intermediate_result = transform.apply( - outputs[col], output_column_name=col) + intermediate_result = transform(outputs[col], output_column_name=col) for key, value in intermediate_result.items(): outputs[key] = value return outputs diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py index 878006550dc..4abcfee0a6e 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers_test.py +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -49,23 +49,18 @@ if not tft: class _AddOperation(TFTOperation): - def apply(self, inputs, output_column_name, **kwargs): + def apply_transform(self, inputs, output_column_name, **kwargs): return {output_column_name: inputs + 1} class _MultiplyOperation(TFTOperation): - def apply(self, inputs, output_column_name, **kwargs): + def apply_transform(self, inputs, output_column_name, **kwargs): return {output_column_name: inputs * 10} class _FakeOperationWithArtifacts(TFTOperation): - def apply(self, inputs, output_column_name, **kwargs): - return { - **{ - output_column_name: inputs - }, - **(self.get_artifacts(inputs, 'artifact')) - } + def apply_transform(self, inputs, output_column_name, **kwargs): + return {output_column_name: inputs} def get_artifacts(self, data, col_name): return {'artifact': tf.convert_to_tensor([1])} diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index 329a10a74ca..c96290d0440 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -79,9 +79,9 @@ class TFTOperation(BaseOperation[common_types.TensorType, """ Base Operation class for TFT data processing transformations. Processing logic for the transformation is defined in the - apply() method. If you have a custom transformation that is not + apply_transform() method. If you have a custom transformation that is not supported by the existing transforms, you can extend this class - and implement the apply() method. + and implement the apply_transform() method. Args: columns: List of column names to apply the transformation. """ @@ -141,8 +141,9 @@ class ComputeAndApplyVocabulary(TFTOperation): 'compute_and_apply_vocab') self._name = name - def apply(self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + def apply_transform( + self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: return { output_column_name: tft.compute_and_apply_vocabulary( x=data, @@ -186,15 +187,13 @@ class ScaleToZScore(TFTOperation): self.elementwise = elementwise self.name = name - def apply(self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: - artifacts = self.get_artifacts(data, output_column_name) + def apply_transform( + self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: output_dict = { output_column_name: tft.scale_to_z_score( x=data, elementwise=self.elementwise, name=self.name) } - if artifacts is not None: - output_dict.update(artifacts) return output_dict def get_artifacts(self, data: common_types.TensorType, @@ -245,15 +244,13 @@ class ScaleTo01(TFTOperation): col_name + '_max': tf.broadcast_to(tft.max(data), shape) } - def apply(self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: - artifacts = self.get_artifacts(data, output_column_name) + def apply_transform( + self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: output = tft.scale_to_0_1( x=data, elementwise=self.elementwise, name=self.name) output_dict = {output_column_name: output} - if artifacts is not None: - output_dict.update(artifacts) return output_dict @@ -282,8 +279,9 @@ class ApplyBuckets(TFTOperation): self.bucket_boundaries = [bucket_boundaries] self.name = name - def apply(self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + def apply_transform( + self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: output = { output_column_name: tft.apply_buckets( x=data, bucket_boundaries=self.bucket_boundaries, name=self.name) @@ -354,9 +352,9 @@ class Bucketize(TFTOperation): # Should we change the prefix _quantiles to _bucket_boundaries? return {col_name + '_quantiles': tf.broadcast_to(quantiles, shape)} - def apply(self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: - artifacts = self.get_artifacts(data, output_column_name) + def apply_transform( + self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: output = { output_column_name: tft.bucketize( x=data, @@ -365,8 +363,6 @@ class Bucketize(TFTOperation): elementwise=self.elementwise, name=self.name) } - if artifacts is not None: - output.update(artifacts) return output @@ -408,7 +404,7 @@ class TFIDF(TFTOperation): self.name = name self.tfidf_weight = None - def apply( + def apply_transform( self, data: tf.SparseTensor, output_column_name: str) -> tf.SparseTensor: if self.vocab_size is None: diff --git a/sdks/python/apache_beam/ml/transforms/utils.py b/sdks/python/apache_beam/ml/transforms/utils.py index 1f1fa729b16..19bb02c5ae1 100644 --- a/sdks/python/apache_beam/ml/transforms/utils.py +++ b/sdks/python/apache_beam/ml/transforms/utils.py @@ -52,5 +52,6 @@ class ArtifactsFetcher(): """ return self.transform_output.vocabulary_file_by_name(vocab_filename) - def get_vocab_size(self, vocab_filename: str) -> int: + def get_vocab_size( + self, vocab_filename: str = 'compute_and_apply_vocab') -> int: return self.transform_output.vocabulary_size_by_name(vocab_filename)