gemini-code-assist[bot] commented on code in PR #36623: URL: https://github.com/apache/beam/pull/36623#discussion_r2523782294
########## sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java: ########## @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeNotNull; +import static org.junit.Assume.assumeTrue; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.ml.remoteinference.RemoteInference; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response; + +public class OpenAIModelHandlerIT { + private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class); + + @Rule + public final transient TestPipeline pipeline = TestPipeline.create(); + + private String apiKey; + private static final String API_KEY_ENV = "OPENAI_API_KEY"; + private static final String DEFAULT_MODEL = "gpt-4o-mini"; + + + @Before + public void setUp() { + // Get API key + apiKey = System.getenv(API_KEY_ENV); + + // Skip tests if API key is not provided + assumeNotNull( + "OpenAI API key not found. Set " + API_KEY_ENV + + " environment variable to run integration tests.", + apiKey); + assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV + + " environment variable to run integration tests.", + !apiKey.trim().isEmpty()); + } + + @Test + public void testSentimentAnalysisWithSingleInput() { + String input = "This product is absolutely amazing! I love it!"; + + PCollection<OpenAIModelInput> inputs = pipeline + .apply("CreateSingleInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = inputs + .apply("SentimentInference", + RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Analyze the sentiment as 'positive' or 'negative'. Return only one word.") + .build())); + + // Verify results + PAssert.that(results).satisfies(batches -> { + int count = 0; + for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> batch : batches) { + for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : batch) { + count++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertNotNull("Output text should not be null", + result.getOutput().getModelResponse()); + + String sentiment = result.getOutput().getModelResponse().toLowerCase(); + assertTrue("Sentiment should be positive or negative, got: " + sentiment, + sentiment.contains("positive") + || sentiment.contains("negative")); + } + } + assertEquals("Should have exactly 1 result", 1, count); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testSentimentAnalysisWithMultipleInputs() { + List<String> inputs = Arrays.asList( + "An excellent B2B SaaS solution that streamlines business processes efficiently.", + "The customer support is terrible. I've been waiting for days without any response.", + "The application works as expected. Installation was straightforward.", + "Really impressed with the innovative features! The AI capabilities are groundbreaking!", + "Mediocre product with occasional glitches. Documentation could be better."); + + PCollection<OpenAIModelInput> inputCollection = pipeline + .apply("CreateMultipleInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = inputCollection + .apply("SentimentInference", + RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Analyze sentiment as positive or negative") + .build())); + + // Verify we get results for all inputs + PAssert.that(results).satisfies(batches -> { + int totalCount = 0; + for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> batch : batches) { + for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : batch) { + totalCount++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertFalse("Output should not be empty", + result.getOutput().getModelResponse().trim().isEmpty()); + } + } + assertEquals("Should have results for all 5 inputs", 5, totalCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testTextClassification() { + List<String> inputs = Arrays.asList( + "How do I reset my password?", + "Your product is broken and I want a refund!", + "Thank you for the excellent service!"); + + PCollection<OpenAIModelInput> inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = inputCollection + .apply("ClassificationInference", + RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.") + .build())); + + PAssert.that(results).satisfies(batches -> { + List<String> categories = new ArrayList<>(); + for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> batch : batches) { + for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : batch) { + String category = result.getOutput().getModelResponse().toLowerCase(); + categories.add(category); + } + } + + assertEquals("Should have 3 categories", 3, categories.size()); + + // Verify expected categories + boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question")); + boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint")); + boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise")); + + assertTrue("Should have at least one recognized category", + hasQuestion || hasComplaint || hasPraise); + + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInputOutputMapping() { + List<String> inputs = Arrays.asList("apple", "banana", "cherry"); + + PCollection<OpenAIModelInput> inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = inputCollection + .apply("MappingInference", + RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Return the input word in uppercase") + .build())); + + // Verify input-output pairing is preserved + PAssert.that(results).satisfies(batches -> { + for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> batch : batches) { + for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : batch) { + String input = result.getInput().getModelInput(); + String output = result.getOutput().getModelResponse().toLowerCase(); + + // Verify the output relates to the input + assertTrue("Output should relate to input '" + input + "', got: " + output, + output.contains(input.toLowerCase())); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithDifferentModel() { + // Test with a different model + String input = "Explain quantum computing in one sentence."; + + PCollection<OpenAIModelInput> inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = inputs + .apply("DifferentModelInference", + RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName("gpt-5") Review Comment:  The integration test `testWithDifferentModel` uses the model name "gpt-5". This model is not currently available, and using it will cause the integration test to fail with an "invalid model" error from the OpenAI API. Please use a valid and available model name to ensure the test can run successfully. For example, you could use another real model like `gpt-4-turbo`. ```suggestion .modelName("gpt-4-turbo") ``` ########## sdks/java/ml/remoteinference/build.gradle.kts: ########## @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id("org.apache.beam.module") + id("java-library") +} + +description = "Apache Beam :: SDKs :: Java :: ML :: RemoteInference" + +dependencies { + // Core Beam SDK + implementation(project(":sdks:java:core")) + + implementation("com.openai:openai-java:4.3.0") + compileOnly("com.google.auto.value:auto-value-annotations:1.11.0") + compileOnly("org.checkerframework:checker-qual:3.42.0") + annotationProcessor("com.google.auto.value:auto-value:1.11.0") + implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") Review Comment:  The specified version `2.20.0` for `com.fasterxml.jackson.core:jackson-core` does not appear to be a valid published version and will likely cause build failures. Please use a valid version. For example, `2.17.1` is a recent stable version. ```kotlin implementation("com.fasterxml.jackson.core:jackson-core:2.17.1") ``` ########## sdks/java/ml/remoteinference/build.gradle.kts: ########## @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id("org.apache.beam.module") + id("java-library") +} + +description = "Apache Beam :: SDKs :: Java :: ML :: RemoteInference" + +dependencies { + // Core Beam SDK + implementation(project(":sdks:java:core")) + + implementation("com.openai:openai-java:4.3.0") + compileOnly("com.google.auto.value:auto-value-annotations:1.11.0") + compileOnly("org.checkerframework:checker-qual:3.42.0") + annotationProcessor("com.google.auto.value:auto-value:1.11.0") + implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") + implementation("org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1") + implementation("org.slf4j:slf4j-api:2.0.9") + implementation("org.slf4j:slf4j-simple:2.0.9") Review Comment:  The `slf4j-simple` dependency should be scoped to `testImplementation` instead of `implementation`. Including a concrete logging implementation like `slf4j-simple` in the `implementation` scope can lead to classpath conflicts in downstream projects that use this module, as they may have their own preferred logging framework. It's a best practice for libraries to only depend on `slf4j-api`. ```kotlin testImplementation("org.slf4j:slf4j-simple:2.0.9") ``` ########## sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java: ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonSchemaLocalValidation; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.StructuredResponseCreateParams; +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; +import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Model handler for OpenAI API inference requests. + * + * <p>This handler manages communication with OpenAI's API, including client initialization, + * request formatting, and response parsing. It uses OpenAI's structured output feature to + * ensure reliable input-output pairing. + * + * <h3>Usage</h3> + * <pre>{@code + * OpenAIModelParameters params = OpenAIModelParameters.builder() + * .apiKey("sk-...") + * .modelName("gpt-4") + * .instructionPrompt("Classify the following text into one of the categories: {CATEGORIES}") + * .build(); + * + * PCollection<OpenAIModelInput> inputs = ...; + * PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = + * inputs.apply( + * RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + * .handler(OpenAIModelHandler.class) + * .withParameters(params) + * ); + * }</pre> + * + */ +public class OpenAIModelHandler + implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, OpenAIModelResponse> { + + private transient OpenAIClient client; + private transient StructuredResponseCreateParams<StructuredInputOutput> clientParams; Review Comment:  The field `clientParams` is declared as a transient instance variable but it is only assigned and used within the `request` method. This makes the code harder to reason about, as it suggests `clientParams` holds state across method calls, which it doesn't. It should be a local variable within the `request` method. ########## sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java: ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonSchemaLocalValidation; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.StructuredResponseCreateParams; +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; +import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Model handler for OpenAI API inference requests. + * + * <p>This handler manages communication with OpenAI's API, including client initialization, + * request formatting, and response parsing. It uses OpenAI's structured output feature to + * ensure reliable input-output pairing. + * + * <h3>Usage</h3> + * <pre>{@code + * OpenAIModelParameters params = OpenAIModelParameters.builder() + * .apiKey("sk-...") + * .modelName("gpt-4") + * .instructionPrompt("Classify the following text into one of the categories: {CATEGORIES}") + * .build(); + * + * PCollection<OpenAIModelInput> inputs = ...; + * PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = + * inputs.apply( + * RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + * .handler(OpenAIModelHandler.class) + * .withParameters(params) + * ); + * }</pre> + * + */ +public class OpenAIModelHandler + implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, OpenAIModelResponse> { + + private transient OpenAIClient client; + private transient StructuredResponseCreateParams<StructuredInputOutput> clientParams; + private OpenAIModelParameters modelParameters; + + /** + * Initializes the OpenAI client with the provided parameters. + * + * <p>This method is called once during setup. It creates an authenticated + * OpenAI client using the API key from the parameters. + * + * @param parameters the configuration parameters including API key and model name + */ + @Override + public void createClient(OpenAIModelParameters parameters) { + this.modelParameters = parameters; + this.client = OpenAIOkHttpClient.builder() + .apiKey(this.modelParameters.getApiKey()) + .build(); + } + + /** + * Performs inference on a batch of inputs using the OpenAI Client. + * + * <p>This method serializes the input batch to JSON string, sends it to OpenAI with structured + * output requirements, and parses the response into {@link PredictionResult} objects + * that pair each input with its corresponding output. + * + * @param input the list of inputs to process + * @return an iterable of model results and input pairs + */ + @Override + public Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> request(List<OpenAIModelInput> input) { + + try { + // Convert input list to JSON string + String inputBatch = new ObjectMapper() + .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); + + // Build structured response parameters + this.clientParams = ResponseCreateParams.builder() + .model(modelParameters.getModelName()) + .input(inputBatch) + .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) + .instructions(modelParameters.getInstructionPrompt()) + .build(); Review Comment:  The `clientParams` is being assigned to an instance field, but it's only used locally within the `request` method. It should be a local variable. This also requires removing the `clientParams` field from the class. ```suggestion StructuredResponseCreateParams<StructuredInputOutput> clientParams = ResponseCreateParams.builder() .model(modelParameters.getModelName()) .input(inputBatch) .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) .instructions(modelParameters.getInstructionPrompt()) .build(); ``` ########## sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java: ########## @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; + +/** + * Configuration parameters required for OpenAI model inference. + * + * <p>This class encapsulates all configuration needed to initialize and communicate with + * OpenAI's API, including authentication credentials, model selection, and inference instructions. + * + * <h3>Example Usage</h3> + * <pre>{@code + * OpenAIModelParameters params = OpenAIModelParameters.builder() + * .apiKey("sk-...") + * .modelName("gpt-4") + * .instructionPrompt("Translate the following text to French:") + * .build(); + * }</pre> + * + * @see OpenAIModelHandler + */ +public class OpenAIModelParameters implements BaseModelParameters { Review Comment:  The classes `OpenAIModelParameters`, `OpenAIModelInput`, and `OpenAIModelResponse` do not override `equals()` and `hashCode()`. This can lead to unexpected behavior when these objects are used in collections (like `Set` or as keys in a `Map`) or in tests that rely on object equality. The test classes you've written for `RemoteInferenceTest` correctly implement these methods, and the production classes should as well. For `OpenAIModelParameters`, you can add the following: ```java @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; OpenAIModelParameters that = (OpenAIModelParameters) o; return java.util.Objects.equals(apiKey, that.apiKey) && java.util.Objects.equals(modelName, that.modelName) && java.util.Objects.equals(instructionPrompt, that.instructionPrompt); } @Override public int hashCode() { return java.util.Objects.hash(apiKey, modelName, instructionPrompt); } ``` Similar implementations should be added to `OpenAIModelInput` and `OpenAIModelResponse`. ########## sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java: ########## @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference; + +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.transforms.*; +import org.checkerframework.checker.nullness.qual.Nullable; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import org.apache.beam.sdk.values.PCollection; +import com.google.auto.value.AutoValue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A {@link PTransform} for making remote inference calls to external machine learning services. + * + * <p>{@code RemoteInference} provides a framework for integrating remote ML model + * inference into Apache Beam pipelines and handles the communication between pipelines + * and external inference APIs. + * + * <h3>Example: OpenAI Model Inference</h3> + * + * <pre>{@code + * // Create model parameters + * OpenAIModelParameters params = OpenAIModelParameters.builder() + * .apiKey("your-api-key") + * .modelName("gpt-4") + * .instructionPrompt("Analyse sentiment as positive or negative") + * .build(); + * + * // Apply remote inference transform + * PCollection<OpenAIModelInput> inputs = pipeline.apply(Create.of( + * OpenAIModelInput.create("An excellent B2B SaaS solution that streamlines business processes efficiently."), + * OpenAIModelInput.create("Really impressed with the innovative features!") + * )); + * + * PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = + * inputs.apply( + * RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + * .handler(OpenAIModelHandler.class) + * .withParameters(params) + * ); + * }</pre> + * + */ +@SuppressWarnings({ "rawtypes", "unchecked" }) Review Comment:  The class is suppressed with `@SuppressWarnings({ "rawtypes", "unchecked" })` because `BaseModelHandler` is used as a raw type. This reduces type safety and can hide potential class cast exceptions at runtime. For example, `Invoke.handler()` returns `Class<? extends BaseModelHandler>`, and `RemoteInferenceFn.handler` is a raw `BaseModelHandler`. This means the compiler cannot verify that the `BaseModelParameters` and `BaseInput`/`BaseResponse` types are compatible between the `Invoke` transform and the `BaseModelHandler` implementation. While fixing this might require some significant refactoring of the generics, it would make the framework more robust. A potential direction could be to include the `BaseModelParameters` type in the `Invoke` transform's generics, like `Invoke<InputT, OutputT, ParamT extends BaseModelParameters>`. This would allow for stronger type checking throughout the implementation. ########## sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java: ########## @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference; + +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.transforms.*; +import org.checkerframework.checker.nullness.qual.Nullable; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import org.apache.beam.sdk.values.PCollection; +import com.google.auto.value.AutoValue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A {@link PTransform} for making remote inference calls to external machine learning services. + * + * <p>{@code RemoteInference} provides a framework for integrating remote ML model + * inference into Apache Beam pipelines and handles the communication between pipelines + * and external inference APIs. + * + * <h3>Example: OpenAI Model Inference</h3> + * + * <pre>{@code + * // Create model parameters + * OpenAIModelParameters params = OpenAIModelParameters.builder() + * .apiKey("your-api-key") + * .modelName("gpt-4") + * .instructionPrompt("Analyse sentiment as positive or negative") + * .build(); + * + * // Apply remote inference transform + * PCollection<OpenAIModelInput> inputs = pipeline.apply(Create.of( + * OpenAIModelInput.create("An excellent B2B SaaS solution that streamlines business processes efficiently."), + * OpenAIModelInput.create("Really impressed with the innovative features!") + * )); + * + * PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = + * inputs.apply( + * RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + * .handler(OpenAIModelHandler.class) + * .withParameters(params) + * ); + * }</pre> + * + */ +@SuppressWarnings({ "rawtypes", "unchecked" }) +public class RemoteInference { + + /** Invoke the model handler with model parameters */ + public static <InputT extends BaseInput, OutputT extends BaseResponse> Invoke<InputT, OutputT> invoke() { + return new AutoValue_RemoteInference_Invoke.Builder<InputT, OutputT>().setParameters(null) + .build(); + } + + private RemoteInference() { + } + + @AutoValue + public abstract static class Invoke<InputT extends BaseInput, OutputT extends BaseResponse> + extends PTransform<PCollection<InputT>, PCollection<Iterable<PredictionResult<InputT, OutputT>>>> { + + abstract @Nullable Class<? extends BaseModelHandler> handler(); + + abstract @Nullable BaseModelParameters parameters(); + + + abstract Builder<InputT, OutputT> builder(); + + @AutoValue.Builder + abstract static class Builder<InputT extends BaseInput, OutputT extends BaseResponse> { + + abstract Builder<InputT, OutputT> setHandler(Class<? extends BaseModelHandler> modelHandler); + + abstract Builder<InputT, OutputT> setParameters(BaseModelParameters modelParameters); + + + abstract Invoke<InputT, OutputT> build(); + } + + /** + * Model handler class for inference. + */ + public Invoke<InputT, OutputT> handler(Class<? extends BaseModelHandler> modelHandler) { + return builder().setHandler(modelHandler).build(); + } + + /** + * Configures the parameters for model initialization. + */ + public Invoke<InputT, OutputT> withParameters(BaseModelParameters modelParameters) { + return builder().setParameters(modelParameters).build(); + } + + + @Override + public PCollection<Iterable<PredictionResult<InputT, OutputT>>> expand(PCollection<InputT> input) { + checkArgument(handler() != null, "handler() is required"); + checkArgument(parameters() != null, "withParameters() is required"); + return input + .apply("WrapInputInList", MapElements.via(new SimpleFunction<InputT, List<InputT>>() { + @Override + public List<InputT> apply(InputT element) { + return Collections.singletonList(element); + } + })) + // Pass the list to the inference function + .apply("RemoteInference", ParDo.of(new RemoteInferenceFn<InputT, OutputT>(this))); Review Comment:  The current implementation processes each input element individually by wrapping it in a singleton list. This results in a separate remote inference call for every element, which is highly inefficient and will lead to poor performance, especially for large datasets. The `BaseModelHandler#request` method already accepts a `List<InputT>`, suggesting that batching is intended. To improve performance, you should introduce batching before the `ParDo` transform. You can use `org.apache.beam.sdk.transforms.GroupIntoBatches` to group elements into batches of a configurable size. This will significantly reduce the number of remote calls. For example, you could add a `batchSize` parameter to the `Invoke` transform and then use it like this: ```java // First, add a key to each element input.apply(WithKeys.of("key")) // Then, group elements into batches .apply(GroupIntoBatches.ofSize(batchSize)) // Then, get the values (the batches) .apply(Values.create()) // Finally, perform the remote inference on each batch .apply("RemoteInference", ParDo.of(new RemoteInferenceFn<InputT, OutputT>(this))); ``` ########## sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java: ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonSchemaLocalValidation; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.StructuredResponseCreateParams; +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; +import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Model handler for OpenAI API inference requests. + * + * <p>This handler manages communication with OpenAI's API, including client initialization, + * request formatting, and response parsing. It uses OpenAI's structured output feature to + * ensure reliable input-output pairing. + * + * <h3>Usage</h3> + * <pre>{@code + * OpenAIModelParameters params = OpenAIModelParameters.builder() + * .apiKey("sk-...") + * .modelName("gpt-4") + * .instructionPrompt("Classify the following text into one of the categories: {CATEGORIES}") + * .build(); + * + * PCollection<OpenAIModelInput> inputs = ...; + * PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = + * inputs.apply( + * RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() + * .handler(OpenAIModelHandler.class) + * .withParameters(params) + * ); + * }</pre> + * + */ +public class OpenAIModelHandler + implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, OpenAIModelResponse> { + + private transient OpenAIClient client; + private transient StructuredResponseCreateParams<StructuredInputOutput> clientParams; + private OpenAIModelParameters modelParameters; + + /** + * Initializes the OpenAI client with the provided parameters. + * + * <p>This method is called once during setup. It creates an authenticated + * OpenAI client using the API key from the parameters. + * + * @param parameters the configuration parameters including API key and model name + */ + @Override + public void createClient(OpenAIModelParameters parameters) { + this.modelParameters = parameters; + this.client = OpenAIOkHttpClient.builder() + .apiKey(this.modelParameters.getApiKey()) + .build(); + } + + /** + * Performs inference on a batch of inputs using the OpenAI Client. + * + * <p>This method serializes the input batch to JSON string, sends it to OpenAI with structured + * output requirements, and parses the response into {@link PredictionResult} objects + * that pair each input with its corresponding output. + * + * @param input the list of inputs to process + * @return an iterable of model results and input pairs + */ + @Override + public Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> request(List<OpenAIModelInput> input) { + + try { + // Convert input list to JSON string + String inputBatch = new ObjectMapper() + .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); Review Comment:  A new `ObjectMapper` instance is created for every call to the `request` method. `ObjectMapper` is a heavy object to create, and it is thread-safe. Creating it repeatedly in a hot path like this can negatively impact performance. You should create the `ObjectMapper` instance once and reuse it. A good place to initialize it would be in the `createClient` method and store it in a `transient` field. 1. Add a field to `OpenAIModelHandler`: ```java private transient ObjectMapper objectMapper; ``` 2. Initialize it in `createClient`: ```java @Override public void createClient(OpenAIModelParameters parameters) { this.modelParameters = parameters; this.client = OpenAIOkHttpClient.builder() .apiKey(this.modelParameters.getApiKey()) .build(); this.objectMapper = new ObjectMapper(); } ``` 3. Use the field in `request`: ```suggestion String inputBatch = objectMapper .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
