gemini-code-assist[bot] commented on code in PR #35756:
URL: https://github.com/apache/beam/pull/35756#discussion_r2246028438


##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -212,21 +215,31 @@
 
     # Get all columns for this item
     for col in columns:
-      result.append(item[col])
+      if isinstance(item, dict):
+        result.append(item[col])
   return result
 
 
 def _dict_output_fn(
     columns: Sequence[str],
-    batch: Sequence[Dict[str, Any]],
-    embeddings: Sequence[Any]) -> List[Dict[str, Any]]:
+    batch: Sequence[Union[Dict[str, Any], beam.Row]],
+    embeddings: Sequence[Any]) -> list[Union[dict[str, Any], beam.Row]]:
   """Map embeddings back to columns in batch."""
+  is_beam_row = False
+  if batch and hasattr(batch[0], '_asdict'):
+    is_beam_row = True
+    batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   Similar to `_dict_input_fn`, the `if hasattr(batch[0], '_asdict')` in this 
list comprehension filters elements from the batch. This can lead to silent 
data loss and incorrect results. Assuming the batch is homogeneous, this 
condition should be removed.
   
   ```python
   batch = [row._asdict() for row in batch]
   ```



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -183,20 +183,23 @@ def append_transform(self, transform: BaseOperation):
     """
 
 
-def _dict_input_fn(columns: Sequence[str],
-                   batch: Sequence[Dict[str, Any]]) -> List[str]:
+def _dict_input_fn(
+    columns: Sequence[str], batch: Sequence[Union[Dict[str, Any],
+                                                  beam.Row]]) -> List[str]:
   """Extract text from specified columns in batch."""
+  if batch and hasattr(batch[0], '_asdict'):
+    batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   The `if hasattr(batch[0], '_asdict')` in this list comprehension filters 
elements from the batch that don't have an `_asdict` method. This can lead to 
silent data loss if the batch is heterogeneous. Assuming the batch is 
homogeneous (all `beam.Row`s if the first element is one), this condition is 
unnecessary and potentially harmful. It's better to raise an error on a 
heterogeneous batch than to silently drop elements.
   
   ```python
   batch = [row._asdict() for row in batch]
   ```



##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -29,14 +32,37 @@
 from apache_beam.yaml import options
 from apache_beam.yaml.yaml_utils import SafeLineLoader
 
+
+def _list_submodules(package):
+  """
+    Lists all submodules within a given package.
+    """

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The docstring formatting is a bit off with the extra indentation. It can be 
a single line for conciseness.
   
   ```python
       """Lists all submodules within a given package."""
   ```



##########
sdks/python/apache_beam/yaml/yaml_ml_test.py:
##########
@@ -86,6 +86,38 @@ def test_ml_transform(self):
             equal_to([5]),
             label='CheckVocab')
 
+  def test_sentence_transformer_embedding(self):
+    SENTENCE_EMBEDDING_DIMENSION = 384
+    DATA = [{
+        'id': 1, 'log_message': "Error in module A"
+    }, {
+        'id': 2, 'log_message': "Warning in module B"
+    }, {
+        'id': 3, 'log_message': "Info in module C"
+    }]

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The changes in `base.py` add support for `beam.Row` inputs to embedding 
transforms. To ensure this new capability is tested, consider using `beam.Row` 
objects for the test data instead of dictionaries.
   
   ```python
       DATA = [
           beam.Row(id=1, log_message="Error in module A"),
           beam.Row(id=2, log_message="Warning in module B"),
           beam.Row(id=3, log_message="Info in module C"),
       ]
   ```



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -212,21 +215,31 @@
 
     # Get all columns for this item
     for col in columns:
-      result.append(item[col])
+      if isinstance(item, dict):
+        result.append(item[col])
   return result
 
 
 def _dict_output_fn(
     columns: Sequence[str],
-    batch: Sequence[Dict[str, Any]],
-    embeddings: Sequence[Any]) -> List[Dict[str, Any]]:
+    batch: Sequence[Union[Dict[str, Any], beam.Row]],
+    embeddings: Sequence[Any]) -> list[Union[dict[str, Any], beam.Row]]:
   """Map embeddings back to columns in batch."""
+  is_beam_row = False
+  if batch and hasattr(batch[0], '_asdict'):
+    is_beam_row = True
+    batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
+
   result = []
   for batch_idx, item in enumerate(batch):
     for col_idx, col in enumerate(columns):
       embedding_idx = batch_idx * len(columns) + col_idx
-      item[col] = embeddings[embedding_idx]
+      if isinstance(item, dict):
+        item[col] = embeddings[embedding_idx]

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   After converting `beam.Row` objects to `dict`s, all items in the batch 
should be dictionaries. This defensive check adds complexity. If the conversion 
at line 231 is made robust, this check is not needed.
   
   ```python
   item[col] = embeddings[embedding_idx]
   ```



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -183,20 +183,23 @@
     """
 
 
