AnandInguva commented on code in PR #27544:
URL: https://github.com/apache/beam/pull/27544#discussion_r1270570970
##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -365,26 +446,38 @@ def process_data(
raw_data_metadata = metadata_io.read_metadata(
os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+ keyed_raw_data = (raw_data | beam.ParDo(ComputeAndAttachHashKey()))
+
+ feature_set = [feature.name for feature in
raw_data_metadata.schema.feature]
+ columns_not_in_schema_with_hash = (
+ keyed_raw_data
+ | beam.ParDo(GetMissingColumnsPColl(feature_set)))
+
# To maintain consistency by outputting numpy array all the time,
# whether a scalar value or list or np array is passed as input,
# we will convert scalar values to list values and TFT will ouput
# numpy array all the time.
- raw_data |= beam.ParDo(ConvertScalarValuesToListValues())
+ keyed_raw_data = keyed_raw_data | beam.ParDo(
+ ConvertScalarValuesToListValues())
+
+ raw_data_list = (keyed_raw_data | beam.ParDo(MakeHashKeyAsColumn()))
with tft_beam.Context(temp_dir=self.artifact_location):
- data = (raw_data, raw_data_metadata)
+ data = (raw_data_list, raw_data_metadata)
if self.artifact_mode == ArtifactMode.PRODUCE:
transform_fn = (
data
| "AnalyzeDataset" >>
tft_beam.AnalyzeDataset(self.process_data_fn))
+ # TODO: Remove the 'hash_key' column from the transformed
+ # dataset schema.
Review Comment:
We remove the hash key from the data but not the transformed schema that was
generated by TFT, written to the disk. we also remove this key from the
PCollection schema.
##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -119,6 +121,82 @@ def expand(
return pcoll | beam.Map(lambda x: x._asdict())
+class ComputeAndAttachHashKey(beam.DoFn):
+ """
+ Computues and attaches a hash key to the element.
+ Only for internal use. No backwards compatibility guarantees.
+ """
+ def process(self, element):
+ hash_object = hashlib.sha256()
+ for _, value in element.items():
+ # handle the case where value is a list or numpy array
+ if isinstance(value, (list, np.ndarray)):
+ hash_object.update(str(list(value)).encode())
+ else: # assume value is a primitive that can be turned into str
+ hash_object.update(str(value).encode())
+ yield (hash_object.hexdigest(), element)
+
+
+class GetMissingColumnsPColl(beam.DoFn):
+ """
+ Returns data containing only the columns that are not
+ present in the schema. This is needed since TFT only outputs
+ columns that are transformed by any of the data processing transforms.
+
+ Only for internal use. No backwards compatibility guarantees.
+ """
+ def __init__(self, existing_columns):
+ self.existing_columns = existing_columns
+
+ def process(self, element):
+ new_dict = {}
+ hash_key, element = element
+ for key, value in element.items():
+ if key not in self.existing_columns:
+ new_dict[key] = value
+ yield (hash_key, new_dict)
+
+
+class MakeHashKeyAsColumn(beam.DoFn):
+ """
+ Extracts the hash key from the element and adds it as a column.
+
+ Only for internal use. No backwards compatibility guarantees.
+ """
+ def process(self, element):
+ hash_key, element = element
+ element['hash_key'] = hash_key
+ yield element
+
+
+class ExtractHashAndKeyPColl(beam.DoFn):
+ """
+ Extracts the hash key and return hashkey and element as a tuple.
+
+ Only for internal use. No backwards compatibility guarantees.
+ """
+ def process(self, element):
+ hashkey = element['hash_key']
+ if isinstance(hashkey, np.ndarray):
Review Comment:
Removed it
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]