gemini-code-assist[bot] commented on code in PR #38215: URL: https://github.com/apache/beam/pull/38215#discussion_r3312714084
########## sdks/python/apache_beam/examples/ml_transform/mltransform_generate_vocab.py: ########## @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Batch-only vocabulary generation pipeline using MLTransform. + +This pipeline creates a vocabulary artifact from one or more input columns. + +Key properties: +- Batch only (no streaming path). +- Vocabulary generation via MLTransform ComputeAndApplyVocabulary. +- Output format: one token per line. +""" + +import argparse +import json +import logging +import tempfile +from typing import Any + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary +from apache_beam.options.pipeline_options import PipelineOptions + + +def parse_bool_flag(value: str) -> bool: + value_lc = value.strip().lower() + if value_lc in ('1', 'true', 't', 'yes', 'y'): + return True + if value_lc in ('0', 'false', 'f', 'no', 'n'): + return False + raise ValueError( + f'Invalid boolean value {value!r}. Expected true/false style value.') + + +def normalize_text(value: Any, lowercase: bool = True) -> str: + if value is None: + return '' + text = str(value).strip() + if lowercase: + text = text.lower() + return text + + +def _parse_json_line(line: str) -> dict[str, Any]: + try: + parsed = json.loads(line) + except json.JSONDecodeError: + # Treat plain-text rows as values for the default "text" column. + return {'text': line} + if not isinstance(parsed, dict): + raise ValueError( + f'Input JSON line must decode to an object, got: {parsed!r}') + return parsed Review Comment:  If the input is a plain-text file containing lines that happen to be valid JSON primitives (such as `true`, `false`, `null`, or numbers like `123`), `json.loads` will successfully parse them. Since they are not dictionaries, the current implementation will raise a `ValueError` and crash the pipeline. We can make this more robust by falling back to treating the line as plain text if it does not decode to a dictionary. ```suggestion try: parsed = json.loads(line) if isinstance(parsed, dict): return parsed except json.JSONDecodeError: pass # Treat plain-text or non-object rows as values for the default "text" column. return {'text': line} ``` ########## sdks/python/apache_beam/examples/ml_transform/mltransform_generate_vocab.py: ########## @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Batch-only vocabulary generation pipeline using MLTransform. + +This pipeline creates a vocabulary artifact from one or more input columns. + +Key properties: +- Batch only (no streaming path). +- Vocabulary generation via MLTransform ComputeAndApplyVocabulary. +- Output format: one token per line. +""" + +import argparse +import json +import logging +import tempfile +from typing import Any + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary +from apache_beam.options.pipeline_options import PipelineOptions + + +def parse_bool_flag(value: str) -> bool: + value_lc = value.strip().lower() + if value_lc in ('1', 'true', 't', 'yes', 'y'): + return True + if value_lc in ('0', 'false', 'f', 'no', 'n'): + return False + raise ValueError( + f'Invalid boolean value {value!r}. Expected true/false style value.') + + +def normalize_text(value: Any, lowercase: bool = True) -> str: + if value is None: + return '' + text = str(value).strip() + if lowercase: + text = text.lower() + return text + + +def _parse_json_line(line: str) -> dict[str, Any]: + try: + parsed = json.loads(line) + except json.JSONDecodeError: + # Treat plain-text rows as values for the default "text" column. + return {'text': line} + if not isinstance(parsed, dict): + raise ValueError( + f'Input JSON line must decode to an object, got: {parsed!r}') + return parsed + + +def _extract_column_values(row: dict[str, Any], + columns: list[str]) -> list[str]: + values: list[str] = [] + for col in columns: + if col not in row: + continue + val = row[col] + if val is None: + continue + if isinstance(val, list): + values.extend(str(item) for item in val if item is not None) + else: + values.append(str(val)) + return values + + +def _build_vocab_text( + row: dict[str, Any], columns: list[str], lowercase: bool) -> str: + values = _extract_column_values(row, columns) + normalized_values = [ + normalize_text(value, lowercase=lowercase) for value in values + ] + non_empty_values = [value for value in normalized_values if value] + return ' '.join(non_empty_values) + + +def _resolve_vocab_asset_path( + artifact_location: str, vocab_filename: str, column_name: str) -> str: + asset_name = f'{vocab_filename}_{column_name}' + pattern = ( + f'{artifact_location.rstrip("/")}' + f'/*/transform_fn/assets/{asset_name}') + matches = FileSystems.match([pattern])[0].metadata_list + if not matches: + raise ValueError( + f'Could not locate vocabulary artifact {asset_name!r} under ' + f'{artifact_location!r}.') + return matches[0].path + + +def _read_vocab_tokens(vocab_asset_path: str) -> list[str]: + tokens = [] + with FileSystems.open(vocab_asset_path) as f: + for raw_line in f: + token = raw_line.decode('utf-8').rstrip('\n') + if token: + tokens.append(token) + return tokens + + +def _write_vocab_file(output_path: str, tokens: list[str]) -> None: + with FileSystems.create(output_path) as f: + for token in tokens: + f.write((token + '\n').encode('utf-8')) + + +def parse_known_args(argv): + parser = argparse.ArgumentParser( + description='Generate vocabulary from batch input with MLTransform.') + parser.add_argument('--input_file', help='Input JSONL file path.') + parser.add_argument( + '--input_table', + help='Input BigQuery table path in PROJECT:DATASET.TABLE format.') + parser.add_argument('--output_vocab', help='Output vocab file prefix/path.') + parser.add_argument( + '--columns', + help='Comma-separated source columns to include in vocabulary.') + parser.add_argument( + '--vocab_size', + type=int, + default=50000, + help='Maximum vocabulary size (top-K by frequency).') + parser.add_argument( + '--min_frequency', + type=int, + default=1, + help='Minimum token frequency required to keep token.') + parser.add_argument( + '--lowercase', + default='true', + help='Whether to lowercase text before vocabulary generation.') + parser.add_argument( + '--input_expand_factor', + type=int, + default=1, + help=( + 'Batch-only: repeat each input line this many times to scale volume ' + 'for load/perf testing.')) + parser.add_argument( + '--artifact_location', + default='', + help=( + 'Artifact directory for MLTransform output. If empty, a temporary ' + 'local directory is used.')) + return parser.parse_known_args(argv) + + +def validate_args(args) -> list[str]: + has_input_file = bool(args.input_file) + has_input_table = bool(args.input_table) + if not has_input_file and not has_input_table: + raise ValueError('One of --input_file or --input_table is required.') + if has_input_file and has_input_table: + raise ValueError('Use exactly one of --input_file or --input_table.') + if not args.output_vocab: + raise ValueError('--output_vocab is required.') + if not args.columns: + raise ValueError('--columns is required.') + if args.vocab_size is None or args.vocab_size <= 0: + raise ValueError('--vocab_size must be > 0.') + if args.min_frequency is None or args.min_frequency < 1: + raise ValueError('--min_frequency must be >= 1.') + if args.input_expand_factor is None or args.input_expand_factor < 1: + raise ValueError('--input_expand_factor must be >= 1.') + return [col.strip() for col in args.columns.split(',') if col.strip()] + + +def run(argv=None, test_pipeline=None): + known_args, pipeline_args = parse_known_args(argv) + columns = validate_args(known_args) + lowercase = parse_bool_flag(known_args.lowercase) + artifact_location = known_args.artifact_location or tempfile.mkdtemp( + prefix='mltransform_generate_vocab_artifacts_') + + options = PipelineOptions(pipeline_args) + pipeline = test_pipeline or beam.Pipeline(options=options) Review Comment:  When running on a distributed runner like `DataflowRunner`, a local temporary directory created via `tempfile.mkdtemp()` on the submission machine will not be accessible by the workers, leading to failures when writing or reading the vocabulary. If `--artifact_location` is not specified, we should fall back to a subdirectory under the pipeline's `--temp_location` if it is a remote path (e.g., GCS). ```python options = PipelineOptions(pipeline_args) artifact_location = known_args.artifact_location if not artifact_location: temp_location = options.get_all_options().get('temp_location') if temp_location and (temp_location.startswith('gs://') or temp_location.startswith('s3://')): artifact_location = f"{temp_location.rstrip('/')}/mltransform_vocab_artifacts" else: artifact_location = tempfile.mkdtemp(prefix='mltransform_generate_vocab_artifacts_') pipeline = test_pipeline or beam.Pipeline(options=options) ``` ########## sdks/python/apache_beam/examples/ml_transform/mltransform_generate_vocab.py: ########## @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Batch-only vocabulary generation pipeline using MLTransform. + +This pipeline creates a vocabulary artifact from one or more input columns. + +Key properties: +- Batch only (no streaming path). +- Vocabulary generation via MLTransform ComputeAndApplyVocabulary. +- Output format: one token per line. +""" + +import argparse +import json +import logging +import tempfile +from typing import Any + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary +from apache_beam.options.pipeline_options import PipelineOptions + + +def parse_bool_flag(value: str) -> bool: + value_lc = value.strip().lower() + if value_lc in ('1', 'true', 't', 'yes', 'y'): + return True + if value_lc in ('0', 'false', 'f', 'no', 'n'): + return False + raise ValueError( + f'Invalid boolean value {value!r}. Expected true/false style value.') + + +def normalize_text(value: Any, lowercase: bool = True) -> str: + if value is None: + return '' + text = str(value).strip() + if lowercase: + text = text.lower() + return text + + +def _parse_json_line(line: str) -> dict[str, Any]: + try: + parsed = json.loads(line) + except json.JSONDecodeError: + # Treat plain-text rows as values for the default "text" column. + return {'text': line} + if not isinstance(parsed, dict): + raise ValueError( + f'Input JSON line must decode to an object, got: {parsed!r}') + return parsed + + +def _extract_column_values(row: dict[str, Any], + columns: list[str]) -> list[str]: + values: list[str] = [] + for col in columns: + if col not in row: + continue + val = row[col] + if val is None: + continue + if isinstance(val, list): + values.extend(str(item) for item in val if item is not None) + else: + values.append(str(val)) + return values + + +def _build_vocab_text( + row: dict[str, Any], columns: list[str], lowercase: bool) -> str: + values = _extract_column_values(row, columns) + normalized_values = [ + normalize_text(value, lowercase=lowercase) for value in values + ] + non_empty_values = [value for value in normalized_values if value] + return ' '.join(non_empty_values) + + +def _resolve_vocab_asset_path( + artifact_location: str, vocab_filename: str, column_name: str) -> str: + asset_name = f'{vocab_filename}_{column_name}' + pattern = ( + f'{artifact_location.rstrip("/")}' + f'/*/transform_fn/assets/{asset_name}') + matches = FileSystems.match([pattern])[0].metadata_list + if not matches: + raise ValueError( + f'Could not locate vocabulary artifact {asset_name!r} under ' + f'{artifact_location!r}.') + return matches[0].path Review Comment:  `FileSystems.match` can raise a `BeamIOError` or other exceptions if the parent directory does not exist or is inaccessible, rather than returning an empty list. Wrapping this in a try-except block ensures a clean and informative error message is raised. ```python try: match_results = FileSystems.match([pattern]) matches = match_results[0].metadata_list if match_results else [] except Exception as e: raise ValueError( f'Could not locate vocabulary artifact {asset_name!r} under ' f'{artifact_location!r} due to error: {e}') from e if not matches: raise ValueError( f'Could not locate vocabulary artifact {asset_name!r} under ' f'{artifact_location!r}.') return matches[0].path ``` ########## sdks/python/apache_beam/examples/ml_transform/mltransform_generate_vocab.py: ########## @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Batch-only vocabulary generation pipeline using MLTransform. + +This pipeline creates a vocabulary artifact from one or more input columns. + +Key properties: +- Batch only (no streaming path). +- Vocabulary generation via MLTransform ComputeAndApplyVocabulary. +- Output format: one token per line. +""" + +import argparse +import json +import logging +import tempfile +from typing import Any + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary +from apache_beam.options.pipeline_options import PipelineOptions + + +def parse_bool_flag(value: str) -> bool: + value_lc = value.strip().lower() + if value_lc in ('1', 'true', 't', 'yes', 'y'): + return True + if value_lc in ('0', 'false', 'f', 'no', 'n'): + return False + raise ValueError( + f'Invalid boolean value {value!r}. Expected true/false style value.') + + +def normalize_text(value: Any, lowercase: bool = True) -> str: + if value is None: + return '' + text = str(value).strip() + if lowercase: + text = text.lower() + return text + + +def _parse_json_line(line: str) -> dict[str, Any]: + try: + parsed = json.loads(line) + except json.JSONDecodeError: + # Treat plain-text rows as values for the default "text" column. + return {'text': line} + if not isinstance(parsed, dict): + raise ValueError( + f'Input JSON line must decode to an object, got: {parsed!r}') + return parsed + + +def _extract_column_values(row: dict[str, Any], + columns: list[str]) -> list[str]: + values: list[str] = [] + for col in columns: + if col not in row: + continue + val = row[col] + if val is None: + continue + if isinstance(val, list): + values.extend(str(item) for item in val if item is not None) + else: + values.append(str(val)) + return values + + +def _build_vocab_text( + row: dict[str, Any], columns: list[str], lowercase: bool) -> str: + values = _extract_column_values(row, columns) + normalized_values = [ + normalize_text(value, lowercase=lowercase) for value in values + ] + non_empty_values = [value for value in normalized_values if value] + return ' '.join(non_empty_values) + + +def _resolve_vocab_asset_path( + artifact_location: str, vocab_filename: str, column_name: str) -> str: + asset_name = f'{vocab_filename}_{column_name}' + pattern = ( + f'{artifact_location.rstrip("/")}' + f'/*/transform_fn/assets/{asset_name}') + matches = FileSystems.match([pattern])[0].metadata_list + if not matches: + raise ValueError( + f'Could not locate vocabulary artifact {asset_name!r} under ' + f'{artifact_location!r}.') + return matches[0].path + + +def _read_vocab_tokens(vocab_asset_path: str) -> list[str]: + tokens = [] + with FileSystems.open(vocab_asset_path) as f: + for raw_line in f: + token = raw_line.decode('utf-8').rstrip('\n') Review Comment:  Using `rstrip('\n')` on a decoded string can leave trailing carriage returns (`\r`) if the file was written with Windows-style line endings. It is safer to use `rstrip('\r\n')`. ```suggestion token = raw_line.decode('utf-8').rstrip('\r\n') ``` ########## sdks/python/apache_beam/examples/ml_transform/mltransform_generate_vocab.py: ########## @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Batch-only vocabulary generation pipeline using MLTransform. + +This pipeline creates a vocabulary artifact from one or more input columns. + +Key properties: +- Batch only (no streaming path). +- Vocabulary generation via MLTransform ComputeAndApplyVocabulary. +- Output format: one token per line. +""" + +import argparse +import json +import logging +import tempfile +from typing import Any + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary +from apache_beam.options.pipeline_options import PipelineOptions + + +def parse_bool_flag(value: str) -> bool: + value_lc = value.strip().lower() + if value_lc in ('1', 'true', 't', 'yes', 'y'): + return True + if value_lc in ('0', 'false', 'f', 'no', 'n'): + return False + raise ValueError( + f'Invalid boolean value {value!r}. Expected true/false style value.') + + +def normalize_text(value: Any, lowercase: bool = True) -> str: + if value is None: + return '' + text = str(value).strip() + if lowercase: + text = text.lower() + return text + + +def _parse_json_line(line: str) -> dict[str, Any]: + try: + parsed = json.loads(line) + except json.JSONDecodeError: + # Treat plain-text rows as values for the default "text" column. + return {'text': line} + if not isinstance(parsed, dict): + raise ValueError( + f'Input JSON line must decode to an object, got: {parsed!r}') + return parsed + + +def _extract_column_values(row: dict[str, Any], + columns: list[str]) -> list[str]: + values: list[str] = [] + for col in columns: + if col not in row: + continue + val = row[col] + if val is None: + continue + if isinstance(val, list): + values.extend(str(item) for item in val if item is not None) + else: + values.append(str(val)) + return values + + +def _build_vocab_text( + row: dict[str, Any], columns: list[str], lowercase: bool) -> str: + values = _extract_column_values(row, columns) + normalized_values = [ + normalize_text(value, lowercase=lowercase) for value in values + ] + non_empty_values = [value for value in normalized_values if value] + return ' '.join(non_empty_values) + + +def _resolve_vocab_asset_path( + artifact_location: str, vocab_filename: str, column_name: str) -> str: + asset_name = f'{vocab_filename}_{column_name}' + pattern = ( + f'{artifact_location.rstrip("/")}' + f'/*/transform_fn/assets/{asset_name}') + matches = FileSystems.match([pattern])[0].metadata_list + if not matches: + raise ValueError( + f'Could not locate vocabulary artifact {asset_name!r} under ' + f'{artifact_location!r}.') + return matches[0].path + + +def _read_vocab_tokens(vocab_asset_path: str) -> list[str]: + tokens = [] + with FileSystems.open(vocab_asset_path) as f: + for raw_line in f: + token = raw_line.decode('utf-8').rstrip('\n') + if token: + tokens.append(token) + return tokens + + +def _write_vocab_file(output_path: str, tokens: list[str]) -> None: + with FileSystems.create(output_path) as f: + for token in tokens: + f.write((token + '\n').encode('utf-8')) + + +def parse_known_args(argv): + parser = argparse.ArgumentParser( + description='Generate vocabulary from batch input with MLTransform.') + parser.add_argument('--input_file', help='Input JSONL file path.') + parser.add_argument( + '--input_table', + help='Input BigQuery table path in PROJECT:DATASET.TABLE format.') + parser.add_argument('--output_vocab', help='Output vocab file prefix/path.') + parser.add_argument( + '--columns', + help='Comma-separated source columns to include in vocabulary.') + parser.add_argument( + '--vocab_size', + type=int, + default=50000, + help='Maximum vocabulary size (top-K by frequency).') + parser.add_argument( + '--min_frequency', + type=int, + default=1, + help='Minimum token frequency required to keep token.') + parser.add_argument( + '--lowercase', + default='true', + help='Whether to lowercase text before vocabulary generation.') + parser.add_argument( + '--input_expand_factor', + type=int, + default=1, + help=( + 'Batch-only: repeat each input line this many times to scale volume ' + 'for load/perf testing.')) + parser.add_argument( + '--artifact_location', + default='', + help=( + 'Artifact directory for MLTransform output. If empty, a temporary ' + 'local directory is used.')) + return parser.parse_known_args(argv) + + +def validate_args(args) -> list[str]: + has_input_file = bool(args.input_file) + has_input_table = bool(args.input_table) + if not has_input_file and not has_input_table: + raise ValueError('One of --input_file or --input_table is required.') + if has_input_file and has_input_table: + raise ValueError('Use exactly one of --input_file or --input_table.') + if not args.output_vocab: + raise ValueError('--output_vocab is required.') + if not args.columns: + raise ValueError('--columns is required.') + if args.vocab_size is None or args.vocab_size <= 0: + raise ValueError('--vocab_size must be > 0.') + if args.min_frequency is None or args.min_frequency < 1: + raise ValueError('--min_frequency must be >= 1.') + if args.input_expand_factor is None or args.input_expand_factor < 1: + raise ValueError('--input_expand_factor must be >= 1.') + return [col.strip() for col in args.columns.split(',') if col.strip()] Review Comment:  If duplicate columns are specified in the `--columns` argument (e.g., `--columns=text,text`), they will be processed multiple times. We can deduplicate the columns while preserving their original order. ```suggestion seen = set() deduped_cols = [] for col in (c.strip() for c in args.columns.split(',')): if col and col not in seen: seen.add(col) deduped_cols.append(col) return deduped_cols ``` ########## sdks/python/apache_beam/examples/ml_transform/mltransform_generate_vocab.py: ########## @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Batch-only vocabulary generation pipeline using MLTransform. + +This pipeline creates a vocabulary artifact from one or more input columns. + +Key properties: +- Batch only (no streaming path). +- Vocabulary generation via MLTransform ComputeAndApplyVocabulary. +- Output format: one token per line. +""" + +import argparse +import json +import logging +import tempfile +from typing import Any + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary +from apache_beam.options.pipeline_options import PipelineOptions + + +def parse_bool_flag(value: str) -> bool: + value_lc = value.strip().lower() + if value_lc in ('1', 'true', 't', 'yes', 'y'): + return True + if value_lc in ('0', 'false', 'f', 'no', 'n'): + return False + raise ValueError( + f'Invalid boolean value {value!r}. Expected true/false style value.') + + +def normalize_text(value: Any, lowercase: bool = True) -> str: + if value is None: + return '' + text = str(value).strip() + if lowercase: + text = text.lower() + return text + + +def _parse_json_line(line: str) -> dict[str, Any]: + try: + parsed = json.loads(line) + except json.JSONDecodeError: + # Treat plain-text rows as values for the default "text" column. + return {'text': line} + if not isinstance(parsed, dict): + raise ValueError( + f'Input JSON line must decode to an object, got: {parsed!r}') + return parsed + + +def _extract_column_values(row: dict[str, Any], + columns: list[str]) -> list[str]: + values: list[str] = [] + for col in columns: + if col not in row: + continue + val = row[col] + if val is None: + continue + if isinstance(val, list): + values.extend(str(item) for item in val if item is not None) + else: + values.append(str(val)) + return values + + +def _build_vocab_text( + row: dict[str, Any], columns: list[str], lowercase: bool) -> str: + values = _extract_column_values(row, columns) + normalized_values = [ + normalize_text(value, lowercase=lowercase) for value in values + ] + non_empty_values = [value for value in normalized_values if value] + return ' '.join(non_empty_values) + + +def _resolve_vocab_asset_path( + artifact_location: str, vocab_filename: str, column_name: str) -> str: + asset_name = f'{vocab_filename}_{column_name}' + pattern = ( + f'{artifact_location.rstrip("/")}' + f'/*/transform_fn/assets/{asset_name}') + matches = FileSystems.match([pattern])[0].metadata_list + if not matches: + raise ValueError( + f'Could not locate vocabulary artifact {asset_name!r} under ' + f'{artifact_location!r}.') + return matches[0].path + + +def _read_vocab_tokens(vocab_asset_path: str) -> list[str]: + tokens = [] + with FileSystems.open(vocab_asset_path) as f: + for raw_line in f: + token = raw_line.decode('utf-8').rstrip('\n') + if token: + tokens.append(token) + return tokens + + +def _write_vocab_file(output_path: str, tokens: list[str]) -> None: + with FileSystems.create(output_path) as f: + for token in tokens: + f.write((token + '\n').encode('utf-8')) + + +def parse_known_args(argv): + parser = argparse.ArgumentParser( + description='Generate vocabulary from batch input with MLTransform.') + parser.add_argument('--input_file', help='Input JSONL file path.') + parser.add_argument( + '--input_table', + help='Input BigQuery table path in PROJECT:DATASET.TABLE format.') + parser.add_argument('--output_vocab', help='Output vocab file prefix/path.') + parser.add_argument( + '--columns', + help='Comma-separated source columns to include in vocabulary.') + parser.add_argument( + '--vocab_size', + type=int, + default=50000, + help='Maximum vocabulary size (top-K by frequency).') + parser.add_argument( + '--min_frequency', + type=int, + default=1, + help='Minimum token frequency required to keep token.') + parser.add_argument( + '--lowercase', + default='true', + help='Whether to lowercase text before vocabulary generation.') + parser.add_argument( + '--input_expand_factor', + type=int, + default=1, + help=( + 'Batch-only: repeat each input line this many times to scale volume ' + 'for load/perf testing.')) + parser.add_argument( + '--artifact_location', + default='', + help=( + 'Artifact directory for MLTransform output. If empty, a temporary ' + 'local directory is used.')) + return parser.parse_known_args(argv) + + +def validate_args(args) -> list[str]: + has_input_file = bool(args.input_file) + has_input_table = bool(args.input_table) + if not has_input_file and not has_input_table: + raise ValueError('One of --input_file or --input_table is required.') + if has_input_file and has_input_table: + raise ValueError('Use exactly one of --input_file or --input_table.') + if not args.output_vocab: + raise ValueError('--output_vocab is required.') + if not args.columns: + raise ValueError('--columns is required.') + if args.vocab_size is None or args.vocab_size <= 0: + raise ValueError('--vocab_size must be > 0.') + if args.min_frequency is None or args.min_frequency < 1: + raise ValueError('--min_frequency must be >= 1.') + if args.input_expand_factor is None or args.input_expand_factor < 1: + raise ValueError('--input_expand_factor must be >= 1.') + return [col.strip() for col in args.columns.split(',') if col.strip()] + + +def run(argv=None, test_pipeline=None): + known_args, pipeline_args = parse_known_args(argv) + columns = validate_args(known_args) + lowercase = parse_bool_flag(known_args.lowercase) + artifact_location = known_args.artifact_location or tempfile.mkdtemp( + prefix='mltransform_generate_vocab_artifacts_') + + options = PipelineOptions(pipeline_args) + pipeline = test_pipeline or beam.Pipeline(options=options) + + if known_args.input_file: + lines = ( + pipeline + | 'ReadInputFile' >> beam.io.ReadFromText(known_args.input_file)) + if known_args.input_expand_factor > 1: + lines = ( + lines + | 'ExpandInputForPerf' >> beam.FlatMap( + lambda line, n: [line] * n, known_args.input_expand_factor)) + rows = lines | 'ParseJSON' >> beam.Map(_parse_json_line) + else: + rows = pipeline | 'ReadInputTable' >> beam.io.ReadFromBigQuery( + table=known_args.input_table) + + vocab_text = ( + rows + | 'BuildVocabText' >> + beam.Map(lambda row: _build_vocab_text(row, columns, lowercase)) + | 'DropEmptyText' >> beam.Filter(bool)) + + _ = ( + vocab_text + | 'MLTransformInput' >> beam.Map(lambda text: {'text': text}) + | 'ApplyMLTransform' >> + MLTransform(write_artifact_location=artifact_location).with_transform( + ComputeAndApplyVocabulary( + columns=['text'], + top_k=known_args.vocab_size, + frequency_threshold=known_args.min_frequency, + split_string_by_delimiter=' ', + vocab_filename='vocab')) + | 'ExtractTransformedTokens' >> beam.Map(lambda row: row.text) + | 'FlattenTokens' >> beam.FlatMap(list)) Review Comment:  The downstream steps `ExtractTransformedTokens` and `FlattenTokens` are completely redundant because the output of the `MLTransform` pipeline is ignored (assigned to `_`), and the vocabulary is read directly from the written artifact. Removing these steps avoids unnecessary processing overhead on the entire dataset. ```python | 'ApplyMLTransform' >> MLTransform(write_artifact_location=artifact_location).with_transform( ComputeAndApplyVocabulary( columns=['text'], top_k=known_args.vocab_size, frequency_threshold=known_args.min_frequency, split_string_by_delimiter=' ', vocab_filename='vocab'))) ``` ########## sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py: ########## @@ -193,6 +193,29 @@ def _get_throughput_metrics( """Query Cloud Monitoring for per-PCollection throughput.""" name = ( pcollection_name if pcollection_name is not None else self.pcollection) + + def _point_numeric_value(point) -> float: + value = point.value + # point.value is proto-plus, so use the underlying protobuf oneof. + raw_value = getattr(value, '_pb', None) + if raw_value is not None: + active_field = raw_value.WhichOneof('value') + if active_field == 'double_value': + return float(value.double_value) + if active_field == 'int64_value': + return float(value.int64_value) + if active_field == 'distribution_value': + # Use aligned mean for distribution-valued points. + distribution = value.distribution_value + if distribution.count > 0: + return float(distribution.mean) + return 0.0 + if active_field == 'money_value': + money = value.money_value + nanos = getattr(money, 'nanos', 0) + return float(money.units) + (float(nanos) / 1_000_000_000.0) + return 0.0 Review Comment:  The helper `_point_numeric_value` assumes `point.value` is always a proto-plus wrapper and accesses `_pb` directly. If `point.value` is a standard protobuf message (e.g., in unit tests or if the client library behavior changes), `_pb` will not be present. We can make this more robust by falling back to the object itself. Additionally, `money_value` is not a valid field of `google.monitoring.v3.TypedValue` and can be removed. ```python def _point_numeric_value(point) -> float: value = point.value # point.value is proto-plus, so use the underlying protobuf oneof if present, # otherwise fall back to the value object itself. msg = getattr(value, '_pb', value) if msg is not None: try: active_field = msg.WhichOneof('value') except AttributeError: active_field = None if active_field == 'double_value': return float(value.double_value) if active_field == 'int64_value': return float(value.int64_value) if active_field == 'distribution_value': # Use aligned mean for distribution-valued points. distribution = value.distribution_value if distribution.count > 0: return float(distribution.mean) return 0.0 return 0.0 ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