-def _dict_input_fn(columns: Sequence[str],
-                   batch: Sequence[Dict[str, Any]]) -> List[str]:
+def _dict_input_fn(
+    columns: Sequence[str], batch: Sequence[Union[Dict[str, Any],
+                                                  beam.Row]]) -> List[str]:
   """Extract text from specified columns in batch."""
+  if batch and hasattr(batch[0], '_asdict'):
+    batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
+
   if not batch or not isinstance(batch[0], dict):
     raise TypeError(
         'Expected data to be dicts, got '
         f'{type(batch[0])} instead.')
-
   result = []
   expected_keys = set(batch[0].keys())
   expected_columns = set(columns)
   # Process one batch item at a time
   for item in batch:
-    item_keys = item.keys()
+    item_keys = item.keys() if isinstance(item, dict) else set()

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   After `beam.Row` objects are converted to `dict`s, all items in the batch 
should be dictionaries. This defensive check makes the code less clear. If the 
conversion at line 191 is made robust (i.e., it doesn't filter), this check 
becomes unnecessary. Consider simplifying this to `item_keys = item.keys()`.
   
   ```python
   item_keys = item.keys()
   ```



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -212,21 +215,31 @@
 
     # Get all columns for this item
     for col in columns:
-      result.append(item[col])
+      if isinstance(item, dict):
+        result.append(item[col])
   return result
 
 
 def _dict_output_fn(
     columns: Sequence[str],
-    batch: Sequence[Dict[str, Any]],
-    embeddings: Sequence[Any]) -> List[Dict[str, Any]]:
+    batch: Sequence[Union[Dict[str, Any], beam.Row]],
+    embeddings: Sequence[Any]) -> list[Union[dict[str, Any], beam.Row]]:
   """Map embeddings back to columns in batch."""
+  is_beam_row = False
+  if batch and hasattr(batch[0], '_asdict'):
+    is_beam_row = True
+    batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
+
   result = []
   for batch_idx, item in enumerate(batch):
     for col_idx, col in enumerate(columns):
       embedding_idx = batch_idx * len(columns) + col_idx
-      item[col] = embeddings[embedding_idx]
+      if isinstance(item, dict):
+        item[col] = embeddings[embedding_idx]
     result.append(item)
+
+  if is_beam_row:
+    result = [beam.Row(**item) for item in result if isinstance(item, dict)]

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Similar to the check on line 237, this `isinstance` check is defensive. If 
items in `result` are guaranteed to be dicts (which they should be after the 
processing loop), this check is redundant.
   
   ```python
   result = [beam.Row(**item) for item in result]
   ```



##########
sdks/python/apache_beam/yaml/yaml_ml_test.py:
##########
@@ -86,6 +86,38 @@
             equal_to([5]),
             label='CheckVocab')
 
+  def test_sentence_transformer_embedding(self):
+    SENTENCE_EMBEDDING_DIMENSION = 384
+    DATA = [{
+        'id': 1, 'log_message': "Error in module A"
+    }, {
+        'id': 2, 'log_message': "Warning in module B"
+    }, {
+        'id': 3, 'log_message': "Info in module C"
+    }]
+    ml_opts = beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle', yaml_experimental_features=['ML'])
+    with tempfile.TemporaryDirectory() as tempdir:
+      with beam.Pipeline(options=ml_opts) as p:
+        elements = p | beam.Create(DATA)
+        result = elements | YamlTransform(
+            f'''
+            type: MLTransform
+            config:
+              write_artifact_location: {tempdir}
+              transforms:
+                - type: SentenceTransformerEmbeddings
+                  config:
+                    model_name: all-MiniLM-L6-v2
+                    columns: [log_message]
+            ''')
+
+        # Perform a basic check to ensure that embeddings are generated
+        # and that the dimension of those embeddings is correct.
+        actual_output = result | beam.Map(lambda x: len(x['log_message']))

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   If the input data is changed to `beam.Row` objects as suggested previously, 
this should be updated to use attribute access (`x.log_message`) instead of 
dictionary key access.
   
   ```python
           actual_output = result | beam.Map(lambda x: len(x.log_message))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to