This is an automated email from the ASF dual-hosted git repository.
derrickaw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 65e8b655195 huggingface model handler for yaml - retry (#38451)
65e8b655195 is described below
commit 65e8b655195bf9886c402c5ca27c28fe783d5991
Author: Derrick Williams <[email protected]>
AuthorDate: Thu May 21 10:33:42 2026 -0400
huggingface model handler for yaml - retry (#38451)
* huggingface model handler for yaml
* add yaml huggingface test file
* update dependency logic
---
.../beam_PostCommit_Yaml_Xlang_Direct.json | 2 +-
.../beam_PostCommit_Yaml_Xlang_Direct.yml | 2 +-
.../workflows/beam_PreCommit_Yaml_Xlang_Direct.yml | 2 +-
.../yaml/tests/runinference_huggingface.yaml | 62 ++++++++++++++++++++++
...uninference.yaml => runinference_vertexai.yaml} | 0
sdks/python/apache_beam/yaml/yaml_ml.py | 49 +++++++++++++++++
sdks/python/build.gradle | 19 ++++++-
7 files changed, 131 insertions(+), 5 deletions(-)
diff --git a/.github/trigger_files/beam_PostCommit_Yaml_Xlang_Direct.json
b/.github/trigger_files/beam_PostCommit_Yaml_Xlang_Direct.json
index 541dc4ea8e8..8ed972c9f57 100644
--- a/.github/trigger_files/beam_PostCommit_Yaml_Xlang_Direct.json
+++ b/.github/trigger_files/beam_PostCommit_Yaml_Xlang_Direct.json
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to
run",
- "revision": 2
+ "revision": 3
}
diff --git a/.github/workflows/beam_PostCommit_Yaml_Xlang_Direct.yml
b/.github/workflows/beam_PostCommit_Yaml_Xlang_Direct.yml
index ea1c255f7cc..afa437b64de 100644
--- a/.github/workflows/beam_PostCommit_Yaml_Xlang_Direct.yml
+++ b/.github/workflows/beam_PostCommit_Yaml_Xlang_Direct.yml
@@ -80,7 +80,7 @@ jobs:
- name: run PostCommit Yaml Xlang Direct script
uses: ./.github/actions/gradle-command-self-hosted-action
with:
- gradle-command: :sdks:python:postCommitYamlIntegrationTests
-PyamlTestSet=${{ matrix.test_set }} -PbeamPythonExtra=ml_test,yaml
+ gradle-command: :sdks:python:postCommitYamlIntegrationTests
-PyamlTestSet=${{ matrix.test_set }}
- name: Archive Python Test Results
uses: actions/upload-artifact@v7
if: failure()
diff --git a/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml
b/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml
index 7d17fd2140c..0b8f4cd6393 100644
--- a/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml
+++ b/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml
@@ -91,7 +91,7 @@ jobs:
- name: run PreCommit Yaml Xlang Direct script
uses: ./.github/actions/gradle-command-self-hosted-action
with:
- gradle-command: :sdks:python:yamlIntegrationTests
-PbeamPythonExtra=ml_test,yaml
+ gradle-command: :sdks:python:yamlIntegrationTests
- name: Archive Python Test Results
uses: actions/upload-artifact@v7
if: failure()
diff --git a/sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml
b/sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml
new file mode 100644
index 00000000000..8728a6f544a
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+pipelines:
+ - pipeline:
+ type: chain
+ transforms:
+ - type: Create
+ config:
+ elements:
+ - text: "I love Apache Beam!"
+ - text: "I hate this error."
+ - type: RunInference
+ config:
+ model_handler:
+ type: "HuggingFacePipeline"
+ config:
+ task: "text-classification"
+ inference_fn:
+ callable: |
+ def real_inference(batch, pipeline, inference_args):
+ predictions = pipeline(batch, **inference_args)
+
+ # If it's a single dictionary (batch size of 1), wrap it
in a list
+ if isinstance(predictions, dict):
+ predictions = [predictions]
+
+ return {
+ 'label': [p['label'] for p in predictions],
+ 'score': [p['score'] for p in predictions]
+ }
+ preprocess:
+ callable: 'lambda x: x.text'
+ - type: MapToFields
+ config:
+ language: python
+ fields:
+ text: text
+ sentiment:
+ callable: 'lambda x: x.inference.inference["label"]'
+ - type: AssertEqual
+ config:
+ elements:
+ - text: "I love Apache Beam!"
+ sentiment: "POSITIVE"
+ - text: "I hate this error."
+ sentiment: "NEGATIVE"
+
+ options:
+ yaml_experimental_features: ['ML']
diff --git a/sdks/python/apache_beam/yaml/tests/runinference.yaml
b/sdks/python/apache_beam/yaml/tests/runinference_vertexai.yaml
similarity index 100%
rename from sdks/python/apache_beam/yaml/tests/runinference.yaml
rename to sdks/python/apache_beam/yaml/tests/runinference_vertexai.yaml
diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py
b/sdks/python/apache_beam/yaml/yaml_ml.py
index 51f18c73304..05cbed3bd45 100644
--- a/sdks/python/apache_beam/yaml/yaml_ml.py
+++ b/sdks/python/apache_beam/yaml/yaml_ml.py
@@ -282,6 +282,55 @@ class
VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
('model_id', Optional[str])])
[email protected]_handler_type('HuggingFacePipeline')
+class HuggingFacePipelineProvider(ModelHandlerProvider):
+ def __init__(
+ self,
+ task: Optional[str] = None,
+ model: Optional[str] = None,
+ preprocess: Optional[dict[str, str]] = None,
+ postprocess: Optional[dict[str, str]] = None,
+ device: Optional[Any] = None,
+ inference_fn: Optional[dict[str, str]] = None,
+ load_pipeline_args: Optional[dict[str, Any]] = None,
+ **kwargs):
+ try:
+ from apache_beam.ml.inference.huggingface_inference import
HuggingFacePipelineModelHandler
+ except ImportError:
+ raise ValueError(
+ 'Unable to import HuggingFacePipelineModelHandler. Please '
+ 'install transformers dependencies.')
+
+ kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_')}
+
+ inference_fn_obj = self.parse_processing_transform(
+ inference_fn, 'inference_fn') if inference_fn else None
+
+ handler_kwargs = {}
+ if inference_fn_obj:
+ handler_kwargs['inference_fn'] = inference_fn_obj
+
+ _handler = HuggingFacePipelineModelHandler(
+ task=task,
+ model=model,
+ device=device,
+ load_pipeline_args=load_pipeline_args,
+ **handler_kwargs,
+ **kwargs)
+
+ super().__init__(_handler, preprocess, postprocess)
+
+ @staticmethod
+ def validate(config):
+ if not config.get('task') and not config.get('model'):
+ raise ValueError(
+ "HuggingFacePipeline requires either 'task' or "
+ "'model' to be specified.")
+
+ def inference_output_type(self):
+ return Any
+
+
@beam.ptransform.ptransform_fn
def run_inference(
pcoll,
diff --git a/sdks/python/build.gradle b/sdks/python/build.gradle
index 5f09dff57e8..837631868b8 100644
--- a/sdks/python/build.gradle
+++ b/sdks/python/build.gradle
@@ -124,10 +124,25 @@ tasks.register("generateYamlDocs") {
outputs.file "${buildDir}/yaml-examples.html"
}
+tasks.register("installYamlIntegrationTestDeps") {
+ dependsOn installGcpTest
+ doLast {
+ exec {
+ executable 'sh'
+ args '-c', ". ${envdir}/bin/activate && " +
+ "py_ver=\$(python -c 'import sys;
print(f\"{sys.version_info.major}{sys.version_info.minor}\")') && " +
+ "ml_extra=\"ml_test\" && " +
+ "if [ \"\$py_ver\" -ge 313 ]; then
ml_extra=\"p\${py_ver}_ml_test\"; fi && " +
+ "echo \"Installing dependencies...\" && " +
+ "pip install --pre --retries 10
${buildDir}/apache-beam.tar.gz[\$ml_extra,yaml,transformers]"
+ }
+ }
+}
+
tasks.register("yamlIntegrationTests") {
description "Runs precommit integration tests for yaml pipelines."
- dependsOn installGcpTest
+ dependsOn installYamlIntegrationTestDeps
// Need to build all expansion services referenced in apache_beam/yaml/*.*
// grep -oh 'sdk.*Jar' sdks/python/apache_beam/yaml/*.yaml | sort | uniq
dependsOn ":sdks:java:extensions:schemaio-expansion-service:shadowJar"
@@ -146,7 +161,7 @@ tasks.register("yamlIntegrationTests") {
tasks.register("postCommitYamlIntegrationTests") {
description "Runs postcommit integration tests for yaml pipelines -
parameterized by yamlTestSet."
- dependsOn installGcpTest
+ dependsOn installYamlIntegrationTestDeps
// Need to build all expansion services referenced in apache_beam/yaml/*.*
// grep -oh 'sdk.*Jar' sdks/python/apache_beam/yaml/*.yaml | sort | uniq
dependsOn ":sdks:java:extensions:schemaio-expansion-service:shadowJar"