This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/mlTransformArtifacts
in repository https://gitbox.apache.org/repos/asf/beam.git

commit b4b43caa8874b79bf1e1c00cd8f11b82ae19adf6
Author: Danny McCormick <dannymccorm...@google.com>
AuthorDate: Wed Aug 9 15:49:06 2023 -0400

    Update how we specify artifact modes
---
 .../examples/ml_transform/ml_transform_basic.py    | 28 +++++++--------
 sdks/python/apache_beam/ml/transforms/base.py      | 42 +++++++++++++++-------
 sdks/python/apache_beam/ml/transforms/base_test.py | 19 +++++-----
 sdks/python/apache_beam/ml/transforms/handlers.py  |  9 ++---
 4 files changed, 58 insertions(+), 40 deletions(-)

diff --git 
a/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py 
b/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py
index 2166d0db366..8558215ec5b 100644
--- a/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py
+++ b/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py
@@ -61,27 +61,26 @@ def parse_args():
   return parser.parse_known_args()
 
 
-def preprocess_data_for_ml_training(train_data, artifact_mode, args):
+def preprocess_data_for_ml_training(train_data, args):
   """
   Preprocess the data for ML training. This method runs a pipeline to
-  preprocess the data needed for ML training. It produces artifacts that
-  can be used for ML inference later.
+  preprocess the data needed for ML training. It produces artifacts that can
+  be used for ML inference later.
   """
 
   with beam.Pipeline() as p:
     train_data_pcoll = (p | "CreateData" >> beam.Create(train_data))
 
-    # When 'artifact_mode' is set to 'produce', the ComputeAndApplyVocabulary
+    # When using write_artifact_location, the ComputeAndApplyVocabulary
     # function generates a vocabulary file. This file, stored in
-    # 'artifact_location', contains the vocabulary of the entire dataset.
+    # 'write_artifact_location', contains the vocabulary of the entire dataset.
     # This is considered as an artifact of ComputeAndApplyVocabulary transform.
     # The indices of the vocabulary in this file are returned as
     # the output of MLTransform.
     transformed_data_pcoll = (
         train_data_pcoll
         | 'MLTransform' >> MLTransform(
-            artifact_location=args.artifact_location,
-            artifact_mode=artifact_mode,
+            write_artifact_location=args.artifact_location,
         ).with_transform(ComputeAndApplyVocabulary(
             columns=['x'])).with_transform(TFIDF(columns=['x'])))
 
@@ -93,7 +92,7 @@ def preprocess_data_for_ml_training(train_data, 
artifact_mode, args):
     # 0.5008155 ], dtype=float32), x_vocab_index=array([ 0,  2,  3,  5, 21]))
 
 
-def preprocess_data_for_ml_inference(test_data, artifact_mode, args):
+def preprocess_data_for_ml_inference(test_data, args):
   """
   Preprocess the data for ML inference. This method runs a pipeline to
   preprocess the data needed for ML inference. It consumes the artifacts
@@ -108,8 +107,7 @@ def preprocess_data_for_ml_inference(test_data, 
artifact_mode, args):
     transformed_data_pcoll = (
         test_data_pcoll
         | "MLTransformOnTestData" >> MLTransform(
-            artifact_location=args.artifact_location,
-            artifact_mode=artifact_mode,
+            read_artifact_location=args.artifact_location,
             # ww don't need to specify transforms as they are already saved in
             # in the artifacts.
         ))
@@ -149,18 +147,16 @@ def run(args):
 
   # Preprocess the data for ML training.
   # For the data going into the ML model training, we want to produce the
-  # artifacts. So, we set artifact_mode to ArtifactMode.PRODUCE.
-  preprocess_data_for_ml_training(
-      train_data, artifact_mode=ArtifactMode.PRODUCE, args=args)
+  # artifacts.
+  preprocess_data_for_ml_training(train_data, args=args)
 
   # Do some ML model training here.
 
   # Preprocess the data for ML inference.
   # For the data going into the ML model inference, we want to consume the
   # artifacts produced during the stage where we preprocessed the data for ML
-  # training. So, we set artifact_mode to ArtifactMode.CONSUME.
-  preprocess_data_for_ml_inference(
-      test_data, artifact_mode=ArtifactMode.CONSUME, args=args)
+  # training.
+  preprocess_data_for_ml_inference(test_data, args=args)
 
   # To fetch the artifacts produced in MLTransform, you can use
   # ArtifactsFetcher for fetching vocab related artifacts. For
diff --git a/sdks/python/apache_beam/ml/transforms/base.py 
b/sdks/python/apache_beam/ml/transforms/base.py
index 04aa387580a..9407e40530f 100644
--- a/sdks/python/apache_beam/ml/transforms/base.py
+++ b/sdks/python/apache_beam/ml/transforms/base.py
@@ -115,8 +115,8 @@ class 
MLTransform(beam.PTransform[beam.PCollection[ExampleT],
   def __init__(
       self,
       *,
-      artifact_location: str,
-      artifact_mode: str = ArtifactMode.PRODUCE,
+      write_artifact_location: str = '',
+      read_artifact_location: str = '',
       transforms: Optional[Sequence[BaseOperation]] = None):
     """
     MLTransform is a Beam PTransform that can be used to apply
