This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit d3ac57540969ed122de5bfb793d6af42ca9797ca Author: WenjinXie <[email protected]> AuthorDate: Thu Nov 13 15:57:42 2025 +0800 [test][java] Clean up and refactor java e2e tests. --- .../agents/integration/test/AgentWithAzureAI.java | 126 ------------- .../integration/test/AgentWithOllamaExample.java | 65 ------- .../agents/integration/test/AgentWithResource.java | 168 ----------------- .../integration/test/AgentWithResourceExample.java | 65 ------- .../agents/integration/test/DataStreamAgent.java | 96 ---------- .../test/DataStreamIntegrationExample.java | 102 ----------- .../test/DataStreamTableIntegrationExample.java | 93 ---------- .../flink/agents/integration/test/SimpleAgent.java | 82 --------- .../flink/agents/integration/test/TableAgent.java | 98 ---------- .../integration/test/TableIntegrationExample.java | 99 ---------- .../test/ChatModelIntegrationAgent.java} | 62 +++++-- .../test/ChatModelIntegrationTest.java} | 77 ++++++-- .../test/EmbeddingIntegrationAgent.java} | 33 ++-- .../test/EmbeddingIntegrationTest.java} | 58 +++++- .../integration/test/FlinkIntegrationAgent.java | 186 +++++++++++++++++++ .../integration/test/FlinkIntegrationTest.java | 204 +++++++++++++++++++++ .../agents/integration/test/MemoryObjectAgent.java | 0 .../agents/integration/test/MemoryObjectTest.java} | 22 ++- .../integration/test/OllamaPreparationUtils.java | 47 +++++ .../agents/integration/test/ReActAgentTest.java} | 57 ++++-- .../test_from_datastream_to_datastream.txt | 6 + .../ground-truth/test_from_datastream_to_table.txt | 6 + .../ground-truth/test_from_table_to_table.txt | 6 + .../resources/log4j2-test.properties} | 23 ++- .../resources/ollama_pull_model.sh} | 11 +- 25 files changed, 714 insertions(+), 1078 deletions(-) diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithAzureAI.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithAzureAI.java deleted file mode 100644 index e16f7e0..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithAzureAI.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.Agent; -import org.apache.flink.agents.api.InputEvent; -import org.apache.flink.agents.api.OutputEvent; -import org.apache.flink.agents.api.annotation.*; -import org.apache.flink.agents.api.chat.messages.ChatMessage; -import org.apache.flink.agents.api.chat.messages.MessageRole; -import org.apache.flink.agents.api.context.RunnerContext; -import org.apache.flink.agents.api.event.ChatRequestEvent; -import org.apache.flink.agents.api.event.ChatResponseEvent; -import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.integrations.chatmodels.azureai.AzureAIChatModelConnection; -import org.apache.flink.agents.integrations.chatmodels.azureai.AzureAIChatModelSetup; - -import java.util.Collections; - -public class AgentWithAzureAI extends Agent { - - private static final String AZURE_ENDPOINT = ""; - private static final String AZURE_API_KEY = ""; - - public static boolean callingRealMode() { - if (AZURE_ENDPOINT != null - && !AZURE_ENDPOINT.isEmpty() - && AZURE_API_KEY != null - && !AZURE_API_KEY.isEmpty()) { - return true; - } else { - return false; - } - } - - @ChatModelConnection - public static ResourceDescriptor azureAIChatModelConnection() { - return ResourceDescriptor.Builder.newBuilder(AzureAIChatModelConnection.class.getName()) - .addInitialArgument("endpoint", AZURE_ENDPOINT) - .addInitialArgument("apiKey", AZURE_API_KEY) - .build(); - } - - @ChatModelSetup - public static ResourceDescriptor azureAIChatModel() { - System.out.println( - "Calling real Azure AI service. Make sure the endpoint and apiKey are correct."); - return ResourceDescriptor.Builder.newBuilder(AzureAIChatModelSetup.class.getName()) - .addInitialArgument("connection", "azureAIChatModelConnection") - .addInitialArgument("model", "gpt-4o") - .build(); - } - - @Tool(description = "Converts temperature between Celsius and Fahrenheit") - public static double convertTemperature( - @ToolParam(name = "value", description = "Temperature value to convert") Double value, - @ToolParam( - name = "fromUnit", - description = "Source unit ('C' for Celsius or 'F' for Fahrenheit)") - String fromUnit, - @ToolParam( - name = "toUnit", - description = "Target unit ('C' for Celsius or 'F' for Fahrenheit)") - String toUnit) { - - fromUnit = fromUnit.toUpperCase(); - toUnit = toUnit.toUpperCase(); - - if (fromUnit.equals(toUnit)) { - return value; - } - - if (fromUnit.equals("C") && toUnit.equals("F")) { - return (value * 9 / 5) + 32; - } else if (fromUnit.equals("F") && toUnit.equals("C")) { - return (value - 32) * 5 / 9; - } else { - throw new IllegalArgumentException("Invalid temperature units. Use 'C' or 'F'"); - } - } - - @Tool(description = "Calculates Body Mass Index (BMI)") - public static double calculateBMI( - @ToolParam(name = "weightKg", description = "Weight in kilograms") Double weightKg, - @ToolParam(name = "heightM", description = "Height in meters") Double heightM) { - - if (weightKg <= 0 || heightM <= 0) { - throw new IllegalArgumentException("Weight and height must be positive values"); - } - return weightKg / (heightM * heightM); - } - - @Tool(description = "Create a random number") - public static double createRandomNumber() { - return Math.random(); - } - - @Action(listenEvents = {InputEvent.class}) - public static void process(InputEvent event, RunnerContext ctx) throws Exception { - ctx.sendEvent( - new ChatRequestEvent( - "azureAIChatModel", - Collections.singletonList( - new ChatMessage(MessageRole.USER, (String) event.getInput())))); - } - - @Action(listenEvents = {ChatResponseEvent.class}) - public static void processChatResponse(ChatResponseEvent event, RunnerContext ctx) { - ctx.sendEvent(new OutputEvent(event.getResponse().getContent())); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaExample.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaExample.java deleted file mode 100644 index c58f6b4..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaExample.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.AgentsExecutionEnvironment; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; - -/** Example application that applies {@link AgentWithOllama} to a DataStream of user prompts. */ -public class AgentWithOllamaExample { - /** Runs the example pipeline. */ - public static void main(String[] args) throws Exception { - // Create the execution environment - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(1); - - // Use prompts that trigger different tool calls in the agent - DataStream<String> inputStream = - env.fromData( - "Convert 25 degrees Celsius to Fahrenheit", - "What is 98.6 Fahrenheit in Celsius?", - "Change 32 degrees Celsius to Fahrenheit", - "If it's 75 degrees Fahrenheit, what would that be in Celsius?", - "Convert room temperature of 20C to F", - "Calculate BMI for someone who is 1.75 meters tall and weighs 70 kg", - "What's the BMI for a person weighing 85 kg with height 1.80 meters?", - "Can you tell me the BMI if I'm 1.65m tall and weigh 60kg?", - "Find BMI for 75kg weight and 1.78m height", - "Create me a random number please"); - - // Create agents execution environment - AgentsExecutionEnvironment agentsEnv = - AgentsExecutionEnvironment.getExecutionEnvironment(env); - - // Apply agent to the DataStream and use the prompt itself as the key - DataStream<Object> outputStream = - agentsEnv - .fromDataStream(inputStream, (KeySelector<String, String>) value -> value) - .apply(new AgentWithOllama()) - .toDataStream(); - - // Print the results - outputStream.print(); - - // Execute the pipeline - agentsEnv.execute(); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithResource.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithResource.java deleted file mode 100644 index 101df78..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithResource.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.Agent; -import org.apache.flink.agents.api.InputEvent; -import org.apache.flink.agents.api.OutputEvent; -import org.apache.flink.agents.api.annotation.Action; -import org.apache.flink.agents.api.annotation.ChatModelSetup; -import org.apache.flink.agents.api.annotation.Tool; -import org.apache.flink.agents.api.annotation.ToolParam; -import org.apache.flink.agents.api.chat.messages.ChatMessage; -import org.apache.flink.agents.api.chat.messages.MessageRole; -import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; -import org.apache.flink.agents.api.context.RunnerContext; -import org.apache.flink.agents.api.event.ChatRequestEvent; -import org.apache.flink.agents.api.event.ChatResponseEvent; -import org.apache.flink.agents.api.prompt.Prompt; -import org.apache.flink.agents.api.resource.Resource; -import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.StringJoiner; -import java.util.function.BiFunction; - -public class AgentWithResource extends Agent { - public static class MockChatModel extends BaseChatModelSetup { - private final String endpoint; - private final Integer topK; - private final Double topP; - - public MockChatModel( - ResourceDescriptor descriptor, - BiFunction<String, ResourceType, Resource> getResource) { - super(descriptor, getResource); - this.endpoint = descriptor.getArgument("endpoint"); - this.topP = descriptor.getArgument("topP"); - this.topK = descriptor.getArgument("topK"); - } - - @Override - public Map<String, Object> getParameters() { - Map<String, Object> parameters = new HashMap<>(); - parameters.put("endpoint", this.endpoint); - parameters.put("topP", this.topP); - parameters.put("topK", this.topK); - return parameters; - } - - @Override - public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) { - if (messages.size() == 1) { - Map<String, Object> toolCall = new HashMap<>(); - toolCall.put("id", "1"); - toolCall.put( - "function", - new HashMap<String, Object>() { - { - put("name", tools.get(0)); - put("arguments", Map.of("a", 1, "b", 2, "operation", "add")); - } - }); - return new ChatMessage( - MessageRole.ASSISTANT, - String.format("I will call tool %s", tools.get(0)), - List.of(toolCall)); - } else { - StringJoiner content = new StringJoiner("\n"); - content.add( - String.format("endpoint: %s, topP: %s, topK: %s", endpoint, topP, topK)); - - Map<String, String> arguments = new HashMap<>(); - for (ChatMessage message : messages) { - for (Map.Entry<String, Object> entry : message.getExtraArgs().entrySet()) { - arguments.put(entry.getKey(), entry.getValue().toString()); - } - } - Prompt prompt = - (Prompt) getResource.apply((String) this.prompt, ResourceType.PROMPT); - List<ChatMessage> formatMessages = - prompt.formatMessages(MessageRole.USER, arguments); - content.add("Prompt: " + formatMessages.get(0).getContent()); - - for (ChatMessage message : messages) { - content.add(message.getContent()); - } - return new ChatMessage(MessageRole.ASSISTANT, content.toString()); - } - } - } - - @org.apache.flink.agents.api.annotation.Prompt - public static Prompt myPrompt() { - return new Prompt("What is {a} + {b}?"); - } - - @ChatModelSetup - public static ResourceDescriptor myChatModel() { - return ResourceDescriptor.Builder.newBuilder(MockChatModel.class.getName()) - .addInitialArgument("endpoint", "127.0.0.1") - .addInitialArgument("topK", 5) - .addInitialArgument("topP", 0.2) - .addInitialArgument("prompt", "myPrompt") - .addInitialArgument("tools", List.of("calculate")) - .build(); - } - - @Tool(description = "Performs basic arithmetic operations") - public static double calculate( - @ToolParam(name = "a") Double a, - @ToolParam(name = "b") Double b, - @ToolParam(name = "operation") String operation) { - switch (operation.toLowerCase()) { - case "add": - return a + b; - case "subtract": - return a - b; - case "multiply": - return a * b; - case "divide": - if (b == 0) throw new IllegalArgumentException("Division by zero"); - return a / b; - default: - throw new IllegalArgumentException("Unknown operation: " + operation); - } - } - - @Action(listenEvents = {InputEvent.class}) - public static void process(InputEvent event, RunnerContext ctx) throws Exception { - Map<String, Integer> input = (Map<String, Integer>) event.getInput(); - - ChatMessage message = - new ChatMessage( - MessageRole.USER, - String.format("What is %s + %s?", input.get("a"), input.get("b")), - Map.of("a", input.get("a"), "b", input.get("b"))); - - List<ChatMessage> messages = new ArrayList<>(); - messages.add(message); - ctx.sendEvent(new ChatRequestEvent("myChatModel", messages)); - } - - @Action(listenEvents = {ChatResponseEvent.class}) - public static void output(ChatResponseEvent event, RunnerContext ctx) throws Exception { - String output = event.getResponse().getContent(); - ctx.sendEvent(new OutputEvent(output)); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithResourceExample.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithResourceExample.java deleted file mode 100644 index c1828aa..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithResourceExample.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.AgentsExecutionEnvironment; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.datastream.DataStreamSource; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; - -import java.util.HashMap; -import java.util.Map; - -/** - * Example to test MemoryObject in a complete Java Flink execution environment. This job triggers - * the {@link MemoryObjectAgent} to test storing and retrieving complex data structures. - */ -public class AgentWithResourceExample { - - public static void main(String[] args) throws Exception { - // Create the execution environment - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(1); - - Map<String, Integer> element = new HashMap<>(); - element.put("a", 1); - element.put("b", 2); - DataStreamSource<Map<String, Integer>> inputStream = env.fromElements(element); - - // Create agents execution environment - AgentsExecutionEnvironment agentsEnv = - AgentsExecutionEnvironment.getExecutionEnvironment(env); - - // Apply agent to the DataStream and use the integer itself as the key - DataStream<Object> outputStream = - agentsEnv - .fromDataStream( - inputStream, - (KeySelector<Map<String, Integer>, Integer>) - value -> value.get("a")) - .apply(new AgentWithResource()) - .toDataStream(); - - // Print the results - outputStream.print(); - - // Execute the pipeline - agentsEnv.execute(); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/DataStreamAgent.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/DataStreamAgent.java deleted file mode 100644 index 860412b..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/DataStreamAgent.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.Agent; -import org.apache.flink.agents.api.Event; -import org.apache.flink.agents.api.InputEvent; -import org.apache.flink.agents.api.OutputEvent; -import org.apache.flink.agents.api.annotation.Action; -import org.apache.flink.agents.api.context.MemoryObject; -import org.apache.flink.agents.api.context.MemoryRef; -import org.apache.flink.agents.api.context.RunnerContext; -import org.apache.flink.agents.integration.test.DataStreamIntegrationExample.ItemData; - -/** - * A simple example agent used for explaining integrating agents with DataStream. - * - * <p>This agent processes input events by adding a prefix and a suffix to the input data, counting - * the number of visits, and emitting an output event. - */ -public class DataStreamAgent extends Agent { - - /** Custom event type for internal agent communication. */ - public static class ProcessedEvent extends Event { - private final MemoryRef itemRef; - - public ProcessedEvent(MemoryRef itemRef) { - this.itemRef = itemRef; - } - - public MemoryRef getItemRef() { - return itemRef; - } - } - - /** - * Action that processes incoming input events. - * - * @param event The input event to process - * @param ctx The runner context for sending events - */ - @Action(listenEvents = {InputEvent.class}) - public static void processInput(Event event, RunnerContext ctx) throws Exception { - InputEvent inputEvent = (InputEvent) event; - ItemData item = (ItemData) inputEvent.getInput(); - - // Get short-term memory and update the visit counter for the current key. - MemoryObject stm = ctx.getShortTermMemory(); - int currentCount = 0; - if (stm.isExist("visit_count")) { - currentCount = (int) stm.get("visit_count").getValue(); - } - int newCount = currentCount + 1; - stm.set("visit_count", newCount); - - // Send a custom event for further processing - MemoryRef itemRef = stm.set("input_data", item); - ctx.sendEvent(new ProcessedEvent(itemRef)); - } - - /** - * Action that handles processed events and generates output. - * - * @param event The processed event - * @param ctx The runner context for sending events - */ - @Action(listenEvents = {ProcessedEvent.class}) - public static void generateOutput(Event event, RunnerContext ctx) throws Exception { - ProcessedEvent processedEvent = (ProcessedEvent) event; - MemoryRef itemRef = processedEvent.getItemRef(); - - // Process the input data using short-term memory - MemoryObject stm = ctx.getShortTermMemory(); - ItemData originalData = (ItemData) stm.get(itemRef).getValue(); - originalData.visit_count = (int) stm.get("visit_count").getValue(); - - // Generate final output - String output = "Processed: " + originalData.toString() + " [Agent Complete]"; - ctx.sendEvent(new OutputEvent(output)); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/DataStreamIntegrationExample.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/DataStreamIntegrationExample.java deleted file mode 100644 index f87ec85..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/DataStreamIntegrationExample.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.AgentsExecutionEnvironment; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; - -/** - * Example demonstrating how to integrate Flink Agents with DataStream API. - * - * <p>This example shows how to: - * - * <ul> - * <li>Create a DataStream from a collection - * <li>Apply an agent to process the stream - * <li>Extract output as another DataStream - * <li>Execute the pipeline - * </ul> - */ -public class DataStreamIntegrationExample { - - /** Simple data class for the example. */ - public static class ItemData { - public final int id; - public final String name; - public final double value; - public int visit_count; - - public ItemData(int id, String name, double value) { - this.id = id; - this.name = name; - this.value = value; - this.visit_count = 0; - } - - @Override - public String toString() { - return String.format( - "ItemData{id=%d, name='%s', value=%.2f,visit_count=%d}", - id, name, value, visit_count); - } - } - - /** Key selector for extracting keys from ItemData. */ - public static class ItemKeySelector implements KeySelector<ItemData, Integer> { - @Override - public Integer getKey(ItemData item) { - return item.id; - } - } - - public static void main(String[] args) throws Exception { - // Create the execution environment - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(1); - - // Create input DataStream - DataStream<ItemData> inputStream = - env.fromElements( - new ItemData(1, "item1", 10.5), - new ItemData(2, "item2", 20.0), - new ItemData(3, "item3", 15.7), - new ItemData(1, "item1_updated", 12.3), - new ItemData(2, "item2_updated", 22.1), - new ItemData(1, "item1_updated_again", 15.3)); - - // Create agents execution environment - AgentsExecutionEnvironment agentsEnv = - AgentsExecutionEnvironment.getExecutionEnvironment(env); - - // Apply agent to the DataStream - DataStream<Object> outputStream = - agentsEnv - .fromDataStream(inputStream, new ItemKeySelector()) - .apply(new DataStreamAgent()) - .toDataStream(); - - // Print the results - outputStream.print(); - - // Execute the pipeline - agentsEnv.execute(); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/DataStreamTableIntegrationExample.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/DataStreamTableIntegrationExample.java deleted file mode 100644 index 625eb47..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/DataStreamTableIntegrationExample.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.AgentsExecutionEnvironment; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.api.Schema; -import org.apache.flink.table.api.Table; - -public class DataStreamTableIntegrationExample { - /** Simple data class for the example. */ - public static class ItemData { - public final int id; - public final String name; - public final double value; - public int visit_count; - - public ItemData(int id, String name, double value) { - this.id = id; - this.name = name; - this.value = value; - this.visit_count = 0; - } - - @Override - public String toString() { - return String.format( - "ItemData{id=%d, name='%s', value=%.2f,visit_count=%d}", - id, name, value, visit_count); - } - } - - /** Key selector for extracting keys from ItemData. */ - public static class ItemKeySelector - implements KeySelector<DataStreamIntegrationExample.ItemData, Integer> { - @Override - public Integer getKey(DataStreamIntegrationExample.ItemData item) { - return item.id; - } - } - - public static void main(String[] args) throws Exception { - // Create the execution environment - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(1); - - // Create input DataStream - DataStream<DataStreamIntegrationExample.ItemData> inputStream = - env.fromElements( - new DataStreamIntegrationExample.ItemData(1, "item1", 10.5), - new DataStreamIntegrationExample.ItemData(2, "item2", 20.0), - new DataStreamIntegrationExample.ItemData(3, "item3", 15.7), - new DataStreamIntegrationExample.ItemData(1, "item1_updated", 12.3), - new DataStreamIntegrationExample.ItemData(2, "item2_updated", 22.1), - new DataStreamIntegrationExample.ItemData(1, "item1_updated_again", 15.3)); - - // Create agents execution environment - AgentsExecutionEnvironment agentsEnv = - AgentsExecutionEnvironment.getExecutionEnvironment(env); - - // Define output schema - Schema outputSchema = Schema.newBuilder().column("f0", DataTypes.STRING()).build(); - - // Apply agent to the Table - Table outputTable = - agentsEnv - .fromDataStream( - inputStream, new DataStreamIntegrationExample.ItemKeySelector()) - .apply(new DataStreamAgent()) - .toTable(outputSchema); - - outputTable.execute().print(); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/SimpleAgent.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/SimpleAgent.java deleted file mode 100644 index 991162e..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/SimpleAgent.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.Agent; -import org.apache.flink.agents.api.Event; -import org.apache.flink.agents.api.InputEvent; -import org.apache.flink.agents.api.OutputEvent; -import org.apache.flink.agents.api.annotation.Action; -import org.apache.flink.agents.api.context.RunnerContext; - -/** - * A simple example agent that demonstrates basic agent functionality. - * - * <p>This agent processes input events by adding a prefix to the input data and emitting an output - * event. - */ -public class SimpleAgent extends Agent { - - /** Custom event type for internal agent communication. */ - public static class ProcessedEvent extends Event { - private final String processedData; - - public ProcessedEvent(String processedData) { - this.processedData = processedData; - } - - public String getProcessedData() { - return processedData; - } - } - - /** - * Action that processes incoming input events. - * - * @param event The input event to process - * @param ctx The runner context for sending events - */ - @Action(listenEvents = {InputEvent.class}) - public static void processInput(Event event, RunnerContext ctx) { - InputEvent inputEvent = (InputEvent) event; - Object input = inputEvent.getInput(); - - // Process the input data - String processedData = "Processed: " + input.toString(); - - // Send a custom event for further processing - ctx.sendEvent(new ProcessedEvent(processedData)); - } - - /** - * Action that handles processed events and generates output. - * - * @param event The processed event - * @param ctx The runner context for sending events - */ - @Action(listenEvents = {ProcessedEvent.class}) - public static void generateOutput(Event event, RunnerContext ctx) { - ProcessedEvent processedEvent = (ProcessedEvent) event; - String processedData = processedEvent.getProcessedData(); - - // Generate final output - String output = processedData + " [Agent Complete]"; - ctx.sendEvent(new OutputEvent(output)); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/TableAgent.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/TableAgent.java deleted file mode 100644 index 8582b17..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/TableAgent.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.Agent; -import org.apache.flink.agents.api.Event; -import org.apache.flink.agents.api.InputEvent; -import org.apache.flink.agents.api.OutputEvent; -import org.apache.flink.agents.api.annotation.Action; -import org.apache.flink.agents.api.context.MemoryObject; -import org.apache.flink.agents.api.context.MemoryRef; -import org.apache.flink.agents.api.context.RunnerContext; - -/** - * A simple example agent used for explaining integrating agents with DataStream. - * - * <p>This agent processes input events by adding a prefix and a suffix to the input data, counting - * the number of visits, and emitting an output event. - */ -public class TableAgent extends Agent { - - /** Custom event type for internal agent communication. */ - public static class ProcessedEvent extends Event { - private final MemoryRef inputRef; - - public ProcessedEvent(MemoryRef inputRef) { - this.inputRef = inputRef; - } - - public MemoryRef getInputRef() { - return inputRef; - } - } - - /** - * Action that processes incoming input events. - * - * @param event The input event to process - * @param ctx The runner context for sending events - */ - @Action(listenEvents = {InputEvent.class}) - public static void processInput(Event event, RunnerContext ctx) throws Exception { - InputEvent inputEvent = (InputEvent) event; - Object input = inputEvent.getInput(); - - // Get short-term memory and update the visit counter for the current key. - MemoryObject stm = ctx.getShortTermMemory(); - int currentCount = 0; - if (stm.isExist("visit_count")) { - currentCount = (int) stm.get("visit_count").getValue(); - } - int newCount = currentCount + 1; - stm.set("visit_count", newCount); - - // Send a custom event with the original input and the new count. - MemoryRef inputRef = stm.set("input_data", input); - ctx.sendEvent(new ProcessedEvent(inputRef)); - } - - /** - * Action that handles processed events and generates output. - * - * @param event The processed event - * @param ctx The runner context for sending events - */ - @Action(listenEvents = {ProcessedEvent.class}) - public static void generateOutput(Event event, RunnerContext ctx) throws Exception { - ProcessedEvent processedEvent = (ProcessedEvent) event; - MemoryRef inputRef = processedEvent.getInputRef(); - - // Get input data and visitCount using short-term memory - MemoryObject stm = ctx.getShortTermMemory(); - Object originalInput = stm.get(inputRef).getValue(); - int visitCount = (int) stm.get("visit_count").getValue(); - - // Generate final output - String output = - String.format( - "Processed: %s, visit_count=%d [Agent Complete]", - originalInput.toString(), visitCount); - ctx.sendEvent(new OutputEvent(output)); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/TableIntegrationExample.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/TableIntegrationExample.java deleted file mode 100644 index 89835d7..0000000 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/TableIntegrationExample.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.agents.integration.test; - -import org.apache.flink.agents.api.AgentsExecutionEnvironment; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.api.Schema; -import org.apache.flink.table.api.Table; -import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; -import org.apache.flink.types.Row; - -/** - * Example demonstrating how to integrate Flink Agents with Table API. - * - * <p>This example shows how to: - * - * <ul> - * <li>Create a Table from sample data - * <li>Apply an agent to process the table - * <li>Extract output as another Table - * <li>Execute the pipeline - * </ul> - */ -public class TableIntegrationExample { - - /** Key selector for extracting keys from Row objects. */ - public static class RowKeySelector implements KeySelector<Object, Integer> { - @Override - public Integer getKey(Object value) { - if (value instanceof Row) { - Row row = (Row) value; - return (Integer) row.getField(0); // Assuming first field is the ID - } - return 0; - } - } - - public static void main(String[] args) throws Exception { - // Create the execution environment - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(1); - - // Create the table environment - StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env); - tableEnv.getConfig().set("table.exec.result.display.max-column-width", "100"); - - // Create input table from sample data - Table inputTable = - tableEnv.fromValues( - DataTypes.ROW( - DataTypes.FIELD("id", DataTypes.INT()), - DataTypes.FIELD("name", DataTypes.STRING()), - DataTypes.FIELD("score", DataTypes.DOUBLE())), - Row.of(1, "Alice", 85.5), - Row.of(2, "Bob", 92.0), - Row.of(3, "Charlie", 78.3), - Row.of(1, "Alice", 87.2), - Row.of(2, "Bob", 94.1), - Row.of(1, "Alice", 90.3)); - - // Create agents execution environment - AgentsExecutionEnvironment agentsEnv = - AgentsExecutionEnvironment.getExecutionEnvironment(env, tableEnv); - - // Define output schema - Schema outputSchema = Schema.newBuilder().column("f0", DataTypes.STRING()).build(); - - // Apply agent to the Table - Table outputTable = - agentsEnv - .fromTable(inputTable, new RowKeySelector()) - .apply(new TableAgent()) - .toTable(outputSchema); - - // Print the results to fully display the data - tableEnv.toDataStream(outputTable).print(); - env.execute(); - // Print the results in table format - outputTable.execute().print(); - } -} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllama.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java similarity index 67% rename from e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllama.java rename to e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java index d651db4..254041c 100644 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllama.java +++ b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java @@ -28,12 +28,13 @@ import org.apache.flink.agents.api.annotation.Tool; import org.apache.flink.agents.api.annotation.ToolParam; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; -import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.event.ChatRequestEvent; import org.apache.flink.agents.api.event.ChatResponseEvent; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.integrations.chatmodels.azureai.AzureAIChatModelConnection; +import org.apache.flink.agents.integrations.chatmodels.azureai.AzureAIChatModelSetup; import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelConnection; import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelSetup; @@ -61,23 +62,52 @@ import java.util.List; * resource is configured with the connection name, the model name and the list of tool names that * the model is allowed to call. */ -public class AgentWithOllama extends Agent { +public class ChatModelIntegrationAgent extends Agent { + public static final String OLLAMA_MODEL = "qwen3:0.6b"; + @ChatModelConnection - public static ResourceDescriptor ollamaChatModelConnection() { - return ResourceDescriptor.Builder.newBuilder(OllamaChatModelConnection.class.getName()) - .addInitialArgument("endpoint", "http://localhost:11434") - .build(); + public static ResourceDescriptor chatModelConnection() { + String provider = System.getProperty("MODEL_PROVIDER", "OLLAMA"); + if (provider.equals("OLLAMA")) { + return ResourceDescriptor.Builder.newBuilder(OllamaChatModelConnection.class.getName()) + .addInitialArgument("endpoint", "http://localhost:11434") + .addInitialArgument("requestTimeout", 240) + .build(); + } else if (provider.equals("AZURE")) { + String endpoint = System.getenv().get("AZURE_ENDPOINT"); + String apiKey = System.getenv().get("AZURE_API_KEY"); + return ResourceDescriptor.Builder.newBuilder(AzureAIChatModelConnection.class.getName()) + .addInitialArgument("endpoint", endpoint) + .addInitialArgument("apiKey", apiKey) + .build(); + } else { + throw new RuntimeException(String.format("Unknown model provider %s", provider)); + } } @ChatModelSetup - public static ResourceDescriptor ollamaChatModel() { - return ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName()) - .addInitialArgument("connection", "ollamaChatModelConnection") - .addInitialArgument("model", "qwen3:8b") - .addInitialArgument( - "tools", - List.of("calculateBMI", "convertTemperature", "createRandomNumber")) - .build(); + public static ResourceDescriptor chatModel() { + String provider = System.getProperty("MODEL_PROVIDER", "OLLAMA"); + + if (provider.equals("OLLAMA")) { + return ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName()) + .addInitialArgument("connection", "chatModelConnection") + .addInitialArgument("model", OLLAMA_MODEL) + .addInitialArgument( + "tools", + List.of("calculateBMI", "convertTemperature", "createRandomNumber")) + .build(); + } else if (provider.equals("AZURE")) { + return ResourceDescriptor.Builder.newBuilder(AzureAIChatModelSetup.class.getName()) + .addInitialArgument("connection", "chatModelConnection") + .addInitialArgument("model", "gpt-4o") + .addInitialArgument( + "tools", + List.of("calculateBMI", "convertTemperature", "createRandomNumber")) + .build(); + } else { + throw new RuntimeException(String.format("Unknown model provider %s", provider)); + } } @Tool(description = "Converts temperature between Celsius and Fahrenheit") @@ -126,11 +156,9 @@ public class AgentWithOllama extends Agent { @Action(listenEvents = {InputEvent.class}) public static void process(InputEvent event, RunnerContext ctx) throws Exception { - BaseChatModelSetup chatModel = - (BaseChatModelSetup) ctx.getResource("ollamaChatModel", ResourceType.CHAT_MODEL); ctx.sendEvent( new ChatRequestEvent( - "ollamaChatModel", + "chatModel", Collections.singletonList( new ChatMessage(MessageRole.USER, (String) event.getInput())))); } diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithAzureAIExample.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java similarity index 52% rename from e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithAzureAIExample.java rename to e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java index bb56a21..261e646 100644 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithAzureAIExample.java +++ b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java @@ -15,24 +15,53 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.agents.integration.test; import org.apache.flink.agents.api.AgentsExecutionEnvironment; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; + +import static org.apache.flink.agents.integration.test.ChatModelIntegrationAgent.OLLAMA_MODEL; + +/** + * Example application that applies {@link ChatModelIntegrationAgent} to a DataStream of user + * prompts. + */ +public class ChatModelIntegrationTest extends OllamaPreparationUtils { + private static final Logger LOG = LoggerFactory.getLogger(ChatModelIntegrationTest.class); + + private static final String API_KEY = "_API_KEY"; + private static final String OLLAMA = "OLLAMA"; + + private final boolean ollamaReady; + + public ChatModelIntegrationTest() throws IOException { + ollamaReady = pullModel(OLLAMA_MODEL); + } + + @ParameterizedTest() + @ValueSource(strings = {"OLLAMA", "AZURE"}) + public void testChatModeIntegration(String provider) throws Exception { + Assumptions.assumeTrue( + (OLLAMA.equals(provider) && ollamaReady) + || System.getenv().get(provider + API_KEY) != null, + String.format( + "Server or authentication information is not provided for %s", provider)); + + System.setProperty("MODEL_PROVIDER", provider); -public class AgentWithAzureAIExample { - /** Runs the example pipeline. */ - public static void main(String[] args) throws Exception { - if (!AgentWithAzureAI.callingRealMode()) { - // print warning information - System.err.println( - "Please set the AZURE_ENDPOINT and AZURE_API_KEY in the AgentWithAzureAI class to run this example in real mode."); - System.err.println("Falling back to mock mode."); - AgentWithResourceExample.main(args); - return; - } // Create the execution environment StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); @@ -59,13 +88,33 @@ public class AgentWithAzureAIExample { DataStream<Object> outputStream = agentsEnv .fromDataStream(inputStream, (KeySelector<String, String>) value -> value) - .apply(new AgentWithAzureAI()) + .apply(new ChatModelIntegrationAgent()) .toDataStream(); - // Print the results - outputStream.print(); + // Collect the results + CloseableIterator<Object> results = outputStream.collectAsync(); // Execute the pipeline agentsEnv.execute(); + + checkResult(results); + } + + public void checkResult(CloseableIterator<Object> results) { + List<String> expectedWords = + List.of(" 77", "37", "89", "23", "68", "22", "26", "22", "23", ""); + for (String expected : expectedWords) { + Assertions.assertTrue( + results.hasNext(), "Output messages count %s is less than expected."); + String res = (String) results.next(); + if (res.contains("error") || res.contains("parameters")) { + LOG.warn(res); + } else { + Assertions.assertTrue( + res.contains(expected), + String.format( + "Groud truth %s is not contained in answer {%s}", expected, res)); + } + } } } diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbedding.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/EmbeddingIntegrationAgent.java similarity index 87% rename from e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbedding.java rename to e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/EmbeddingIntegrationAgent.java index a2a9c64..e20354b 100644 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbedding.java +++ b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/EmbeddingIntegrationAgent.java @@ -44,24 +44,35 @@ import java.util.Map; * * <p>Used for e2e testing of the embedding model subsystem. */ -public class AgentWithOllamaEmbedding extends Agent { - +public class EmbeddingIntegrationAgent extends Agent { + public static final String OLLAMA_MODEL = "nomic-embed-text"; private static final ObjectMapper MAPPER = new ObjectMapper(); @EmbeddingModelConnection - public static ResourceDescriptor ollamaEmbeddingConnection() { - return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelConnection.class.getName()) - .addInitialArgument("host", "http://localhost:11434") - .addInitialArgument("timeout", 60) - .build(); + public static ResourceDescriptor embeddingConnection() { + String provider = System.getProperty("MODEL_PROVIDER", "OLLAMA"); + if (provider.equals("OLLAMA")) { + return ResourceDescriptor.Builder.newBuilder( + OllamaEmbeddingModelConnection.class.getName()) + .addInitialArgument("host", "http://localhost:11434") + .addInitialArgument("timeout", 60) + .build(); + } else { + throw new RuntimeException(String.format("Unknown model provider %s", provider)); + } } @EmbeddingModelSetup public static ResourceDescriptor embeddingModel() { - return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelSetup.class.getName()) - .addInitialArgument("connection", "ollamaEmbeddingConnection") - .addInitialArgument("model", "nomic-embed-text") - .build(); + String provider = System.getProperty("MODEL_PROVIDER", "OLLAMA"); + if (provider.equals("OLLAMA")) { + return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "embeddingConnection") + .addInitialArgument("model", OLLAMA_MODEL) + .build(); + } else { + throw new RuntimeException(String.format("Unknown model provider %s", provider)); + } } /** Test tool for validating embedding storage operations. */ diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbeddingExample.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/EmbeddingIntegrationTest.java similarity index 56% rename from e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbeddingExample.java rename to e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/EmbeddingIntegrationTest.java index 3c12c8c..baaf2f6 100644 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbeddingExample.java +++ b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/EmbeddingIntegrationTest.java @@ -22,11 +22,42 @@ import org.apache.flink.agents.api.AgentsExecutionEnvironment; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.io.IOException; +import java.util.Map; + +import static org.apache.flink.agents.integration.test.EmbeddingIntegrationAgent.OLLAMA_MODEL; +import static org.apache.flink.agents.integration.test.OllamaPreparationUtils.pullModel; + +/** + * Example application that applies {@link EmbeddingIntegrationAgent} to a DataStream of prompts. + */ +public class EmbeddingIntegrationTest { + private static final String API_KEY = "_API_KEY"; + private static final String OLLAMA = "OLLAMA"; + + private final boolean ollamaReady; + + public EmbeddingIntegrationTest() throws IOException { + ollamaReady = pullModel(OLLAMA_MODEL); + } + + @ParameterizedTest() + @ValueSource(strings = {"OLLAMA"}) + public void testEmbeddingIntegration(String provider) throws Exception { + Assumptions.assumeTrue( + (OLLAMA.equals(provider) && ollamaReady) + || System.getenv().get(provider + API_KEY) != null, + String.format( + "Server or authentication information is not provided for %s", provider)); + + System.setProperty("MODEL_PROVIDER", provider); -/** Example application that applies {@link AgentWithOllamaEmbedding} to a DataStream of prompts. */ -public class AgentWithOllamaEmbeddingExample { - /** Runs the example pipeline. */ - public static void main(String[] args) throws Exception { // Create the execution environment StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); @@ -53,13 +84,26 @@ public class AgentWithOllamaEmbeddingExample { DataStream<Object> outputStream = agentsEnv .fromDataStream(inputStream, (KeySelector<String, String>) value -> value) - .apply(new AgentWithOllamaEmbedding()) + .apply(new EmbeddingIntegrationAgent()) .toDataStream(); - // Print the results - outputStream.print(); + // Collect the results + CloseableIterator<Object> results = outputStream.collectAsync(); // Execute the pipeline agentsEnv.execute(); + + checkResult(results); + } + + @SuppressWarnings("unchecked") + private void checkResult(CloseableIterator<Object> results) { + for (int i = 1; i <= 10; i++) { + Assertions.assertTrue( + results.hasNext(), + String.format("Output messages count %s is less than expected 10.", i)); + Map<String, Object> res = (Map<String, Object>) results.next(); + Assertions.assertEquals("PASSED", res.get("test_status")); + } } } diff --git a/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/FlinkIntegrationAgent.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/FlinkIntegrationAgent.java new file mode 100644 index 0000000..9aaf135 --- /dev/null +++ b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/FlinkIntegrationAgent.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integration.test; + +import org.apache.flink.agents.api.Agent; +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.context.MemoryObject; +import org.apache.flink.agents.api.context.MemoryRef; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.api.java.functions.KeySelector; + +/** Agent definition for {@link FlinkIntegrationTest} */ +public class FlinkIntegrationAgent { + /** Simple data class for the example. */ + public static class ItemData { + public final int id; + public final String name; + public final double value; + public int visit_count; + + public ItemData(int id, String name, double value) { + this.id = id; + this.name = name; + this.value = value; + this.visit_count = 0; + } + + @Override + public String toString() { + return String.format( + "ItemData{id=%d, name='%s', value=%.2f,visit_count=%d}", + id, name, value, visit_count); + } + } + + /** Key selector for extracting keys from ItemData. */ + public static class ItemKeySelector implements KeySelector<ItemData, Integer> { + @Override + public Integer getKey(ItemData item) { + return item.id; + } + } + + /** Custom event type for internal agent communication. */ + public static class ProcessedEvent extends Event { + private final MemoryRef itemRef; + + public ProcessedEvent(MemoryRef itemRef) { + this.itemRef = itemRef; + } + + public MemoryRef getItemRef() { + return itemRef; + } + } + + /** + * A simple example agent used for explaining integrating agents with DataStream. + * + * <p>This agent processes input events by adding a prefix and a suffix to the input data, + * counting the number of visits, and emitting an output event. + */ + public static class DataStreamAgent extends Agent { + + /** + * Action that processes incoming input events. + * + * @param event The input event to process + * @param ctx The runner context for sending events + */ + @Action(listenEvents = {InputEvent.class}) + public static void processInput(Event event, RunnerContext ctx) throws Exception { + InputEvent inputEvent = (InputEvent) event; + ItemData item = (ItemData) inputEvent.getInput(); + + // Get short-term memory and update the visit counter for the current key. + MemoryObject stm = ctx.getShortTermMemory(); + int currentCount = 0; + if (stm.isExist("visit_count")) { + currentCount = (int) stm.get("visit_count").getValue(); + } + int newCount = currentCount + 1; + stm.set("visit_count", newCount); + + // Send a custom event for further processing + MemoryRef itemRef = stm.set("input_data", item); + ctx.sendEvent(new ProcessedEvent(itemRef)); + } + + /** + * Action that handles processed events and generates output. + * + * @param event The processed event + * @param ctx The runner context for sending events + */ + @Action(listenEvents = {ProcessedEvent.class}) + public static void generateOutput(Event event, RunnerContext ctx) throws Exception { + ProcessedEvent processedEvent = (ProcessedEvent) event; + MemoryRef itemRef = processedEvent.getItemRef(); + + // Process the input data using short-term memory + MemoryObject stm = ctx.getShortTermMemory(); + ItemData originalData = (ItemData) stm.get(itemRef).getValue(); + originalData.visit_count = (int) stm.get("visit_count").getValue(); + + // Generate final output + String output = "Processed: " + originalData.toString() + " [Agent Complete]"; + ctx.sendEvent(new OutputEvent(output)); + } + } + + /** + * A simple example agent used for explaining integrating agents with DataStream. + * + * <p>This agent processes input events by adding a prefix and a suffix to the input data, + * counting the number of visits, and emitting an output event. + */ + public static class TableAgent extends Agent { + /** + * Action that processes incoming input events. + * + * @param event The input event to process + * @param ctx The runner context for sending events + */ + @Action(listenEvents = {InputEvent.class}) + public static void processInput(Event event, RunnerContext ctx) throws Exception { + InputEvent inputEvent = (InputEvent) event; + Object input = inputEvent.getInput(); + + // Get short-term memory and update the visit counter for the current key. + MemoryObject stm = ctx.getShortTermMemory(); + int currentCount = 0; + if (stm.isExist("visit_count")) { + currentCount = (int) stm.get("visit_count").getValue(); + } + int newCount = currentCount + 1; + stm.set("visit_count", newCount); + + // Send a custom event with the original input and the new count. + MemoryRef inputRef = stm.set("input_data", input); + ctx.sendEvent(new ProcessedEvent(inputRef)); + } + + /** + * Action that handles processed events and generates output. + * + * @param event The processed event + * @param ctx The runner context for sending events + */ + @Action(listenEvents = {ProcessedEvent.class}) + public static void generateOutput(Event event, RunnerContext ctx) throws Exception { + ProcessedEvent processedEvent = (ProcessedEvent) event; + MemoryRef inputRef = processedEvent.getItemRef(); + + // Get input data and visitCount using short-term memory + MemoryObject stm = ctx.getShortTermMemory(); + Object originalInput = stm.get(inputRef).getValue(); + int visitCount = (int) stm.get("visit_count").getValue(); + + // Generate final output + String output = + String.format( + "Processed: %s, visit_count=%d [Agent Complete]", + originalInput.toString(), visitCount); + ctx.sendEvent(new OutputEvent(output)); + } + } +} diff --git a/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/FlinkIntegrationTest.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/FlinkIntegrationTest.java new file mode 100644 index 0000000..11c677a --- /dev/null +++ b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/FlinkIntegrationTest.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integration.test; + +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Example demonstrating how to integrate Flink Agents with Table API. + * + * <p>This example shows how to: + * + * <ul> + * <li>Create a Table from sample data + * <li>Apply an agent to process the table + * <li>Extract output as another Table + * <li>Execute the pipeline + * </ul> + */ +public class FlinkIntegrationTest { + + /** Key selector for extracting keys from Row objects. */ + public static class RowKeySelector implements KeySelector<Object, Integer> { + @Override + public Integer getKey(Object value) { + if (value instanceof Row) { + Row row = (Row) value; + return (Integer) row.getField(0); // Assuming first field is the ID + } + return 0; + } + } + + @Test + public void testFromDataStreamToDataStream() throws Exception { + // Create the execution environment + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Create input DataStream + DataStream<FlinkIntegrationAgent.ItemData> inputStream = + env.fromElements( + new FlinkIntegrationAgent.ItemData(1, "item1", 10.5), + new FlinkIntegrationAgent.ItemData(2, "item2", 20.0), + new FlinkIntegrationAgent.ItemData(3, "item3", 15.7), + new FlinkIntegrationAgent.ItemData(1, "item1_updated", 12.3), + new FlinkIntegrationAgent.ItemData(2, "item2_updated", 22.1), + new FlinkIntegrationAgent.ItemData(1, "item1_updated_again", 15.3)); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // Apply agent to the DataStream + DataStream<Object> outputStream = + agentsEnv + .fromDataStream(inputStream, new FlinkIntegrationAgent.ItemKeySelector()) + .apply(new FlinkIntegrationAgent.DataStreamAgent()) + .toDataStream(); + + // Collect the results + CloseableIterator<Object> results = outputStream.collectAsync(); + + // Execute the pipeline + agentsEnv.execute(); + + checkResult(results, "test_from_datastream_to_datastream.txt"); + } + + @Test + public void testFromTableToTable() throws Exception { + // Create the execution environment + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Create the table environment + StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env); + tableEnv.getConfig().set("table.exec.result.display.max-column-width", "100"); + + // Create input table from sample data + Table inputTable = + tableEnv.fromValues( + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.INT()), + DataTypes.FIELD("name", DataTypes.STRING()), + DataTypes.FIELD("score", DataTypes.DOUBLE())), + Row.of(1, "Alice", 85.5), + Row.of(2, "Bob", 92.0), + Row.of(3, "Charlie", 78.3), + Row.of(1, "Alice", 87.2), + Row.of(2, "Bob", 94.1), + Row.of(1, "Alice", 90.3)); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env, tableEnv); + + // Define output schema + Schema outputSchema = Schema.newBuilder().column("f0", DataTypes.STRING()).build(); + + // Apply agent to the Table + Table outputTable = + agentsEnv + .fromTable(inputTable, new RowKeySelector()) + .apply(new FlinkIntegrationAgent.TableAgent()) + .toTable(outputSchema); + + // Collect the results in table format + CloseableIterator<Row> results = outputTable.execute().collect(); + + checkResult(results, "test_from_table_to_table.txt"); + } + + @Test + public void testFromDataStreamToTable() throws Exception { + // Create the execution environment + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Create input DataStream + DataStream<FlinkIntegrationAgent.ItemData> inputStream = + env.fromElements( + new FlinkIntegrationAgent.ItemData(1, "item1", 10.5), + new FlinkIntegrationAgent.ItemData(2, "item2", 20.0), + new FlinkIntegrationAgent.ItemData(3, "item3", 15.7), + new FlinkIntegrationAgent.ItemData(1, "item1_updated", 12.3), + new FlinkIntegrationAgent.ItemData(2, "item2_updated", 22.1), + new FlinkIntegrationAgent.ItemData(1, "item1_updated_again", 15.3)); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // Define output schema + Schema outputSchema = Schema.newBuilder().column("f0", DataTypes.STRING()).build(); + + // Apply agent to the Table + Table outputTable = + agentsEnv + .fromDataStream(inputStream, new FlinkIntegrationAgent.ItemKeySelector()) + .apply(new FlinkIntegrationAgent.DataStreamAgent()) + .toTable(outputSchema); + + // Collect the results in table format + CloseableIterator<Row> results = outputTable.execute().collect(); + + checkResult(results, "test_from_datastream_to_table.txt"); + } + + private void checkResult(CloseableIterator<?> results, String fileName) throws IOException { + String path = + Objects.requireNonNull( + getClass().getClassLoader().getResource("ground-truth/" + fileName)) + .getPath(); + List<String> expected = Files.readAllLines(Path.of(path)); + expected.sort(Comparator.naturalOrder()); + + List<String> actual = new ArrayList<>(); + while (results.hasNext()) { + actual.add(results.next().toString()); + } + actual.sort(Comparator.naturalOrder()); + + Assertions.assertEquals( + expected.size(), actual.size(), "Output messages count is not same as expected"); + for (int i = 0; i < expected.size(); i++) { + Assertions.assertEquals(expected.get(i), actual.get(i)); + } + } +} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java similarity index 100% rename from e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java rename to e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectExample.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/MemoryObjectTest.java similarity index 75% rename from e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectExample.java rename to e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/MemoryObjectTest.java index a0e2304..4a6e1d1 100644 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectExample.java +++ b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/MemoryObjectTest.java @@ -21,14 +21,18 @@ import org.apache.flink.agents.api.AgentsExecutionEnvironment; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; /** * Example to test MemoryObject in a complete Java Flink execution environment. This job triggers * the {@link MemoryObjectAgent} to test storing and retrieving complex data structures. */ -public class MemoryObjectExample { +public class MemoryObjectTest { - public static void main(String[] args) throws Exception { + @Test + public void testMemoryObject() throws Exception { // Create the execution environment StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); @@ -47,10 +51,20 @@ public class MemoryObjectExample { .apply(new MemoryObjectAgent()) .toDataStream(); - // Print the results - outputStream.print(); + // Collect the results + CloseableIterator<Object> results = outputStream.collectAsync(); // Execute the pipeline agentsEnv.execute(); } + + private void checkResult(CloseableIterator<Object> results) { + for (int i = 0; i < 4; i++) { + Assertions.assertTrue( + results.hasNext(), + String.format("Output messages count %s is less than expected 4.", i)); + String res = (String) results.next(); + Assertions.assertTrue(res.contains("All assertions passed")); + } + } } diff --git a/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/OllamaPreparationUtils.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/OllamaPreparationUtils.java new file mode 100644 index 0000000..c236605 --- /dev/null +++ b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/OllamaPreparationUtils.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integration.test; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +public class OllamaPreparationUtils { + private static final Logger LOG = LoggerFactory.getLogger(OllamaPreparationUtils.class); + + public static boolean pullModel(String model) throws IOException { + String path = + Objects.requireNonNull( + ChatModelIntegrationTest.class + .getClassLoader() + .getResource("ollama_pull_model.sh")) + .getPath(); + ProcessBuilder builder = new ProcessBuilder("bash", path, model); + Process process = builder.start(); + try { + process.waitFor(120, TimeUnit.SECONDS); + return process.exitValue() == 0; + } catch (Exception e) { + LOG.warn("Pull {} failed, will skip test", model); + } + return false; + } +} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/ReActAgentExample.java b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java similarity index 77% rename from e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/ReActAgentExample.java rename to e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java index 09bd4dd..3967dfd 100644 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/ReActAgentExample.java +++ b/e2e-test/integration-test/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java @@ -42,10 +42,19 @@ import org.apache.flink.table.api.Schema; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import java.io.IOException; import java.util.List; -public class ReActAgentExample { +import static org.apache.flink.agents.integration.test.OllamaPreparationUtils.pullModel; + +public class ReActAgentTest { + public static final String OLLAMA_MODEL = "qwen3:1.7b"; + @org.apache.flink.agents.api.annotation.Tool( description = "Useful function to add two numbers.") public static double add(@ToolParam(name = "a") Double a, @ToolParam(name = "b") Double b) { @@ -59,8 +68,15 @@ public class ReActAgentExample { return a * b; } - /** Runs the example pipeline. */ - public static void main(String[] args) throws Exception { + private final boolean ollamaReady; + + public ReActAgentTest() throws IOException { + ollamaReady = pullModel(OLLAMA_MODEL); + } + + @Test + public void testReActAgent() throws Exception { + Assumptions.assumeTrue(ollamaReady, String.format("%s is not ready", OLLAMA_MODEL)); StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); @@ -80,18 +96,18 @@ public class ReActAgentExample { ResourceDescriptor.Builder.newBuilder( OllamaChatModelConnection.class.getName()) .addInitialArgument("endpoint", "http://localhost:11434") + .addInitialArgument("requestTimeout", 240) .build()) .addResource( "add", ResourceType.TOOL, Tool.fromMethod( - ReActAgentExample.class.getMethod( - "add", Double.class, Double.class))) + ReActAgentTest.class.getMethod("add", Double.class, Double.class))) .addResource( "multiply", ResourceType.TOOL, Tool.fromMethod( - ReActAgentExample.class.getMethod( + ReActAgentTest.class.getMethod( "multiply", Double.class, Double.class))); agentsEnv @@ -110,7 +126,7 @@ public class ReActAgentExample { DataTypes.FIELD("a", DataTypes.DOUBLE()), DataTypes.FIELD("b", DataTypes.DOUBLE()), DataTypes.FIELD("c", DataTypes.DOUBLE())), - Row.of(1, 2, 3)); + Row.of(2131, 29847, 3)); // Define output schema Schema outputSchema = @@ -128,11 +144,15 @@ public class ReActAgentExample { .apply(agent) .toTable(outputSchema); - // Print the results to fully display the data - tableEnv.toDataStream(outputTable) - .map((MapFunction<Row, Row>) x -> (Row) x.getField("f0")) - .print(); + // Collect the results to fully display the data + CloseableIterator<Row> results = + tableEnv.toDataStream(outputTable) + .map((MapFunction<Row, Row>) x -> (Row) x.getField("f0")) + .collectAsync(); + env.execute(); + + checkResult(results); } // create ReAct agent. @@ -140,7 +160,7 @@ public class ReActAgentExample { ResourceDescriptor chatModelDescriptor = ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName()) .addInitialArgument("connection", "ollama") - .addInitialArgument("model", "qwen3:8b") + .addInitialArgument("model", OLLAMA_MODEL) .addInitialArgument("tools", List.of("add", "multiply")) .addInitialArgument("extract_reasoning", "true") .build(); @@ -148,14 +168,25 @@ public class ReActAgentExample { Prompt prompt = new Prompt( List.of( + new ChatMessage( + MessageRole.SYSTEM, + "Must call function tool to do the calculate."), new ChatMessage( MessageRole.SYSTEM, "An example of output is {\"result\": 30.32}"), - new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}"))); + new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}."))); RowTypeInfo outputTypeInfo = new RowTypeInfo( new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO}, new String[] {"result"}); return new ReActAgent(chatModelDescriptor, prompt, outputTypeInfo); } + + private void checkResult(CloseableIterator<?> results) { + Assertions.assertTrue( + results.hasNext(), + "This may be caused by the LLM response does not match the output schema, you can rerun this case."); + Row res = (Row) results.next(); + Assertions.assertEquals("+I[95934.0]", res.toString()); + } } diff --git a/e2e-test/integration-test/src/test/resources/ground-truth/test_from_datastream_to_datastream.txt b/e2e-test/integration-test/src/test/resources/ground-truth/test_from_datastream_to_datastream.txt new file mode 100644 index 0000000..fe5c029 --- /dev/null +++ b/e2e-test/integration-test/src/test/resources/ground-truth/test_from_datastream_to_datastream.txt @@ -0,0 +1,6 @@ +Processed: ItemData{id=1, name='item1', value=10.50,visit_count=1} [Agent Complete] +Processed: ItemData{id=2, name='item2', value=20.00,visit_count=1} [Agent Complete] +Processed: ItemData{id=3, name='item3', value=15.70,visit_count=1} [Agent Complete] +Processed: ItemData{id=1, name='item1_updated', value=12.30,visit_count=2} [Agent Complete] +Processed: ItemData{id=2, name='item2_updated', value=22.10,visit_count=2} [Agent Complete] +Processed: ItemData{id=1, name='item1_updated_again', value=15.30,visit_count=3} [Agent Complete] \ No newline at end of file diff --git a/e2e-test/integration-test/src/test/resources/ground-truth/test_from_datastream_to_table.txt b/e2e-test/integration-test/src/test/resources/ground-truth/test_from_datastream_to_table.txt new file mode 100644 index 0000000..ffc985a --- /dev/null +++ b/e2e-test/integration-test/src/test/resources/ground-truth/test_from_datastream_to_table.txt @@ -0,0 +1,6 @@ ++I[Processed: ItemData{id=1, name='item1', value=10.50,visit_count=1} [Agent Complete]] ++I[Processed: ItemData{id=2, name='item2', value=20.00,visit_count=1} [Agent Complete]] ++I[Processed: ItemData{id=3, name='item3', value=15.70,visit_count=1} [Agent Complete]] ++I[Processed: ItemData{id=1, name='item1_updated', value=12.30,visit_count=2} [Agent Complete]] ++I[Processed: ItemData{id=2, name='item2_updated', value=22.10,visit_count=2} [Agent Complete]] ++I[Processed: ItemData{id=1, name='item1_updated_again', value=15.30,visit_count=3} [Agent Complete]] \ No newline at end of file diff --git a/e2e-test/integration-test/src/test/resources/ground-truth/test_from_table_to_table.txt b/e2e-test/integration-test/src/test/resources/ground-truth/test_from_table_to_table.txt new file mode 100644 index 0000000..229733f --- /dev/null +++ b/e2e-test/integration-test/src/test/resources/ground-truth/test_from_table_to_table.txt @@ -0,0 +1,6 @@ ++I[Processed: +I[1, Alice, 85.5], visit_count=1 [Agent Complete]] ++I[Processed: +I[2, Bob, 92.0], visit_count=1 [Agent Complete]] ++I[Processed: +I[3, Charlie, 78.3], visit_count=1 [Agent Complete]] ++I[Processed: +I[1, Alice, 87.2], visit_count=2 [Agent Complete]] ++I[Processed: +I[2, Bob, 94.1], visit_count=2 [Agent Complete]] ++I[Processed: +I[1, Alice, 90.3], visit_count=3 [Agent Complete]] \ No newline at end of file diff --git a/e2e-test/integration-test/src/main/resources/log4j2.properties b/e2e-test/integration-test/src/test/resources/log4j2-test.properties similarity index 57% copy from e2e-test/integration-test/src/main/resources/log4j2.properties copy to e2e-test/integration-test/src/test/resources/log4j2-test.properties index 9206863..487984f 100644 --- a/e2e-test/integration-test/src/main/resources/log4j2.properties +++ b/e2e-test/integration-test/src/test/resources/log4j2-test.properties @@ -1,4 +1,4 @@ -################################################################################ +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -7,19 +7,22 @@ # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. -################################################################################ +# limitations under the License. +# -rootLogger.level = INFO -rootLogger.appenderRef.console.ref = ConsoleAppender +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level = OFF +rootLogger.appenderRef.test.ref = TestLogger -appender.console.name = ConsoleAppender -appender.console.type = CONSOLE -appender.console.layout.type = PatternLayout -appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss,SSS} %-5p %-60c %x - %m%n +appender.testlogger.name = TestLogger +appender.testlogger.type = CONSOLE +appender.testlogger.target = SYSTEM_ERR +appender.testlogger.layout.type = PatternLayout +appender.testlogger.layout.pattern = %-4r %d{yyyy-MM-dd HH:mm:ss.SSS} [%t] %-5p %c %x - %m%n diff --git a/e2e-test/integration-test/src/main/resources/log4j2.properties b/e2e-test/integration-test/src/test/resources/ollama_pull_model.sh similarity index 77% rename from e2e-test/integration-test/src/main/resources/log4j2.properties rename to e2e-test/integration-test/src/test/resources/ollama_pull_model.sh index 9206863..7277e12 100644 --- a/e2e-test/integration-test/src/main/resources/log4j2.properties +++ b/e2e-test/integration-test/src/test/resources/ollama_pull_model.sh @@ -1,3 +1,4 @@ +#!/usr/bin/env bash ################################################################################ # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -15,11 +16,5 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ - -rootLogger.level = INFO -rootLogger.appenderRef.console.ref = ConsoleAppender - -appender.console.name = ConsoleAppender -appender.console.type = CONSOLE -appender.console.layout.type = PatternLayout -appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss,SSS} %-5p %-60c %x - %m%n +echo "ollama pull $1" +ollama pull $1 \ No newline at end of file
