robertwb commented on code in PR #26795: URL: https://github.com/apache/beam/pull/26795#discussion_r1224839602
########## sdks/python/apache_beam/ml/transforms/tft_transforms.py: ########## @@ -0,0 +1,416 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module defines a set of data processing transforms that can be used +to perform common data transformations on a dataset. These transforms are +implemented using the TensorFlow Transform (TFT) library. The transforms +in this module are intended to be used in conjunction with the +beam.ml.MLTransform class, which provides a convenient interface for +applying a sequence of data processing transforms to a dataset with the +help of the ProcessHandler class. + +See the documentation for beam.ml.MLTransform for more details. + +Since the transforms in this module are implemented using TFT, they +should be wrapped inside a TFTProcessHandler object before being passed +to the beam.ml.MLTransform class. The ProcessHandler will let MLTransform +know which type of input is expected and infers the relevant schema required +for the TFT library. + +Note: The data processing transforms defined in this module don't +perform the transformation immediately. Instead, it returns a +configured operation object, which encapsulates the details of the +transformation. The actual computation takes place later in the Apache Beam +pipeline, after all transformations are set up and the pipeline is run. +""" + +import logging +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Union + +from apache_beam.ml.transforms.base import BaseOperation +import tensorflow as tf +import tensorflow_transform as tft +from tensorflow_transform import analyzers +from tensorflow_transform import common_types +from tensorflow_transform import tf_utils + +__all__ = [ + 'ComputeAndApplyVocabulary', + 'Scale_To_ZScore', + 'Scale_To_0_1', + 'ApplyBuckets', + 'Bucketize' +] + + +class TFTOperation(BaseOperation): + def __init__(self, columns: List[str], **kwargs): + """ + Base Opertation class for all the TFT operations. + """ + self.columns = columns + self._kwargs = kwargs + + if not columns: + raise RuntimeError( + "Columns are not specified. Please specify the column for the " + " op %s" % self) + + def validate_args(self): + raise NotImplementedError + + def get_artifacts(self, data: common_types.TensorType, + col_name) -> Optional[Dict[str, tf.Tensor]]: + return None + + +class ComputeAndApplyVocabulary(TFTOperation): + def __init__( + self, + columns: List[str], + *, + default_value: Any = -1, + top_k: Optional[int] = None, + frequency_threshold: Optional[int] = None, + num_oov_buckets: int = 0, + vocab_filename: Optional[str] = None, + name: Optional[str] = None, + **kwargs): + """ + This function computes the vocabulary for the given columns of incoming + data. The transformation converts the input values to indices of the + vocabulary. + + Args: + columns: List of column names to apply the transformation. + default_value: (Optional) The value to use for out-of-vocabulary values. + top_k: (Optional) The number of most frequent tokens to keep. + frequency_threshold: (Optional) Limit the generated vocabulary only to + elements whose absolute frequency is >= to the supplied threshold. + If set to None, the full vocabulary is generated. + num_oov_buckets: Any lookup of an out-of-vocabulary token will return a + bucket ID based on its hash if `num_oov_buckets` is greater than zero. + Otherwise it is assigned the `default_value`. + vocab_filename: The file name for the vocabulary file. If None, + a name based on the scope name in the context of this graph will + be used as the file name. If not None, should be unique within + a given preprocessing function. + NOTE in order to make your pipelines resilient to implementation + details please set `vocab_filename` when you are using + the vocab_filename on a downstream component. + """ + super().__init__(columns, **kwargs) + self._default_value = default_value + self._top_k = top_k + self._frequency_threshold = frequency_threshold + self._num_oov_buckets = num_oov_buckets + self._vocab_filename = vocab_filename + self._name = name + + def apply(self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: + # TODO: Pending outputting artifact. + return { + output_column_name: tft.compute_and_apply_vocabulary( + x=data, **self._kwargs) + } + + def __str__(self): + return "compute_and_apply_vocabulary" + + +class Scale_To_ZScore(TFTOperation): + def __init__( + self, + columns: List[str], + *, + elementwise: bool = False, + name: Optional[str] = None, + **kwargs): + """ + This function performs a scaling transformation on the specified columns of + the incoming data. It processes the input data such that it's normalized + to have a mean of 0 and a variance of 1. The transformation achieves this + by subtracting the mean from the input data and then dividing it by the + square root of the variance. + + Args: + columns: A list of column names to apply the transformation on. + elementwise: If True, the transformation is applied elementwise. + Otherwise, the transformation is applied on the entire column. + name: A name for the operation (optional). + + scale_to_z_score also outputs additional artifacts. The artifacts are + mean, which is the mean value in the column, and var, which is the + variance in the column. The artifacts are stored in the column + named with the suffix <original_col_name>_mean and <original_col_name>_var + respectively. + """ + super().__init__(columns, **kwargs) + self.elementwise = elementwise + self.name = name + + def apply(self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: + artifacts = self.get_artifacts(data, output_column_name) + output = {output_column_name: tft.scale_to_z_score(x=data, **self._kwargs)} + output_dict = {output_column_name: output} + if artifacts is not None: + output_dict.update(artifacts) + return output_dict + + def get_artifacts(self, data: common_types.TensorType, + col_name: str) -> Dict[str, tf.Tensor]: + mean_var = tft.analyzers._mean_and_var(data) + shape = [tf.shape(data)[0], 1] + return { + col_name + '_mean': tf.broadcast_to(mean_var[0], shape), + col_name + '_var': tf.broadcast_to(mean_var[1], shape), + } + + def __str__(self): + return "scale_to_z_score" + + +class Scale_To_0_1(TFTOperation): + def __init__( + self, + columns: List[str], + elementwise: bool = False, + name: Optional[str] = None, + **kwargs): + """ + This function applies a scaling transformation on the given columns + of incoming data. The transformation scales the input values to the + range [0, 1] by dividing each value by the maximum value in the + column. + + Args: + columns: A list of column names to apply the transformation on. + elementwise: If True, the transformation is applied elementwise. Review Comment: What does "elementwise" mean? ########## sdks/python/apache_beam/ml/transforms/handlers.py: ########## @@ -0,0 +1,406 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import typing +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransformOutput +from apache_beam.ml.transforms.base import ProcessHandler +from apache_beam.ml.transforms.base import ProcessInputT +from apache_beam.ml.transforms.base import ProcessOutputT +from apache_beam.ml.transforms.tft_transforms import _TFTOperation +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.typehints import native_type_compatibility +from apache_beam.typehints.row_type import RowTypeConstraint +import tensorflow as tf +import tensorflow_transform.beam as tft_beam +from tensorflow_transform import common_types +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.tf_metadata import dataset_metadata +from tensorflow_transform.tf_metadata import schema_utils + +__all__ = [ + 'TFTProcessHandlerDict', +] + +# tensorflow transform doesn't support the types other than tf.int64, +# tf.float32 and tf.string. +_default_type_to_tensor_type_map = { + int: tf.int64, + float: tf.float32, + str: tf.string, + bytes: tf.string, + np.int64: tf.int64, + np.int32: tf.int64, + np.float32: tf.float32, + np.float64: tf.float32, + np.bytes_: tf.string, + np.str_: tf.string, +} + +tft_process_handler_dict_input_type = typing.Union[typing.NamedTuple, beam.Row] + + +class ConvertNamedTupleToDict( + beam.PTransform[beam.PCollection[tft_process_handler_dict_input_type], + beam.PCollection[Dict[str, + common_types.InstanceDictType]]]): + """ + A PTransform that converts a collection of NamedTuples or Rows into a + collection of dictionaries. + """ + def expand( + self, pcoll: beam.PCollection[tft_process_handler_dict_input_type] + ) -> beam.PCollection[common_types.InstanceDictType]: + """ + Args: + pcoll: A PCollection of NamedTuples or Rows. + Returns: + A PCollection of dictionaries. + """ + if isinstance(pcoll.element_type, RowTypeConstraint): + # Row instance + return pcoll | beam.Map(lambda x: x.asdict()) + else: + # named tuple + return pcoll | beam.Map(lambda x: x._asdict()) + + +# TODO: Add metrics namespace. +class TFTProcessHandler(ProcessHandler[ProcessInputT, ProcessOutputT]): + def __init__( + self, + *, + input_types: Optional[Dict[str, type]] = None, + output_record_batches=False, + transforms: List[_TFTOperation] = None, + namespace: str = 'TFTProcessHandler', + ): + """ + A handler class for processing data with TensorFlow Transform (TFT) + operations. This class is intended to be subclassed, with subclasses + implementing the `preprocessing_fn` method. + + Args: + input_types: A dictionary of column names and types. + output_record_batches: Whether to output RecordBatches instead of + dictionaries. + transforms: A list of transforms to apply to the data. All the transforms + are applied in the order they are specified. The input of the + i-th transform is the output of the (i-1)-th transform. Multi-input + transforms are not supported yet. + namespace: A metrics namespace for the TFTProcessHandler. + """ + super().__init__() + self._input_types = input_types + self.transforms = transforms if transforms else [] + self._input_types = input_types + self._output_record_batches = output_record_batches + self._artifact_location = None + self._namespace = namespace + + def get_raw_data_feature_spec( + self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + """ + Return a DatasetMetadata object to be used with + tft_beam.AnalyzeAndTransformDataset. + Args: + input_types: A dictionary of column names and types. + Returns: + A DatasetMetadata object. + """ + raw_data_feature_spec = {} + for key, value in input_types.items(): + raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column( + typ=value, col_name=key) + raw_data_metadata = dataset_metadata.DatasetMetadata( + schema_utils.schema_from_feature_spec(raw_data_feature_spec)) + return raw_data_metadata + + def _get_raw_data_feature_spec_per_column(self, typ: type, col_name: str): + """ + Return a FeatureSpec object to be used with + tft_beam.AnalyzeAndTransformDataset + Args: + typ: A type of the column. + col_name: A name of the column. + Returns: + A FeatureSpec object. + """ + # lets conver the builtin types to typing types for consistency. + typ = native_type_compatibility.convert_builtin_to_typing(typ) + containers_type = (List._name, Tuple._name) + is_container = hasattr(typ, '_name') and typ._name in containers_type + + if is_container: + dtype = typing.get_args(typ)[0] + if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union: + raise RuntimeError( + f"Incorrect type specifications in {typ} for column {col_name}. " + f"Please specify a single type.") + if dtype not in _default_type_to_tensor_type_map: + raise TypeError( + f"Unable to identify type: {dtype} specified on column: {col_name}" + f". Please specify a valid type.") + else: + dtype = typ + + is_container = is_container or issubclass(dtype, np.generic) + if is_container: + return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype]) + else: + return tf.io.FixedLenFeature([], _default_type_to_tensor_type_map[dtype]) + + def get_metadata(self, input_types: Dict[str, type]): + """ + Return metadata to be used with tft_beam.AnalyzeAndTransformDataset + Args: + input_types: A dictionary of column names and types. + """ + raise NotImplementedError + + def write_transform_artifacts(self, transform_fn, location): + """ + Write transform artifacts to the given location. + Args: + transform_fn: A transform_fn object. + location: A location to write the artifacts. + Returns: + A PCollection of WriteTransformFn writing a TF transform graph. + """ + return ( + transform_fn + | 'Write Transform Artifacts' >> + transform_fn_io.WriteTransformFn(location)) + + def infer_output_type(self, input_type): + if not isinstance(input_type, RowTypeConstraint): + row_type = RowTypeConstraint.from_user_type(input_type) + fields = row_type._inner_types() + return Dict[str, Union[tuple(fields)]] + + def _get_artifact_location(self, pipeline: beam.Pipeline): Review Comment: I don't think the staging/temp directory makes sense as a default--if there are artifacts to be produced/consumed, this should be a required argument. ########## sdks/python/apache_beam/ml/transforms/handlers.py: ########## @@ -0,0 +1,431 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import collections +import logging +import tempfile +import typing +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +import numpy as np + +import apache_beam as beam +from apache_beam.ml.transforms.base import ProcessHandler +from apache_beam.ml.transforms.base import ProcessInputT +from apache_beam.ml.transforms.base import ProcessOutputT +from apache_beam.ml.transforms.tft_transforms import TFTOperation +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.transforms.window import GlobalWindows +from apache_beam.typehints import native_type_compatibility +from apache_beam.typehints.row_type import RowTypeConstraint +import tensorflow as tf +from tensorflow_metadata.proto.v0 import schema_pb2 +import tensorflow_transform.beam as tft_beam +from tensorflow_transform import common_types +from tensorflow_transform.beam.tft_beam_io import beam_metadata_io +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.tf_metadata import dataset_metadata +from tensorflow_transform.tf_metadata import schema_utils + +__all__ = [ + 'TFTProcessHandlerSchema', +] + +# tensorflow transform doesn't support the types other than tf.int64, +# tf.float32 and tf.string. +_default_type_to_tensor_type_map = { + int: tf.int64, + float: tf.float32, + str: tf.string, + bytes: tf.string, + np.int64: tf.int64, + np.int32: tf.int64, + np.float32: tf.float32, + np.float64: tf.float32, + np.bytes_: tf.string, + np.str_: tf.string, +} +_primitive_types_to_typing_container_type = { + int: List[int], float: List[float], str: List[str], bytes: List[bytes] +} + +tft_process_handler_schema_input_type = typing.Union[typing.NamedTuple, + beam.Row] + + +class ConvertNamedTupleToDict( + beam.PTransform[beam.PCollection[tft_process_handler_schema_input_type], + beam.PCollection[Dict[str, + common_types.InstanceDictType]]]): + """ + A PTransform that converts a collection of NamedTuples or Rows into a + collection of dictionaries. + """ + def expand( + self, pcoll: beam.PCollection[tft_process_handler_schema_input_type] + ) -> beam.PCollection[common_types.InstanceDictType]: + """ + Args: + pcoll: A PCollection of NamedTuples or Rows. + Returns: + A PCollection of dictionaries. + """ + if isinstance(pcoll.element_type, RowTypeConstraint): + # Row instance + return pcoll | beam.Map(lambda x: x.as_dict()) + else: + # named tuple + return pcoll | beam.Map(lambda x: x._asdict()) + + +class TFTProcessHandler(ProcessHandler[ProcessInputT, ProcessOutputT]): + def __init__( + self, + *, + transforms: Optional[List[TFTOperation]] = None, + artifact_location: typing.Optional[str] = None, + preprocessing_fn: typing.Optional[typing.Callable] = None, + ): + """ + A handler class for processing data with TensorFlow Transform (TFT) + operations. This class is intended to be subclassed, with subclasses + implementing the `preprocessing_fn` method. + + Args: + transforms: A list of transforms to apply to the data. All the transforms + are applied in the order they are specified. The input of the + i-th transform is the output of the (i-1)-th transform. Multi-input + transforms are not supported yet. + artifact_location: A location to store the artifacts, which includes + the tensorflow graph produced by analyzers such as scale_to_0_1, + sclaed_to_z_score, etc. + Note: If not specified, the artifacts will be stored + in a temporary directory for DirectRunner and staging location for + DataflowRunner. + """ + self.transforms = transforms if transforms else [] + self.transformed_schema = None + self.artifact_location = artifact_location + self.preprocessing_fn = preprocessing_fn + + def append_transform(self, transform): + self.transforms.append(transform) + + def get_raw_data_feature_spec( + self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + """ + Return a DatasetMetadata object to be used with + tft_beam.AnalyzeAndTransformDataset. + Args: + input_types: A dictionary of column names and types. + Returns: + A DatasetMetadata object. + """ + raw_data_feature_spec = {} + for key, value in input_types.items(): + raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column( + typ=value, col_name=key) + raw_data_metadata = dataset_metadata.DatasetMetadata( + schema_utils.schema_from_feature_spec(raw_data_feature_spec)) + return raw_data_metadata + + def _get_raw_data_feature_spec_per_column(self, typ: type, col_name: str): + """ + Return a FeatureSpec object to be used with + tft_beam.AnalyzeAndTransformDataset + Args: + typ: A type of the column. + col_name: A name of the column. + Returns: + A FeatureSpec object. + """ + # lets conver the builtin types to typing types for consistency. + typ = native_type_compatibility.convert_builtin_to_typing(typ) + primitive_containers_type = ( + list, + collections.abc.Sequence, + ) + is_primitive_container = ( + typing.get_origin(typ) in primitive_containers_type) + + if is_primitive_container: + dtype = typing.get_args(typ)[0] # type: ignore[attr-defined] + if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union: # type: ignore[attr-defined] + raise RuntimeError( + f"Union type is not supported for column: {col_name}. " + f"Please pass a PCollection with valid schema for column " + f"{col_name} by passing a single type " + "in container. For example, List[int].") + elif issubclass(typ, np.generic) or typ in _default_type_to_tensor_type_map: + dtype = typ + else: + raise TypeError( + f"Unable to identify type: {typ} specified on column: {col_name}. " + f"Please provide a valid type from the following: " + f"{_default_type_to_tensor_type_map.keys()}") + return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype]) + + def get_raw_data_metadata(self, input_types: Dict[str, type]): + """ + Return metadata to be used with tft_beam.AnalyzeAndTransformDataset + Args: + input_types: A dictionary of column names and types. + """ + raise NotImplementedError + + def write_transform_artifacts(self, transform_fn, location): + """ + Write transform artifacts to the given location. + Args: + transform_fn: A transform_fn object. + location: A location to write the artifacts. + Returns: + A PCollection of WriteTransformFn writing a TF transform graph. + """ + return ( + transform_fn + | 'Write Transform Artifacts' >> + transform_fn_io.WriteTransformFn(location)) + + def _get_artifact_location(self, pipeline: beam.Pipeline): + """ + Return the artifact location. If the pipeline options has staging location + set, then we will use that as the artifact location. Otherwise, we will + create a temporary directory and use that as the artifact location. + Args: + pipeline: A beam pipeline object. + Returns: + A location to write the artifacts. + """ + # let us get the staging location from the pipeline options + # and initialize it as the artifact location. + staging_location = pipeline.options.view_as( + GoogleCloudOptions).staging_location + if not staging_location: + return tempfile.mkdtemp() + else: + return staging_location + + def process_data_fn( + self, inputs: Dict[str, common_types.ConsistentTensorType] + ) -> Dict[str, common_types.ConsistentTensorType]: + """ + A preprocessing_fn which should be implemented by subclasses + of TFTProcessHandlers. In this method, tft data transforms + such as scale_0_to_1 functions are called. + Args: + inputs: A dictionary of column names and associated data. + """ + raise NotImplementedError + + def _fail_on_non_gloabl_window(self, pcoll): + window_fn = pcoll.windowing.windowfn + if not isinstance(window_fn, GlobalWindows): Review Comment: Do we support non-trivial triggering? If not (e.g. we want to do a global combine) we should use `pcoll.windowing.is_default()`. ########## sdks/python/apache_beam/ml/transforms/tft_transforms.py: ########## @@ -0,0 +1,416 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module defines a set of data processing transforms that can be used +to perform common data transformations on a dataset. These transforms are +implemented using the TensorFlow Transform (TFT) library. The transforms +in this module are intended to be used in conjunction with the +beam.ml.MLTransform class, which provides a convenient interface for +applying a sequence of data processing transforms to a dataset with the +help of the ProcessHandler class. + +See the documentation for beam.ml.MLTransform for more details. + +Since the transforms in this module are implemented using TFT, they +should be wrapped inside a TFTProcessHandler object before being passed +to the beam.ml.MLTransform class. The ProcessHandler will let MLTransform +know which type of input is expected and infers the relevant schema required +for the TFT library. + +Note: The data processing transforms defined in this module don't +perform the transformation immediately. Instead, it returns a +configured operation object, which encapsulates the details of the +transformation. The actual computation takes place later in the Apache Beam +pipeline, after all transformations are set up and the pipeline is run. +""" + +import logging +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Union + +from apache_beam.ml.transforms.base import BaseOperation +import tensorflow as tf +import tensorflow_transform as tft +from tensorflow_transform import analyzers +from tensorflow_transform import common_types +from tensorflow_transform import tf_utils + +__all__ = [ + 'ComputeAndApplyVocabulary', + 'Scale_To_ZScore', + 'Scale_To_0_1', + 'ApplyBuckets', + 'Bucketize' +] + + +class TFTOperation(BaseOperation): + def __init__(self, columns: List[str], **kwargs): + """ + Base Opertation class for all the TFT operations. + """ + self.columns = columns + self._kwargs = kwargs + + if not columns: + raise RuntimeError( + "Columns are not specified. Please specify the column for the " + " op %s" % self) + + def validate_args(self): + raise NotImplementedError + + def get_artifacts(self, data: common_types.TensorType, + col_name) -> Optional[Dict[str, tf.Tensor]]: + return None + + +class ComputeAndApplyVocabulary(TFTOperation): + def __init__( + self, + columns: List[str], + *, + default_value: Any = -1, + top_k: Optional[int] = None, + frequency_threshold: Optional[int] = None, + num_oov_buckets: int = 0, + vocab_filename: Optional[str] = None, + name: Optional[str] = None, + **kwargs): + """ + This function computes the vocabulary for the given columns of incoming + data. The transformation converts the input values to indices of the + vocabulary. + + Args: + columns: List of column names to apply the transformation. + default_value: (Optional) The value to use for out-of-vocabulary values. + top_k: (Optional) The number of most frequent tokens to keep. + frequency_threshold: (Optional) Limit the generated vocabulary only to + elements whose absolute frequency is >= to the supplied threshold. + If set to None, the full vocabulary is generated. + num_oov_buckets: Any lookup of an out-of-vocabulary token will return a + bucket ID based on its hash if `num_oov_buckets` is greater than zero. + Otherwise it is assigned the `default_value`. + vocab_filename: The file name for the vocabulary file. If None, + a name based on the scope name in the context of this graph will + be used as the file name. If not None, should be unique within + a given preprocessing function. + NOTE in order to make your pipelines resilient to implementation + details please set `vocab_filename` when you are using + the vocab_filename on a downstream component. + """ + super().__init__(columns, **kwargs) + self._default_value = default_value + self._top_k = top_k + self._frequency_threshold = frequency_threshold + self._num_oov_buckets = num_oov_buckets + self._vocab_filename = vocab_filename + self._name = name + + def apply(self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: + # TODO: Pending outputting artifact. + return { + output_column_name: tft.compute_and_apply_vocabulary( + x=data, **self._kwargs) + } + + def __str__(self): + return "compute_and_apply_vocabulary" + + +class Scale_To_ZScore(TFTOperation): Review Comment: Class names should generally be CamelCase, not Camel_And_SnakeCase. ########## sdks/python/apache_beam/ml/transforms/handlers.py: ########## @@ -0,0 +1,431 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import collections +import logging +import tempfile +import typing +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +import numpy as np + +import apache_beam as beam +from apache_beam.ml.transforms.base import ProcessHandler +from apache_beam.ml.transforms.base import ProcessInputT +from apache_beam.ml.transforms.base import ProcessOutputT +from apache_beam.ml.transforms.tft_transforms import TFTOperation +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.transforms.window import GlobalWindows +from apache_beam.typehints import native_type_compatibility +from apache_beam.typehints.row_type import RowTypeConstraint +import tensorflow as tf +from tensorflow_metadata.proto.v0 import schema_pb2 +import tensorflow_transform.beam as tft_beam +from tensorflow_transform import common_types +from tensorflow_transform.beam.tft_beam_io import beam_metadata_io +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.tf_metadata import dataset_metadata +from tensorflow_transform.tf_metadata import schema_utils + +__all__ = [ + 'TFTProcessHandlerSchema', +] + +# tensorflow transform doesn't support the types other than tf.int64, +# tf.float32 and tf.string. +_default_type_to_tensor_type_map = { + int: tf.int64, + float: tf.float32, + str: tf.string, + bytes: tf.string, + np.int64: tf.int64, + np.int32: tf.int64, + np.float32: tf.float32, + np.float64: tf.float32, + np.bytes_: tf.string, + np.str_: tf.string, +} +_primitive_types_to_typing_container_type = { + int: List[int], float: List[float], str: List[str], bytes: List[bytes] +} + +tft_process_handler_schema_input_type = typing.Union[typing.NamedTuple, + beam.Row] + + +class ConvertNamedTupleToDict( + beam.PTransform[beam.PCollection[tft_process_handler_schema_input_type], + beam.PCollection[Dict[str, + common_types.InstanceDictType]]]): + """ + A PTransform that converts a collection of NamedTuples or Rows into a + collection of dictionaries. + """ + def expand( + self, pcoll: beam.PCollection[tft_process_handler_schema_input_type] + ) -> beam.PCollection[common_types.InstanceDictType]: + """ + Args: + pcoll: A PCollection of NamedTuples or Rows. + Returns: + A PCollection of dictionaries. + """ + if isinstance(pcoll.element_type, RowTypeConstraint): + # Row instance + return pcoll | beam.Map(lambda x: x.as_dict()) + else: + # named tuple + return pcoll | beam.Map(lambda x: x._asdict()) + + +class TFTProcessHandler(ProcessHandler[ProcessInputT, ProcessOutputT]): + def __init__( + self, + *, + transforms: Optional[List[TFTOperation]] = None, + artifact_location: typing.Optional[str] = None, + preprocessing_fn: typing.Optional[typing.Callable] = None, + ): + """ + A handler class for processing data with TensorFlow Transform (TFT) + operations. This class is intended to be subclassed, with subclasses + implementing the `preprocessing_fn` method. + + Args: + transforms: A list of transforms to apply to the data. All the transforms + are applied in the order they are specified. The input of the + i-th transform is the output of the (i-1)-th transform. Multi-input + transforms are not supported yet. + artifact_location: A location to store the artifacts, which includes + the tensorflow graph produced by analyzers such as scale_to_0_1, + sclaed_to_z_score, etc. + Note: If not specified, the artifacts will be stored + in a temporary directory for DirectRunner and staging location for + DataflowRunner. + """ + self.transforms = transforms if transforms else [] + self.transformed_schema = None + self.artifact_location = artifact_location + self.preprocessing_fn = preprocessing_fn + + def append_transform(self, transform): + self.transforms.append(transform) + + def get_raw_data_feature_spec( + self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + """ + Return a DatasetMetadata object to be used with + tft_beam.AnalyzeAndTransformDataset. + Args: + input_types: A dictionary of column names and types. + Returns: + A DatasetMetadata object. + """ + raw_data_feature_spec = {} + for key, value in input_types.items(): + raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column( + typ=value, col_name=key) + raw_data_metadata = dataset_metadata.DatasetMetadata( + schema_utils.schema_from_feature_spec(raw_data_feature_spec)) + return raw_data_metadata + + def _get_raw_data_feature_spec_per_column(self, typ: type, col_name: str): + """ + Return a FeatureSpec object to be used with + tft_beam.AnalyzeAndTransformDataset + Args: + typ: A type of the column. + col_name: A name of the column. + Returns: + A FeatureSpec object. + """ + # lets conver the builtin types to typing types for consistency. + typ = native_type_compatibility.convert_builtin_to_typing(typ) + primitive_containers_type = ( + list, + collections.abc.Sequence, + ) + is_primitive_container = ( + typing.get_origin(typ) in primitive_containers_type) + + if is_primitive_container: + dtype = typing.get_args(typ)[0] # type: ignore[attr-defined] + if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union: # type: ignore[attr-defined] + raise RuntimeError( + f"Union type is not supported for column: {col_name}. " + f"Please pass a PCollection with valid schema for column " + f"{col_name} by passing a single type " + "in container. For example, List[int].") + elif issubclass(typ, np.generic) or typ in _default_type_to_tensor_type_map: + dtype = typ + else: + raise TypeError( + f"Unable to identify type: {typ} specified on column: {col_name}. " + f"Please provide a valid type from the following: " + f"{_default_type_to_tensor_type_map.keys()}") + return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype]) + + def get_raw_data_metadata(self, input_types: Dict[str, type]): + """ + Return metadata to be used with tft_beam.AnalyzeAndTransformDataset + Args: + input_types: A dictionary of column names and types. + """ + raise NotImplementedError + + def write_transform_artifacts(self, transform_fn, location): + """ + Write transform artifacts to the given location. + Args: + transform_fn: A transform_fn object. + location: A location to write the artifacts. + Returns: + A PCollection of WriteTransformFn writing a TF transform graph. + """ + return ( + transform_fn + | 'Write Transform Artifacts' >> + transform_fn_io.WriteTransformFn(location)) + + def _get_artifact_location(self, pipeline: beam.Pipeline): + """ + Return the artifact location. If the pipeline options has staging location + set, then we will use that as the artifact location. Otherwise, we will + create a temporary directory and use that as the artifact location. + Args: + pipeline: A beam pipeline object. + Returns: + A location to write the artifacts. + """ + # let us get the staging location from the pipeline options + # and initialize it as the artifact location. + staging_location = pipeline.options.view_as( + GoogleCloudOptions).staging_location + if not staging_location: + return tempfile.mkdtemp() + else: + return staging_location + + def process_data_fn( + self, inputs: Dict[str, common_types.ConsistentTensorType] + ) -> Dict[str, common_types.ConsistentTensorType]: + """ + A preprocessing_fn which should be implemented by subclasses + of TFTProcessHandlers. In this method, tft data transforms + such as scale_0_to_1 functions are called. + Args: + inputs: A dictionary of column names and associated data. + """ + raise NotImplementedError + + def _fail_on_non_gloabl_window(self, pcoll): + window_fn = pcoll.windowing.windowfn + if not isinstance(window_fn, GlobalWindows): + raise RuntimeError( + "TFTProcessHandler only supports GlobalWindows. " + "Please use beam.WindowInto(beam.transforms.window.GlobalWindows()) " + "to convert your PCollection to GlobalWindow.") + + +class TFTProcessHandlerSchema( + TFTProcessHandler[tft_process_handler_schema_input_type, beam.Row]): + """ + A subclass of TFTProcessHandler specifically for handling + data in beam.Row or NamedTuple format. + TFTProcessHandlerSchema creates a beam graph that applies + TensorFlow Transform (TFT) operations to the input data and + outputs a beam.Row object containing the transformed data as numpy arrays. + + This only works on the Schema'd PCollection. Please refer to + https://beam.apache.org/documentation/programming-guide/#schemas + for more information on Schema'd PCollection. + + Currently, there are two ways to define a schema for a PCollection: + + 1) Register a `typing.NamedTuple` type to use RowCoder, and specify it as + the output type. For example:: + + Purchase = typing.NamedTuple('Purchase', + [('item_name', unicode), ('price', float)]) + coders.registry.register_coder(Purchase, coders.RowCoder) + with Pipeline() as p: + purchases = (p | beam.Create(...) + | beam.Map(..).with_output_types(Purchase)) + + 2) Produce `beam.Row` instances. Note this option will fail if Beam is + unable to infer data types for any of the fields. For example:: + + with Pipeline() as p: + purchases = (p | beam.Create(...) + | beam.Map(lambda x: beam.Row(item_name=unicode(..), + price=float(..)))) + In the schema, TFTProcessHandlerSchema accepts the following types: + 1. Primitive types: int, float, str, bytes + 2. List of the primitive types. + 3. Numpy arrays. + + For any other types, TFTProcessHandler will raise a TypeError. + """ + def _map_column_names_to_types(self, element_type): + """ + Return a dictionary of column names and types. + Args: + element_type: A type of the element. This could be a NamedTuple or a Row. + Returns: + A dictionary of column names and types. + """ + + if not isinstance(element_type, RowTypeConstraint): + row_type = RowTypeConstraint.from_user_type(element_type) + if not row_type: + raise TypeError( + "Element type must be compatible with Beam Schemas (" + "https://beam.apache.org/documentation/programming-guide/#schemas)" + " for to use with MLTransform and TFTProcessHandlerSchema.") + else: + row_type = element_type + inferred_types = {name: typ for name, typ in row_type._fields} + + for k, t in inferred_types.items(): + if t in _primitive_types_to_typing_container_type: + inferred_types[k] = _primitive_types_to_typing_container_type[t] + + # sometimes a numpy type can be provided as np.dtype('int64'). + # convert numpy.dtype to numpy type since both are same. + for name, typ in inferred_types.items(): + if isinstance(typ, np.dtype): + inferred_types[name] = typ.type + + return inferred_types + + def process_data_fn( + self, inputs: Dict[str, common_types.ConsistentTensorType] + ) -> Dict[str, common_types.ConsistentTensorType]: + """ + This method is used in the AnalyzeAndTransformDataset step. It applies + the transforms to the `inputs` in sequential order on the columns + provided for a given transform. + Args: + inputs: A dictionary of column names and data. + Returns: + A dictionary of column names and transformed data. + """ + outputs = inputs.copy() + for transform in self.transforms: + columns = transform.columns + for col in columns: + intermediate_result = transform.apply( + outputs[col], output_column_name=col) + for key, value in intermediate_result.items(): + outputs[key] = value + return outputs + + def _get_transformed_data_schema( + self, + metadata: dataset_metadata.DatasetMetadata, + ) -> Dict[str, typing.Sequence[typing.Union[np.float32, np.int64, bytes]]]: + schema = metadata._schema + transformed_types = {} + logging.info("Schema: %s", schema) + for feature in schema.feature: + name = feature.name + feature_type = feature.type + if feature_type == schema_pb2.FeatureType.FLOAT: + transformed_types[name] = typing.Sequence[np.float32] + elif feature_type == schema_pb2.FeatureType.INT: + transformed_types[name] = typing.Sequence[np.int64] + elif feature_type == schema_pb2.FeatureType.BYTES: + transformed_types[name] = typing.Sequence[bytes] + else: + # TODO: This else condition won't be hit since TFT doesn't output + # other than float, int and bytes. Refactor the code here. + raise RuntimeError( + 'Unsupported feature type: %s encountered' % feature_type) + logging.info(transformed_types) + return transformed_types + + def process_data( + self, pcoll: beam.PCollection[tft_process_handler_schema_input_type] + ) -> beam.PCollection[beam.Row]: + + self._fail_on_non_gloabl_window(pcoll) Review Comment: This should only be the case for artifact production. We want to support artifact consumption operations in streaming mode. ########## sdks/python/apache_beam/ml/transforms/tft_transforms_test.py: ########## @@ -0,0 +1,393 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List +from typing import NamedTuple + +import unittest +import numpy as np +from parameterized import parameterized + +import apache_beam as beam +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +# pylint: disable=wrong-import-order, wrong-import-position +try: + from apache_beam.ml.transforms import base + from apache_beam.ml.transforms import tft_transforms + from apache_beam.ml.transforms import handlers +except ImportError: + tft_transforms = None + +skip_if_tft_not_available = unittest.skipIf( + tft_transforms is None, 'tensorflow_transform is not installed.') + + +class MyTypesUnbatched(NamedTuple): + x: List[int] + + +class MyTypesBatched(NamedTuple): + x: List[int] + + +z_score_expected = {'x_mean': 3.5, 'x_var': 2.9166666666666665} + + +def assert_z_score_artifacts(element): + element = element.as_dict() + assert 'x_mean' in element + assert 'x_var' in element + assert element['x_mean'] == z_score_expected['x_mean'] + assert element['x_var'] == z_score_expected['x_var'] + + +def assert_scale_to_0_1_artifacts(element): + element = element.as_dict() + assert 'x_min' in element + assert 'x_max' in element + assert element['x_min'] == 1 + assert element['x_max'] == 6 + + +def assert_bucketize_artifacts(element): + element = element.as_dict() + assert 'x_quantiles' in element + assert np.array_equal( + element['x_quantiles'], np.array([3, 5], dtype=np.float32)) + + +@skip_if_tft_not_available +class ScaleZScoreTest(unittest.TestCase): + def test_z_score_unbatched(self): + unbatched_data = [{ + 'x': 1 + }, { + 'x': 2 + }, { + 'x': 3 + }, { + 'x': 4 + }, { + 'x': 5 + }, { + 'x': 6 + }] + + with beam.Pipeline() as p: + process_handler = handlers.TFTProcessHandlerSchema() + unbatched_result = ( + p + | "unbatchedCreate" >> beam.Create(unbatched_data) + | beam.Map(lambda x: MyTypesBatched(**x)).with_output_types( + MyTypesUnbatched) + | "unbatchedMLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( + tft_transforms.Scale_To_ZScore(columns=['x']))) + _ = (unbatched_result | beam.Map(assert_z_score_artifacts)) + + def test_z_score_batched(self): + batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + with beam.Pipeline() as p: + process_handler = handlers.TFTProcessHandlerSchema() + batched_result = ( + p + | "batchedCreate" >> beam.Create(batched_data) + | beam.Map(lambda x: MyTypesBatched(**x)).with_output_types( + MyTypesBatched) + | "batchedMLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( + tft_transforms.Scale_To_ZScore(columns=['x']))) + _ = (batched_result | beam.Map(assert_z_score_artifacts)) + + +@skip_if_tft_not_available +class ScaleTo01Test(unittest.TestCase): + def test_scale_to_0_1_batched(self): + batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + with beam.Pipeline() as p: + process_handler = handlers.TFTProcessHandlerSchema() + batched_result = ( + p + | "batchedCreate" >> beam.Create(batched_data) + | beam.Map(lambda x: MyTypesBatched(**x)).with_output_types( + MyTypesBatched) + | "batchedMLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( + tft_transforms.Scale_To_0_1(columns=['x']))) + _ = (batched_result | beam.Map(assert_scale_to_0_1_artifacts)) + + expected_output = [ + np.array([0, 0.2, 0.4], dtype=np.float32), + np.array([0.6, 0.8, 1], dtype=np.float32) + ] + actual_output = (batched_result | beam.Map(lambda x: x.x)) + assert_that( + actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + + def test_scale_to_0_1_unbatched(self): + unbatched_data = [{ + 'x': 1 + }, { + 'x': 2 + }, { + 'x': 3 + }, { + 'x': 4 + }, { + 'x': 5 + }, { + 'x': 6 + }] + with beam.Pipeline() as p: + process_handler = handlers.TFTProcessHandlerSchema() + unbatched_result = ( + p + | "unbatchedCreate" >> beam.Create(unbatched_data) + | beam.Map(lambda x: MyTypesBatched(**x)).with_output_types( + MyTypesUnbatched) + | "unbatchedMLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( + tft_transforms.Scale_To_0_1(columns=['x']))) + + _ = (unbatched_result | beam.Map(assert_scale_to_0_1_artifacts)) + expected_output = ( + np.array([0], dtype=np.float32), + np.array([0.2], dtype=np.float32), + np.array([0.4], dtype=np.float32), + np.array([0.6], dtype=np.float32), + np.array([0.8], dtype=np.float32), + np.array([1], dtype=np.float32)) + actual_output = (unbatched_result | beam.Map(lambda x: x.x)) + assert_that( + actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + + +@skip_if_tft_not_available +class BucketizeTest(unittest.TestCase): + def test_bucketize_unbatched(self): + unbatched = [{'x': 1}, {'x': 2}, {'x': 3}, {'x': 4}, {'x': 5}, {'x': 6}] + with beam.Pipeline() as p: + process_handler = handlers.TFTProcessHandlerSchema() + unbatched_result = ( + p + | "unbatchedCreate" >> beam.Create(unbatched) + | beam.Map(lambda x: MyTypesBatched(**x)).with_output_types( + MyTypesUnbatched) + | "unbatchedMLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( + tft_transforms.Bucketize(columns=['x'], num_buckets=3))) + _ = (unbatched_result | beam.Map(assert_bucketize_artifacts)) + + transformed_data = (unbatched_result | beam.Map(lambda x: x.x)) + expected_data = [ + np.array([0]), + np.array([0]), + np.array([1]), + np.array([1]), + np.array([2]), + np.array([2]) + ] + assert_that( + transformed_data, equal_to(expected_data, equals_fn=np.array_equal)) + + def test_bucketize_batched(self): + batched = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + with beam.Pipeline() as p: + process_handler = handlers.TFTProcessHandlerSchema() + batched_result = ( + p + | "batchedCreate" >> beam.Create(batched) + | beam.Map(lambda x: MyTypesBatched(**x)).with_output_types( + MyTypesBatched) + | "batchedMLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( + tft_transforms.Bucketize(columns=['x'], num_buckets=3))) + _ = (batched_result | beam.Map(assert_bucketize_artifacts)) + + transformed_data = ( + batched_result + | "TransformedColumnX" >> beam.Map(lambda ele: ele.x)) + expected_data = [ + np.array([0, 0, 1], dtype=np.int64), + np.array([1, 2, 2], dtype=np.int64) + ] + assert_that( + transformed_data, equal_to(expected_data, equals_fn=np.array_equal)) + + @parameterized.expand([ + (range(1, 10), [4, 7]), + (range(9, 0, -1), [4, 7]), + (range(19, 0, -1), [10]), + (range(1, 100), [25, 50, 75]), + # similar to the above but with odd number of elements + (range(1, 100, 2), [25, 51, 75]), + (range(99, 0, -1), range(10, 100, 10)) + ]) + def test_bucketize_boundaries(self, test_input, expected_boundaries): + # boundaries are outputted as artifacts for the Bucketize transform. + data = [{'x': [i]} for i in test_input] + num_buckets = len(expected_boundaries) + 1 + with beam.Pipeline() as p: + process_handler = handlers.TFTProcessHandlerSchema() + result = ( + p + | "Create" >> beam.Create(data) + | beam.Map(lambda x: MyTypesBatched(**x)).with_output_types( + MyTypesUnbatched) + | "MLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( + tft_transforms.Bucketize(columns=['x'], num_buckets=num_buckets))) + actual_boundaries = ( + result + | beam.Map(lambda x: x.as_dict()) + | beam.Map(lambda x: x['x_quantiles'])) + + def assert_boundaries(actual_boundaries): + assert np.array_equal(actual_boundaries, expected_boundaries) + + _ = (actual_boundaries | beam.Map(assert_boundaries)) + + +@skip_if_tft_not_available +class ApplyBucketsTest(unittest.TestCase): + @parameterized.expand([ + (range(1, 100), [25, 50, 75]), + (range(1, 100, 2), [25, 51, 75]), + ]) + def test_apply_buckets(self, test_inputs, bucket_boundaries): + with beam.Pipeline() as p: + data = [{'x': [i]} for i in test_inputs] + process_handler = handlers.TFTProcessHandlerSchema() + result = ( + p + | "Create" >> beam.Create(data) + | beam.Map(lambda x: MyTypesBatched(**x)).with_output_types( + MyTypesUnbatched) + | "MLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( + tft_transforms.ApplyBuckets( + columns=['x'], bucket_boundaries=bucket_boundaries))) + expected_output = [] + bucket = 0 + for x in sorted(test_inputs): + # Increment the bucket number when crossing the boundary + if (bucket < len(bucket_boundaries) and x >= bucket_boundaries[bucket]): + bucket += 1 + expected_output.append(np.array([bucket])) + + actual_output = (result | beam.Map(lambda x: x.x)) + assert_that( + actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + + +class ComputeAndVocabUnbatchedInputType(NamedTuple): + x: str + + +class ComputeAndVocabBatchedInputType(NamedTuple): + x: List[str] + + +@skip_if_tft_not_available +class ComputeAndApplyVocabTest(unittest.TestCase): + def test_compute_and_apply_vocabulary_unbatched_inputs(self): + batch_size = 100 + num_instances = batch_size + 1 + input_data = [{ + 'x': '%.10i' % i, # Front-padded to facilitate lexicographic sorting. + } for i in range(num_instances)] + + expected_data = [{ + 'x': (len(input_data) - 1) - i, # Due to reverse lexicographic sorting. + } for i in range(len(input_data))] + + with beam.Pipeline() as p: + process_handler = handlers.TFTProcessHandlerSchema() + actual_data = ( + p + | "Create" >> beam.Create(input_data) + | beam.Map(lambda x: ComputeAndVocabUnbatchedInputType(**x) + ).with_output_types(ComputeAndVocabUnbatchedInputType) + | "MLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( + tft_transforms.ComputeAndApplyVocabulary(columns=['x']))) + actual_data |= beam.Map(lambda x: x.as_dict()) + + assert_that(actual_data, equal_to(expected_data)) + + def test_compute_and_apply_vocabulary_batched(self): + batch_size = 100 + num_instances = batch_size + 1 + input_data = [ + { + 'x': ['%.10i' % i, '%.10i' % (i + 1), '%.10i' % (i + 2)], + # Front-padded to facilitate lexicographic sorting. + } for i in range(0, num_instances, 3) + ] + + # since we have 3 elements in a single batch, multiply with 3 for + # each iteration i on the expected output. + excepted_data = [ + np.array([(len(input_data) * 3 - 1) - i, + (len(input_data) * 3 - 1) - i - 1, + (len(input_data) * 3 - 1) - i - 2], + dtype=np.int64) # Front-padded to facilitate lexicographic + # sorting. + for i in range(0, len(input_data) * 3, 3) + ] + + with beam.Pipeline() as p: + process_handler = handlers.TFTProcessHandlerSchema() + result = ( + p + | "Create" >> beam.Create(input_data) + | beam.Map(lambda x: ComputeAndVocabBatchedInputType(**x) + ).with_output_types(ComputeAndVocabBatchedInputType) + | "MLTransform" >> + base.MLTransform(process_handler=process_handler).with_transform( Review Comment: What does "process_handler" mean? Is it an abstraction we need to expose to the user? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