@@ -134,7 +134,7 @@ class 
MLTransform(beam.PTransform[beam.PCollection[ExampleT],
     themselves.
 
     Args:
-      artifact_location: A storage location for artifacts resulting from
+      write_artifact_location: A storage location for artifacts resulting from
         MLTransform. These artifacts include transformations applied to
         the dataset and generated values like min, max from ScaleTo01,
         and mean, var from ScaleToZScore. Artifacts are produced and stored
@@ -143,23 +143,41 @@ class 
MLTransform(beam.PTransform[beam.PCollection[ExampleT],
         retrieved from this location. Note that when consuming artifacts,
         it is not necessary to pass the transforms since they are inherently
         stored within the artifacts themselves. The value assigned to
-        `artifact_location` should be a valid storage path where the artifacts
-        can be written to or read from.
+        `write_artifact_location` should be a valid storage path where the
+        artifacts can be written to.
+      read_artifact_location: A storage location to read artifacts resulting
+        froma previous MLTransform. These artifacts include transformations
+        applied to the dataset and generated values like min, max from
+        ScaleTo01, and mean, var from ScaleToZScore. Note that when consuming
+        artifacts, it is not necessary to pass the transforms since they are
+        inherently stored within the artifacts themselves. The value assigned
+        to `read_artifact_location` should be a valid storage path where the
+        artifacts can be read from.
       transforms: A list of transforms to apply to the data. All the transforms
         are applied in the order they are specified. The input of the
         i-th transform is the output of the (i-1)-th transform. Multi-input
         transforms are not supported yet.
-      artifact_mode: Whether to produce or consume artifacts. If set to
-        'consume', MLTransform will assume that the artifacts are already
-        computed and stored in the artifact_location. Pass the same artifact
-        location that was passed during produce phase to ensure that the
-        right artifacts are read. If set to 'produce', MLTransform
-        will compute the artifacts and store them in the artifact_location.
-        The artifacts will be read from this location during the consume phase.
     """
     if transforms:
       _ = [self._validate_transform(transform) for transform in transforms]
 
+    if len(read_artifact_location) > 0 and len(write_artifact_location) > 0:
+      raise ValueError(
+          'Only one of read_artifact_location or write_artifact_location can '
+          ' bespecified to initialize MLTransform')
+
+    if len(read_artifact_location) == 0 and len(write_artifact_location) == 0:
+      raise ValueError(
+          'Either a read_artifact_location or write_artifact_location must be '
+          'specified to initialize MLTransform')
+
+    artifact_location = write_artifact_location
+    artifact_mode = ArtifactMode.PRODUCE
+
+    if len(read_artifact_location) > 0:
+      artifact_location = read_artifact_location
+      artifact_mode = ArtifactMode.CONSUME
+
     # avoid circular import
     # pylint: disable=wrong-import-order, wrong-import-position
     from apache_beam.ml.transforms.handlers import TFTProcessHandler
diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py 
b/sdks/python/apache_beam/ml/transforms/base_test.py
index 09f4ddfa53f..df7a6d26b47 100644
--- a/sdks/python/apache_beam/ml/transforms/base_test.py
+++ b/sdks/python/apache_beam/ml/transforms/base_test.py
@@ -62,7 +62,7 @@ class BaseMLTransformTest(unittest.TestCase):
     fake_fn_1 = _FakeOperation(name='fake_fn_1', columns=['x'])
     transforms = [fake_fn_1]
     ml_transform = base.MLTransform(
-        transforms=transforms, artifact_location=self.artifact_location)
+        transforms=transforms, write_artifact_location=self.artifact_location)
     ml_transform = ml_transform.with_transform(
         transform=_FakeOperation(name='fake_fn_2', columns=['x']))
 
@@ -80,7 +80,8 @@ class BaseMLTransformTest(unittest.TestCase):
           p
           | beam.Create(data)
           | base.MLTransform(
-              artifact_location=self.artifact_location, transforms=transforms))
+              write_artifact_location=self.artifact_location,
+              transforms=transforms))
       expected_output = [
           np.array([0.0], dtype=np.float32),
           np.array([1.0], dtype=np.float32),
@@ -97,7 +98,8 @@ class BaseMLTransformTest(unittest.TestCase):
           p
           | beam.Create(data)
           | base.MLTransform(
-              transforms=transforms, artifact_location=self.artifact_location))
+              transforms=transforms,
+              write_artifact_location=self.artifact_location))
       expected_output = [
           np.array([0, 0.2, 0.4], dtype=np.float32),
           np.array([0.6, 0.8, 1], dtype=np.float32),
@@ -170,7 +172,7 @@ class BaseMLTransformTest(unittest.TestCase):
               beam.row_type.RowTypeConstraint.from_fields(
                   list(input_types.items()))))
       transformed_data = schema_data | base.MLTransform(
-          artifact_location=self.artifact_location, transforms=transforms)
+          write_artifact_location=self.artifact_location, 
transforms=transforms)
       for name, typ in transformed_data.element_type._fields:
         if name in expected_dtype:
           self.assertEqual(expected_dtype[name], typ)
@@ -187,8 +189,7 @@ class BaseMLTransformTest(unittest.TestCase):
             | beam.WindowInto(beam.window.FixedWindows(1))
             | base.MLTransform(
                 transforms=transforms,
-                artifact_location=self.artifact_location,
-                artifact_mode=base.ArtifactMode.PRODUCE,
+                write_artifact_location=self.artifact_location,
             ))
 
   def test_ml_transform_on_multiple_columns_single_transform(self):
@@ -199,7 +200,8 @@ class BaseMLTransformTest(unittest.TestCase):
           p
           | beam.Create(data)
           | base.MLTransform(
-              transforms=transforms, artifact_location=self.artifact_location))
+              transforms=transforms,
+              write_artifact_location=self.artifact_location))
       expected_output_x = [
           np.array([0, 0.5, 1], dtype=np.float32),
       ]
@@ -225,7 +227,8 @@ class BaseMLTransformTest(unittest.TestCase):
           p
           | beam.Create(data)
           | base.MLTransform(
-              transforms=transforms, artifact_location=self.artifact_location))
+              transforms=transforms,
+              write_artifact_location=self.artifact_location))
       expected_output_x = [
           np.array([0, 0.5, 1], dtype=np.float32),
       ]
diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py 
b/sdks/python/apache_beam/ml/transforms/handlers.py
index 0b00610684e..8695d5146ef 100644
--- a/sdks/python/apache_beam/ml/transforms/handlers.py
+++ b/sdks/python/apache_beam/ml/transforms/handlers.py
@@ -437,10 +437,11 @@ class 
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
       if not os.path.exists(os.path.join(
           self.artifact_location, RAW_DATA_METADATA_DIR, SCHEMA_FILE)):
         raise FileNotFoundError(
-            "Artifacts not found at location: %s when artifact_mode=consume."
-            "Make sure you've run the pipeline in `produce` mode using "
-            "this artifact location before setting artifact_mode to `consume`."
-            % os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+            "Artifacts not found at location: %s when using "
+            "read_artifact_location. Make sure you've run the pipeline with "
+            "write_artifact_location using this artifact location before "
+            "running with read_artifact_location set." %
+            os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
       raw_data_metadata = metadata_io.read_metadata(
           os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
 

Reply via email to