gemini-code-assist[bot] commented on code in PR #35756:
URL: https://github.com/apache/beam/pull/35756#discussion_r2246028438
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -212,21 +215,31 @@
# Get all columns for this item
for col in columns:
- result.append(item[col])
+ if isinstance(item, dict):
+ result.append(item[col])
return result
def _dict_output_fn(
columns: Sequence[str],
- batch: Sequence[Dict[str, Any]],
- embeddings: Sequence[Any]) -> List[Dict[str, Any]]:
+ batch: Sequence[Union[Dict[str, Any], beam.Row]],
+ embeddings: Sequence[Any]) -> list[Union[dict[str, Any], beam.Row]]:
"""Map embeddings back to columns in batch."""
+ is_beam_row = False
+ if batch and hasattr(batch[0], '_asdict'):
+ is_beam_row = True
+ batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
Review Comment:

Similar to `_dict_input_fn`, the `if hasattr(batch[0], '_asdict')` in this
list comprehension filters elements from the batch. This can lead to silent
data loss and incorrect results. Assuming the batch is homogeneous, this
condition should be removed.
```python
batch = [row._asdict() for row in batch]
```
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -183,20 +183,23 @@ def append_transform(self, transform: BaseOperation):
"""
-def _dict_input_fn(columns: Sequence[str],
- batch: Sequence[Dict[str, Any]]) -> List[str]:
+def _dict_input_fn(
+ columns: Sequence[str], batch: Sequence[Union[Dict[str, Any],
+ beam.Row]]) -> List[str]:
"""Extract text from specified columns in batch."""
+ if batch and hasattr(batch[0], '_asdict'):
+ batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
Review Comment:

The `if hasattr(batch[0], '_asdict')` in this list comprehension filters
elements from the batch that don't have an `_asdict` method. This can lead to
silent data loss if the batch is heterogeneous. Assuming the batch is
homogeneous (all `beam.Row`s if the first element is one), this condition is
unnecessary and potentially harmful. It's better to raise an error on a
heterogeneous batch than to silently drop elements.
```python
batch = [row._asdict() for row in batch]
```
##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -29,14 +32,37 @@
from apache_beam.yaml import options
from apache_beam.yaml.yaml_utils import SafeLineLoader
+
+def _list_submodules(package):
+ """
+ Lists all submodules within a given package.
+ """
Review Comment:

The docstring formatting is a bit off with the extra indentation. It can be
a single line for conciseness.
```python
"""Lists all submodules within a given package."""
```
##########
sdks/python/apache_beam/yaml/yaml_ml_test.py:
##########
@@ -86,6 +86,38 @@ def test_ml_transform(self):
equal_to([5]),
label='CheckVocab')
+ def test_sentence_transformer_embedding(self):
+ SENTENCE_EMBEDDING_DIMENSION = 384
+ DATA = [{
+ 'id': 1, 'log_message': "Error in module A"
+ }, {
+ 'id': 2, 'log_message': "Warning in module B"
+ }, {
+ 'id': 3, 'log_message': "Info in module C"
+ }]
Review Comment:

The changes in `base.py` add support for `beam.Row` inputs to embedding
transforms. To ensure this new capability is tested, consider using `beam.Row`
objects for the test data instead of dictionaries.
```python
DATA = [
beam.Row(id=1, log_message="Error in module A"),
beam.Row(id=2, log_message="Warning in module B"),
beam.Row(id=3, log_message="Info in module C"),
]
```
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -212,21 +215,31 @@
# Get all columns for this item
for col in columns:
- result.append(item[col])
+ if isinstance(item, dict):
+ result.append(item[col])
return result
def _dict_output_fn(
columns: Sequence[str],
- batch: Sequence[Dict[str, Any]],
- embeddings: Sequence[Any]) -> List[Dict[str, Any]]:
+ batch: Sequence[Union[Dict[str, Any], beam.Row]],
+ embeddings: Sequence[Any]) -> list[Union[dict[str, Any], beam.Row]]:
"""Map embeddings back to columns in batch."""
+ is_beam_row = False
+ if batch and hasattr(batch[0], '_asdict'):
+ is_beam_row = True
+ batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
+
result = []
for batch_idx, item in enumerate(batch):
for col_idx, col in enumerate(columns):
embedding_idx = batch_idx * len(columns) + col_idx
- item[col] = embeddings[embedding_idx]
+ if isinstance(item, dict):
+ item[col] = embeddings[embedding_idx]
Review Comment:

After converting `beam.Row` objects to `dict`s, all items in the batch
should be dictionaries. This defensive check adds complexity. If the conversion
at line 231 is made robust, this check is not needed.
```python
item[col] = embeddings[embedding_idx]
```
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -183,20 +183,23 @@
"""
-def _dict_input_fn(columns: Sequence[str],
- batch: Sequence[Dict[str, Any]]) -> List[str]:
+def _dict_input_fn(
+ columns: Sequence[str], batch: Sequence[Union[Dict[str, Any],
+ beam.Row]]) -> List[str]:
"""Extract text from specified columns in batch."""
+ if batch and hasattr(batch[0], '_asdict'):
+ batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
+
if not batch or not isinstance(batch[0], dict):
raise TypeError(
'Expected data to be dicts, got '
f'{type(batch[0])} instead.')
-
result = []
expected_keys = set(batch[0].keys())
expected_columns = set(columns)
# Process one batch item at a time
for item in batch:
- item_keys = item.keys()
+ item_keys = item.keys() if isinstance(item, dict) else set()
Review Comment:

