This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 06e103d87e8 Add ApplyBucketsWithInterpolation TFTransform (#31291)
06e103d87e8 is described below

commit 06e103d87e8ac883f606475dbadbefed4ba77c9a
Author: Jack McCluskey <34928439+jrmcclus...@users.noreply.github.com>
AuthorDate: Wed May 29 14:23:57 2024 -0400

    Add ApplyBucketsWithInterpolation TFTransform (#31291)
    
    * Add ApplyBucketsWithInterpolation TFTransform
    
    * Update sdks/python/apache_beam/ml/transforms/tft.py
    
    Co-authored-by: tvalentyn <tvalen...@users.noreply.github.com>
    
    * add tft documentation link
    
    * change docstring wording around bucket_boundaries
    
    * Update sdks/python/apache_beam/ml/transforms/tft.py
    
    Co-authored-by: tvalentyn <tvalen...@users.noreply.github.com>
    
    ---------
    
    Co-authored-by: tvalentyn <tvalen...@users.noreply.github.com>
---
 sdks/python/apache_beam/ml/transforms/tft.py      | 49 ++++++++++++++++++++---
 sdks/python/apache_beam/ml/transforms/tft_test.py | 30 ++++++++++++++
 2 files changed, 74 insertions(+), 5 deletions(-)

diff --git a/sdks/python/apache_beam/ml/transforms/tft.py 
b/sdks/python/apache_beam/ml/transforms/tft.py
index 370043bc0d9..e2f02971e7c 100644
--- a/sdks/python/apache_beam/ml/transforms/tft.py
+++ b/sdks/python/apache_beam/ml/transforms/tft.py
@@ -337,16 +337,16 @@ class ApplyBuckets(TFTOperation):
       name: Optional[str] = None):
     """
     This functions is used to map the element to a positive index i for
-    which bucket_boundaries[i-1] <= element < bucket_boundaries[i],
-    if it exists. If input < bucket_boundaries[0], then element is
-    mapped to 0. If element >= bucket_boundaries[-1], then element is
+    which `bucket_boundaries[i-1] <= element < bucket_boundaries[i]`,
+    if it exists. If `input < bucket_boundaries[0]`, then element is
+    mapped to 0. If `element >= bucket_boundaries[-1]`, then element is
     mapped to len(bucket_boundaries). NaNs are mapped to
     len(bucket_boundaries).
 
     Args:
       columns: A list of column names to apply the transformation on.
-      bucket_boundaries: A rank 2 Tensor or list representing the bucket
-        boundaries sorted in ascending order.
+      bucket_boundaries: An iterable of ints or floats representing the bucket
+        boundaries. Must be sorted in ascending order.
       name: (Optional) A string that specifies the name of the operation.
     """
     super().__init__(columns)
@@ -363,6 +363,45 @@ class ApplyBuckets(TFTOperation):
     return output
 
 
+@register_input_dtype(float)
+class ApplyBucketsWithInterpolation(TFTOperation):
+  def __init__(
+      self,
+      columns: List[str],
+      bucket_boundaries: Iterable[Union[int, float]],
+      name: Optional[str] = None):
+    """Interpolates values within the provided buckets and then normalizes to
+    [0, 1].
+    
+    Input values are bucketized based on the provided boundaries such that the
+    input is mapped to a positive index i for which `bucket_boundaries[i-1] <=
+    element < bucket_boundaries[i]`, if it exists. The values are then
+    normalized to the range [0,1] within the bucket, with NaN values being
+    mapped to 0.5.
+
+    For more information, see:
+    
https://www.tensorflow.org/tfx/transform/api_docs/python/tft/apply_buckets_with_interpolation
+
+    Args:
+      columns: A list of column names to apply the transformation on.
+      bucket_boundaries: An iterable of ints or floats representing the bucket
+        boundaries sorted in ascending order.
+      name: (Optional) A string that specifies the name of the operation.
+    """
+    super().__init__(columns)
+    self.bucket_boundaries = [bucket_boundaries]
+    self.name = name
+
+  def apply_transform(
+      self, data: common_types.TensorType,
+      output_column_name: str) -> Dict[str, common_types.TensorType]:
+    output = {
+        output_column_name: tft.apply_buckets_with_interpolation(
+            x=data, bucket_boundaries=self.bucket_boundaries, name=self.name)
+    }
+    return output
+
+
 @register_input_dtype(float)
 class Bucketize(TFTOperation):
   def __init__(
diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py 
b/sdks/python/apache_beam/ml/transforms/tft_test.py
index 5c42ecc012f..f5615e9d4ad 100644
--- a/sdks/python/apache_beam/ml/transforms/tft_test.py
+++ b/sdks/python/apache_beam/ml/transforms/tft_test.py
@@ -364,6 +364,36 @@ class ApplyBucketsTest(unittest.TestCase):
           actual_output, equal_to(expected_output, equals_fn=np.array_equal))
 
 
+class ApplyBucketsWithInterpolationTest(unittest.TestCase):
+  def setUp(self) -> None:
+    self.artifact_location = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.artifact_location)
+
+  @parameterized.expand([
+      ([-1, 9, 10, 11], [10], [0., 0., 1., 1.]),
+      ([15, 20, 25], [10, 20], [.5, 1, 1]),
+  ])
+  def test_apply_buckets(self, test_inputs, bucket_boundaries, 
expected_values):
+    with beam.Pipeline() as p:
+      data = [{'x': [i]} for i in test_inputs]
+      result = (
+          p
+          | "Create" >> beam.Create(data)
+          | "MLTransform" >> base.MLTransform(
+              write_artifact_location=self.artifact_location).with_transform(
+                  tft.ApplyBucketsWithInterpolation(
+                      columns=['x'], bucket_boundaries=bucket_boundaries)))
+      expected_output = []
+      for x in expected_values:
+        expected_output.append(np.array(x))
+
+      actual_output = (result | beam.Map(lambda x: x.x))
+      assert_that(
+          actual_output, equal_to(expected_output, equals_fn=np.allclose))
+
+
 class ComputeAndApplyVocabTest(unittest.TestCase):
   def setUp(self) -> None:
     self.artifact_location = tempfile.mkdtemp()

Reply via email to