This is an automated email from the ASF dual-hosted git repository.

anandinguva pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a5779fe43f Refactor MLTrasform BaseOperation class (#27389)
5a5779fe43f is described below

commit 5a5779fe43f5547f6227f5ce0c4dbca304eee3e1
Author: Anand Inguva <34158215+ananding...@users.noreply.github.com>
AuthorDate: Mon Jul 10 20:58:23 2023 +0000

    Refactor MLTrasform BaseOperation class (#27389)
    
    * Refactor code. Make get_artifacts abstract method of BaseOperation
    
    * rename apply to apply_transform
    
    * Add Docstring to MLTransform
    
    * Add _validate_transform method
    
    * Fix mypy
    
    * Provide default value while fetching vocab size from utils
    
    * Fix base test
---
 sdks/python/apache_beam/ml/transforms/base.py      | 59 +++++++++++++++++++---
 sdks/python/apache_beam/ml/transforms/base_test.py |  2 +-
 sdks/python/apache_beam/ml/transforms/handlers.py  |  8 +--
 .../apache_beam/ml/transforms/handlers_test.py     | 13 ++---
 sdks/python/apache_beam/ml/transforms/tft.py       | 40 +++++++--------
 sdks/python/apache_beam/ml/transforms/utils.py     |  3 +-
 6 files changed, 79 insertions(+), 46 deletions(-)

diff --git a/sdks/python/apache_beam/ml/transforms/base.py 
b/sdks/python/apache_beam/ml/transforms/base.py
index f2906409484..04aa387580a 100644
--- a/sdks/python/apache_beam/ml/transforms/base.py
+++ b/sdks/python/apache_beam/ml/transforms/base.py
@@ -17,6 +17,7 @@
 # pytype: skip-file
 
 import abc
+from typing import Dict
 from typing import Generic
 from typing import List
 from typing import Optional
@@ -55,16 +56,38 @@ class BaseOperation(Generic[OperationInputT, 
OperationOutputT], abc.ABC):
     self.columns = columns
 
   @abc.abstractmethod
-  def apply(
-      self, data: OperationInputT, output_column_name: str) -> 
OperationOutputT:
+  def apply_transform(self, data: OperationInputT,
+                      output_column_name: str) -> Dict[str, OperationOutputT]:
     """
-    Define any processing logic in the apply() method.
+    Define any processing logic in the apply_transform() method.
     processing logics are applied on inputs and returns a transformed
     output.
     Args:
       inputs: input data.
     """
 
+  @abc.abstractmethod
+  def get_artifacts(
+      self, data: OperationInputT,
+      output_column_prefix: str) -> Optional[Dict[str, OperationOutputT]]:
+    """
+    If the operation generates any artifacts, they can be returned from this
+    method.
+    """
+    pass
+
+  def __call__(self, data: OperationInputT,
+               output_column_name: str) -> Dict[str, OperationOutputT]:
+    """
+    This method is called when the instance of the class is called.
+    This method will invoke the apply() method of the class.
+    """
+    transformed_data = self.apply_transform(data, output_column_name)
+    artifacts = self.get_artifacts(data, output_column_name)
+    if artifacts:
+      transformed_data = {**transformed_data, **artifacts}
+    return transformed_data
+
 
 class ProcessHandler(Generic[ExampleT, MLTransformOutputT], abc.ABC):
   """
@@ -96,6 +119,20 @@ class 
MLTransform(beam.PTransform[beam.PCollection[ExampleT],
       artifact_mode: str = ArtifactMode.PRODUCE,
       transforms: Optional[Sequence[BaseOperation]] = None):
     """
+    MLTransform is a Beam PTransform that can be used to apply
+    transformations to the data. MLTransform is used to wrap the
+    data processing transforms provided by Apache Beam. MLTransform
+    works in two modes: produce and consume. In the produce mode,
+    MLTransform will apply the transforms to the data and store the
+    artifacts in the artifact_location. In the consume mode, MLTransform
+    will read the artifacts from the artifact_location and apply the
+    transforms to the data. The artifact_location should be a valid
+    storage path where the artifacts can be written to or read from.
+
+    Note that when consuming artifacts, it is not necessary to pass the
+    transforms since they are inherently stored within the artifacts
+    themselves.
+
     Args:
       artifact_location: A storage location for artifacts resulting from
         MLTransform. These artifacts include transformations applied to
@@ -113,15 +150,16 @@ class 
MLTransform(beam.PTransform[beam.PCollection[ExampleT],
         i-th transform is the output of the (i-1)-th transform. Multi-input
         transforms are not supported yet.
       artifact_mode: Whether to produce or consume artifacts. If set to
-        'consume', the handler will assume that the artifacts are already
+        'consume', MLTransform will assume that the artifacts are already
         computed and stored in the artifact_location. Pass the same artifact
         location that was passed during produce phase to ensure that the
-        right artifacts are read. If set to 'produce', the handler
+        right artifacts are read. If set to 'produce', MLTransform
         will compute the artifacts and store them in the artifact_location.
         The artifacts will be read from this location during the consume phase.
-        There is no need to pass the transforms in this case since they are
-        already embedded in the stored artifacts.
     """
+    if transforms:
+      _ = [self._validate_transform(transform) for transform in transforms]
+
     # avoid circular import
     # pylint: disable=wrong-import-order, wrong-import-position
     from apache_beam.ml.transforms.handlers import TFTProcessHandler
@@ -161,5 +199,12 @@ class 
MLTransform(beam.PTransform[beam.PCollection[ExampleT],
     Returns:
       A MLTransform instance.
     """
+    self._validate_transform(transform)
     self._process_handler.append_transform(transform)
     return self
+
+  def _validate_transform(self, transform):
+    if not isinstance(transform, BaseOperation):
+      raise TypeError(
+          'transform must be a subclass of BaseOperation. '
+          'Got: %s instead.' % type(transform))
diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py 
b/sdks/python/apache_beam/ml/transforms/base_test.py
index be208c93426..3ac59ff98a7 100644
--- a/sdks/python/apache_beam/ml/transforms/base_test.py
+++ b/sdks/python/apache_beam/ml/transforms/base_test.py
@@ -47,7 +47,7 @@ class _FakeOperation(TFTOperation):
     super().__init__(*args, **kwargs)
     self.name = name
 
-  def apply(self, inputs, output_column_name, **kwargs):
+  def apply_transform(self, inputs, output_column_name, **kwargs):
     return {output_column_name: inputs}
 
 
diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py 
b/sdks/python/apache_beam/ml/transforms/handlers.py
index 9754204f9fe..09eabe7e3e6 100644
--- a/sdks/python/apache_beam/ml/transforms/handlers.py
+++ b/sdks/python/apache_beam/ml/transforms/handlers.py
@@ -126,17 +126,14 @@ class 
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
       *,
       artifact_location: str,
       transforms: Optional[Sequence[TFTOperation]] = None,
-      preprocessing_fn: typing.Optional[typing.Callable] = None,
       artifact_mode: str = ArtifactMode.PRODUCE):
     """
     A handler class for processing data with TensorFlow Transform (TFT)
-    operations. This class is intended to be subclassed, with subclasses
-    implementing the `preprocessing_fn` method.
+    operations.
     """
     self.transforms = transforms if transforms else []
     self.transformed_schema: Dict[str, type] = {}
     self.artifact_location = artifact_location
-    self.preprocessing_fn = preprocessing_fn
     self.artifact_mode = artifact_mode
     if artifact_mode not in ['produce', 'consume']:
       raise ValueError('artifact_mode must be either `produce` or `consume`.')
@@ -291,8 +288,7 @@ class 
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
     for transform in self.transforms:
       columns = transform.columns
       for col in columns:
-        intermediate_result = transform.apply(
-            outputs[col], output_column_name=col)
+        intermediate_result = transform(outputs[col], output_column_name=col)
         for key, value in intermediate_result.items():
           outputs[key] = value
     return outputs
diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py 
b/sdks/python/apache_beam/ml/transforms/handlers_test.py
index 878006550dc..4abcfee0a6e 100644
--- a/sdks/python/apache_beam/ml/transforms/handlers_test.py
+++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py
@@ -49,23 +49,18 @@ if not tft:
 
 
 class _AddOperation(TFTOperation):
-  def apply(self, inputs, output_column_name, **kwargs):
+  def apply_transform(self, inputs, output_column_name, **kwargs):
     return {output_column_name: inputs + 1}
 
 
 class _MultiplyOperation(TFTOperation):
-  def apply(self, inputs, output_column_name, **kwargs):
+  def apply_transform(self, inputs, output_column_name, **kwargs):
     return {output_column_name: inputs * 10}
 
 
 class _FakeOperationWithArtifacts(TFTOperation):
-  def apply(self, inputs, output_column_name, **kwargs):
-    return {
-        **{
-            output_column_name: inputs
-        },
-        **(self.get_artifacts(inputs, 'artifact'))
-    }
+  def apply_transform(self, inputs, output_column_name, **kwargs):
+    return {output_column_name: inputs}
 
   def get_artifacts(self, data, col_name):
     return {'artifact': tf.convert_to_tensor([1])}
diff --git a/sdks/python/apache_beam/ml/transforms/tft.py 
b/sdks/python/apache_beam/ml/transforms/tft.py
index 329a10a74ca..c96290d0440 100644
--- a/sdks/python/apache_beam/ml/transforms/tft.py
+++ b/sdks/python/apache_beam/ml/transforms/tft.py
@@ -79,9 +79,9 @@ class TFTOperation(BaseOperation[common_types.TensorType,
     """
     Base Operation class for TFT data processing transformations.
     Processing logic for the transformation is defined in the
-    apply() method. If you have a custom transformation that is not
+    apply_transform() method. If you have a custom transformation that is not
     supported by the existing transforms, you can extend this class
-    and implement the apply() method.
+    and implement the apply_transform() method.
     Args:
       columns: List of column names to apply the transformation.
     """
@@ -141,8 +141,9 @@ class ComputeAndApplyVocabulary(TFTOperation):
         'compute_and_apply_vocab')
     self._name = name
 
-  def apply(self, data: common_types.TensorType,
-            output_column_name: str) -> Dict[str, common_types.TensorType]:
+  def apply_transform(
+      self, data: common_types.TensorType,
+      output_column_name: str) -> Dict[str, common_types.TensorType]:
     return {
         output_column_name: tft.compute_and_apply_vocabulary(
             x=data,
@@ -186,15 +187,13 @@ class ScaleToZScore(TFTOperation):
     self.elementwise = elementwise
     self.name = name
 
-  def apply(self, data: common_types.TensorType,
-            output_column_name: str) -> Dict[str, common_types.TensorType]:
-    artifacts = self.get_artifacts(data, output_column_name)
+  def apply_transform(
+      self, data: common_types.TensorType,
+      output_column_name: str) -> Dict[str, common_types.TensorType]:
     output_dict = {
         output_column_name: tft.scale_to_z_score(
             x=data, elementwise=self.elementwise, name=self.name)
     }
-    if artifacts is not None:
-      output_dict.update(artifacts)
     return output_dict
 
   def get_artifacts(self, data: common_types.TensorType,
@@ -245,15 +244,13 @@ class ScaleTo01(TFTOperation):
         col_name + '_max': tf.broadcast_to(tft.max(data), shape)
     }
 
-  def apply(self, data: common_types.TensorType,
-            output_column_name: str) -> Dict[str, common_types.TensorType]:
-    artifacts = self.get_artifacts(data, output_column_name)
+  def apply_transform(
+      self, data: common_types.TensorType,
+      output_column_name: str) -> Dict[str, common_types.TensorType]:
     output = tft.scale_to_0_1(
         x=data, elementwise=self.elementwise, name=self.name)
 
     output_dict = {output_column_name: output}
-    if artifacts is not None:
-      output_dict.update(artifacts)
     return output_dict
 
 
@@ -282,8 +279,9 @@ class ApplyBuckets(TFTOperation):
     self.bucket_boundaries = [bucket_boundaries]
     self.name = name
 
-  def apply(self, data: common_types.TensorType,
-            output_column_name: str) -> Dict[str, common_types.TensorType]:
+  def apply_transform(
+      self, data: common_types.TensorType,
+      output_column_name: str) -> Dict[str, common_types.TensorType]:
     output = {
         output_column_name: tft.apply_buckets(
             x=data, bucket_boundaries=self.bucket_boundaries, name=self.name)
@@ -354,9 +352,9 @@ class Bucketize(TFTOperation):
     # Should we change the prefix _quantiles to _bucket_boundaries?
     return {col_name + '_quantiles': tf.broadcast_to(quantiles, shape)}
 
-  def apply(self, data: common_types.TensorType,
-            output_column_name: str) -> Dict[str, common_types.TensorType]:
-    artifacts = self.get_artifacts(data, output_column_name)
+  def apply_transform(
+      self, data: common_types.TensorType,
+      output_column_name: str) -> Dict[str, common_types.TensorType]:
     output = {
         output_column_name: tft.bucketize(
             x=data,
@@ -365,8 +363,6 @@ class Bucketize(TFTOperation):
             elementwise=self.elementwise,
             name=self.name)
     }
-    if artifacts is not None:
-      output.update(artifacts)
     return output
 
 
@@ -408,7 +404,7 @@ class TFIDF(TFTOperation):
     self.name = name
     self.tfidf_weight = None
 
-  def apply(
+  def apply_transform(
       self, data: tf.SparseTensor, output_column_name: str) -> tf.SparseTensor:
 
     if self.vocab_size is None:
diff --git a/sdks/python/apache_beam/ml/transforms/utils.py 
b/sdks/python/apache_beam/ml/transforms/utils.py
index 1f1fa729b16..19bb02c5ae1 100644
--- a/sdks/python/apache_beam/ml/transforms/utils.py
+++ b/sdks/python/apache_beam/ml/transforms/utils.py
@@ -52,5 +52,6 @@ class ArtifactsFetcher():
     """
     return self.transform_output.vocabulary_file_by_name(vocab_filename)
 
-  def get_vocab_size(self, vocab_filename: str) -> int:
+  def get_vocab_size(
+      self, vocab_filename: str = 'compute_and_apply_vocab') -> int:
     return self.transform_output.vocabulary_size_by_name(vocab_filename)

Reply via email to