This is an automated email from the ASF dual-hosted git repository. jrmccluskey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new c230655a1e5 Implement the hash_words TFT operation (#31249) c230655a1e5 is described below commit c230655a1e599e25a69b77d8a85fad45e5fc7587 Author: Jack McCluskey <34928439+jrmcclus...@users.noreply.github.com> AuthorDate: Mon May 13 14:22:25 2024 -0400 Implement the hash_words TFT operation (#31249) * Implement the hash-words TFT operation * linting * Update sdks/python/apache_beam/ml/transforms/tft.py Co-authored-by: tvalentyn <tvalen...@users.noreply.github.com> * tighten key type hint * formatting --------- Co-authored-by: tvalentyn <tvalen...@users.noreply.github.com> --- sdks/python/apache_beam/ml/transforms/tft.py | 44 +++++++++++++ sdks/python/apache_beam/ml/transforms/tft_test.py | 77 +++++++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index 9b02cf8b75c..550dbedbc7b 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -637,3 +637,47 @@ class BagOfWords(TFTOperation): def count_unique_words( data: tf.SparseTensor, output_vocab_name: Optional[str]) -> None: tft.count_per_key(data, key_vocabulary_filename=output_vocab_name) + + +@register_input_dtype(str) +class HashStrings(TFTOperation): + def __init__( + self, + columns: List[str], + hash_buckets: int, + key: Optional[Tuple[int, int]] = None, + name: Optional[str] = None): + '''Hashes strings into the provided number of buckets. + + Args: + columns: A list of the column names to apply the transformation on. + hash_buckets: the number of buckets to hash the strings into. + key: optional. An array of two Python `uint64`. If passed, output will be + a deterministic function of `strings` and `key`. Note that hashing will + be slower if this value is specified. + name: optional. A name for this operation. + + Raises: + ValueError if `hash_buckets` is not a positive and non-zero integer. + ''' + self.hash_buckets = hash_buckets + self.key = key + self.name = name + + if hash_buckets < 1: + raise ValueError( + 'number of hash buckets must be positive, got ', hash_buckets) + + super().__init__(columns) + + def apply_transform( + self, data: common_types.TensorType, + output_col_name: str) -> Dict[str, common_types.TensorType]: + output_dict = { + output_col_name: tft.hash_strings( + strings=data, + hash_buckets=self.hash_buckets, + key=self.key, + name=self.name) + } + return output_dict diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py b/sdks/python/apache_beam/ml/transforms/tft_test.py index ed7c301f3b8..6763032a8eb 100644 --- a/sdks/python/apache_beam/ml/transforms/tft_test.py +++ b/sdks/python/apache_beam/ml/transforms/tft_test.py @@ -932,5 +932,82 @@ class BagOfWordsTest(unittest.TestCase): self.assertEqual(expected_data, actual_data) +class HashWordsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_single_bucket(self): + list_data = [{'x': 'this is a test string'}] + expected_values = [np.array([0])] + with beam.Pipeline() as p: + list_result = ( + p + | "listCreate" >> beam.Create(list_data) + | "listMLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.HashStrings(columns=['x'], hash_buckets=1))) + result = (list_result | beam.Map(lambda x: x.x)) + assert_that(result, equal_to(expected_values, equals_fn=np.array_equal)) + + def test_multi_bucket_one_string(self): + list_data = [{'x': 'this is a test string'}] + expected_values = [np.array([1])] + with beam.Pipeline() as p: + list_result = ( + p + | "listCreate" >> beam.Create(list_data) + | "listMLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.HashStrings(columns=['x'], hash_buckets=2))) + result = (list_result | beam.Map(lambda x: x.x)) + assert_that(result, equal_to(expected_values, equals_fn=np.array_equal)) + + def test_one_bucket_multi_string(self): + list_data = [{'x': ['these', 'are', 'test', 'strings']}] + expected_values = [np.array([0, 0, 0, 0])] + with beam.Pipeline() as p: + list_result = ( + p + | "listCreate" >> beam.Create(list_data) + | "listMLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.HashStrings(columns=['x'], hash_buckets=1))) + result = (list_result | beam.Map(lambda x: x.x)) + assert_that(result, equal_to(expected_values, equals_fn=np.array_equal)) + + def test_two_bucket_multi_string(self): + list_data = [{'x': ['these', 'are', 'test', 'strings']}] + expected_values = [np.array([1, 0, 1, 0])] + with beam.Pipeline() as p: + list_result = ( + p + | "listCreate" >> beam.Create(list_data) + | "listMLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.HashStrings( + columns=['x'], hash_buckets=2, key=[123, 456]))) + result = (list_result | beam.Map(lambda x: x.x)) + assert_that(result, equal_to(expected_values, equals_fn=np.array_equal)) + + def test_multi_buckets_multi_string(self): + # This is a recreation of one of the TFT test cases from + # https://github.com/tensorflow/transform/blob/master/tensorflow_transform/mappers_test.py + list_data = [{'x': ['Cake', 'Pie', 'Sundae']}] + expected_values = [np.array([6, 5, 6])] + with beam.Pipeline() as p: + list_result = ( + p + | "listCreate" >> beam.Create(list_data) + | "listMLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.HashStrings( + columns=['x'], hash_buckets=11, key=[123, 456]))) + result = (list_result | beam.Map(lambda x: x.x)) + assert_that(result, equal_to(expected_values, equals_fn=np.array_equal)) + + if __name__ == '__main__': unittest.main()