After `beam.Row` objects are converted to `dict`s, all items in the batch
should be dictionaries. This defensive check makes the code less clear. If the
conversion at line 191 is made robust (i.e., it doesn't filter), this check
becomes unnecessary. Consider simplifying this to `item_keys = item.keys()`.
```python
item_keys = item.keys()
```
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -212,21 +215,31 @@
# Get all columns for this item
for col in columns:
- result.append(item[col])
+ if isinstance(item, dict):
+ result.append(item[col])
return result
def _dict_output_fn(
columns: Sequence[str],
- batch: Sequence[Dict[str, Any]],
- embeddings: Sequence[Any]) -> List[Dict[str, Any]]:
+ batch: Sequence[Union[Dict[str, Any], beam.Row]],
+ embeddings: Sequence[Any]) -> list[Union[dict[str, Any], beam.Row]]:
"""Map embeddings back to columns in batch."""
+ is_beam_row = False
+ if batch and hasattr(batch[0], '_asdict'):
+ is_beam_row = True
+ batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
+
result = []
for batch_idx, item in enumerate(batch):
for col_idx, col in enumerate(columns):
embedding_idx = batch_idx * len(columns) + col_idx
- item[col] = embeddings[embedding_idx]
+ if isinstance(item, dict):
+ item[col] = embeddings[embedding_idx]
result.append(item)
+
+ if is_beam_row:
+ result = [beam.Row(**item) for item in result if isinstance(item, dict)]
Review Comment:

Similar to the check on line 237, this `isinstance` check is defensive. If
items in `result` are guaranteed to be dicts (which they should be after the
processing loop), this check is redundant.
```python
result = [beam.Row(**item) for item in result]
```
##########
sdks/python/apache_beam/yaml/yaml_ml_test.py:
##########
@@ -86,6 +86,38 @@
equal_to([5]),
label='CheckVocab')
+ def test_sentence_transformer_embedding(self):
+ SENTENCE_EMBEDDING_DIMENSION = 384
+ DATA = [{
+ 'id': 1, 'log_message': "Error in module A"
+ }, {
+ 'id': 2, 'log_message': "Warning in module B"
+ }, {
+ 'id': 3, 'log_message': "Info in module C"
+ }]
+ ml_opts = beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle', yaml_experimental_features=['ML'])
+ with tempfile.TemporaryDirectory() as tempdir:
+ with beam.Pipeline(options=ml_opts) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ f'''
+ type: MLTransform
+ config:
+ write_artifact_location: {tempdir}
+ transforms:
+ - type: SentenceTransformerEmbeddings
+ config:
+ model_name: all-MiniLM-L6-v2
+ columns: [log_message]
+ ''')
+
+ # Perform a basic check to ensure that embeddings are generated
+ # and that the dimension of those embeddings is correct.
+ actual_output = result | beam.Map(lambda x: len(x['log_message']))
Review Comment:

If the input data is changed to `beam.Row` objects as suggested previously,
this should be updated to use attribute access (`x.log_message`) instead of
dictionary key access.
```python
actual_output = result | beam.Map(lambda x: len(x.log_message))
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]