This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 90cd7c6e62236f6ed783c3e0da9def680cb7e71f Author: Marcelo Colomer <[email protected]> AuthorDate: Sat Oct 25 21:16:53 2025 +0800 [Feature] Support ollama embedding models in java. --- .../api/annotation/EmbeddingModelConnection.java | 35 ++++ .../agents/api/annotation/EmbeddingModelSetup.java | 35 ++++ .../model/BaseEmbeddingModelConnection.java | 83 ++++++++ .../embedding/model/BaseEmbeddingModelSetup.java | 113 +++++++++++ e2e-test/integration-test/pom.xml | 5 + .../integration/test/AgentWithOllamaEmbedding.java | 217 +++++++++++++++++++++ .../test/AgentWithOllamaEmbeddingExample.java | 65 ++++++ integrations/chat-models/ollama/pom.xml | 2 +- .../ollama/pom.xml | 15 +- .../ollama/OllamaEmbeddingModelConnection.java | 126 ++++++++++++ .../ollama/OllamaEmbeddingModelSetup.java | 54 +++++ .../ollama/OllamaEmbeddingModelConnectionTest.java | 89 +++++++++ integrations/{ => embedding-models}/pom.xml | 10 +- integrations/pom.xml | 5 + .../org/apache/flink/agents/plan/AgentPlan.java | 6 + 15 files changed, 844 insertions(+), 16 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelConnection.java new file mode 100644 index 0000000..2b41d37 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelConnection.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to mark a method that provides an embedding model connection resource descriptor. + * + * <p>Methods annotated with this annotation should return a {@link + * org.apache.flink.agents.api.resource.ResourceDescriptor} that describes how to configure and + * create an embedding model connection. + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface EmbeddingModelConnection {} diff --git a/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelSetup.java new file mode 100644 index 0000000..b5a5b9a --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelSetup.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to mark a method that provides an embedding model setup resource descriptor. + * + * <p>Methods annotated with this annotation should return a {@link + * org.apache.flink.agents.api.resource.ResourceDescriptor} that describes how to configure and + * create an embedding model setup. + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface EmbeddingModelSetup {} diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java new file mode 100644 index 0000000..3774490 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.embedding.model; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Abstraction of embedding model connection. + * + * <p>Responsible for managing embedding model service connection configurations, such as Service + * address, API key, Connection timeout, Model name, Authentication information, etc. + * + * <p>This class follows the parameter pattern where additional configuration options can be passed + * through a Map<String, Object> parameters argument. Common parameters include: + * + * <ul> + * <li>model - The model name to use for embeddings + * <li>encoding_format - The format for encoding (e.g., "float", "base64") + * <li>timeout - Request timeout in milliseconds + * <li>batch_size - Maximum number of texts to process in a single request + * </ul> + */ +public abstract class BaseEmbeddingModelConnection extends Resource { + + public BaseEmbeddingModelConnection( + ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) { + super(descriptor, getResource); + } + + @Override + public ResourceType getResourceType() { + return ResourceType.EMBEDDING_MODEL_CONNECTION; + } + + /** + * Generate embeddings for a single text input. + * + * @param text The input text to generate embeddings for + * @param parameters Additional parameters to configure the embedding request. Common parameters + * include: - "model" (String): The specific model variant to use - "encoding_format" + * (String): The format for encoding (e.g., "float", "base64") - "timeout" (Integer): + * Request timeout in milliseconds + * @return An array of floating-point values representing the text embeddings. The length of the + * array is determined by the model itself. + */ + public abstract float[] embed(String text, Map<String, Object> parameters); + + /** + * Generate embeddings for multiple text inputs. + * + * @param texts The list of input texts to generate embeddings for + * @param parameters Additional parameters to configure the embedding request. Common parameters + * include: - "model" (String): The specific model variant to use - "encoding_format" + * (String): The format for encoding (e.g., "float", "base64") - "batch_size" (Integer): + * Maximum number of texts to process in a single request - "timeout" (Integer): Request + * timeout in milliseconds + * @return A list of arrays, each containing floating-point values representing the text + * embeddings. The length of each array is determined by the model itself. + */ + public abstract List<float[]> embed(List<String> texts, Map<String, Object> parameters); +} diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java new file mode 100644 index 0000000..205cfed --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.embedding.model; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Base class for embedding model setup configurations. + * + * <p>This class provides common setup functionality for embedding models, including connection + * management and model configuration. + */ +public abstract class BaseEmbeddingModelSetup extends Resource { + protected final String connection; + protected String model; + + public BaseEmbeddingModelSetup( + ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) { + super(descriptor, getResource); + this.connection = descriptor.getArgument("connection"); + this.model = descriptor.getArgument("model"); + } + + public abstract Map<String, Object> getParameters(); + + @Override + public ResourceType getResourceType() { + return ResourceType.EMBEDDING_MODEL; + } + + /** + * Get the embedding model connection. + * + * @return The embedding model connection instance + */ + public BaseEmbeddingModelConnection getConnection() { + return (BaseEmbeddingModelConnection) + getResource.apply(connection, ResourceType.EMBEDDING_MODEL_CONNECTION); + } + + /** + * Get the model name. + * + * @return The model name + */ + public String getModel() { + return model; + } + + /** + * Generate embeddings for the given text. + * + * @param text The input text to generate embeddings for + * @return An array of floating-point values representing the text embeddings + */ + public float[] embed(String text) { + return this.embed(text, Collections.emptyMap()); + } + + public float[] embed(String text, Map<String, Object> parameters) { + BaseEmbeddingModelConnection connection = getConnection(); + + Map<String, Object> params = this.getParameters(); + params.putAll(parameters); + + // params are propagated to the connection + return connection.embed(text, params); + } + + /** + * Generate embeddings for multiple texts. + * + * @param texts The list of input texts to generate embeddings for + * @return A list of arrays, each containing floating-point values representing the text + * embeddings + */ + public List<float[]> embed(List<String> texts) { + return this.embed(texts, Collections.emptyMap()); + } + + public List<float[]> embed(List<String> texts, Map<String, Object> parameters) { + BaseEmbeddingModelConnection connection = getConnection(); + + Map<String, Object> params = this.getParameters(); + params.putAll(parameters); + + // params are propagated to the connection + return connection.embed(texts, params); + } +} diff --git a/e2e-test/integration-test/pom.xml b/e2e-test/integration-test/pom.xml index a428cc9..c912248 100644 --- a/e2e-test/integration-test/pom.xml +++ b/e2e-test/integration-test/pom.xml @@ -64,6 +64,11 @@ under the License. <artifactId>flink-agents-integrations-chat-models-ollama</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-agents-integrations-embedding-models-ollama</artifactId> + <version>${project.version}</version> + </dependency> </dependencies> </project> \ No newline at end of file diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbedding.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbedding.java new file mode 100644 index 0000000..a2a9c64 --- /dev/null +++ b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbedding.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integration.test; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.Agent; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; +import org.apache.flink.agents.api.annotation.EmbeddingModelSetup; +import org.apache.flink.agents.api.annotation.Tool; +import org.apache.flink.agents.api.annotation.ToolParam; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelConnection; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelSetup; + +import java.util.HashMap; +import java.util.Map; + +/** + * Integration test agent for verifying embedding functionality with Ollama models. + * + * <p>This test agent validates: - Ollama embedding model integration - Vector generation and + * processing - Embedding dimension consistency - Tool integration for embedding operations - Error + * handling in embedding generation + * + * <p>Used for e2e testing of the embedding model subsystem. + */ +public class AgentWithOllamaEmbedding extends Agent { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @EmbeddingModelConnection + public static ResourceDescriptor ollamaEmbeddingConnection() { + return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelConnection.class.getName()) + .addInitialArgument("host", "http://localhost:11434") + .addInitialArgument("timeout", 60) + .build(); + } + + @EmbeddingModelSetup + public static ResourceDescriptor embeddingModel() { + return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "ollamaEmbeddingConnection") + .addInitialArgument("model", "nomic-embed-text") + .build(); + } + + /** Test tool for validating embedding storage operations. */ + @Tool(description = "Validate embedding storage for integration testing") + public static void validateEmbeddingStorage( + @ToolParam(name = "id") String id, + @ToolParam(name = "text") String text, + @ToolParam(name = "embedding") float[] embedding) { + + // Validation assertions for testing + if (embedding == null || embedding.length == 0) { + throw new AssertionError("Embedding cannot be null or empty"); + } + + System.out.printf( + "[TEST] Validated embedding: ID=%s, Dimension=%d, Text='%s...'%n", + id, embedding.length, text.substring(0, Math.min(30, text.length()))); + } + + /** Test tool for validating similarity calculations. */ + @Tool(description = "Validate similarity calculation for integration testing") + public static float validateSimilarityCalculation( + @ToolParam(name = "embedding1") float[] embedding1, + @ToolParam(name = "embedding2") float[] embedding2) { + + if (embedding1.length != embedding2.length) { + throw new AssertionError("Embedding dimensions must match for similarity calculation"); + } + + float dotProduct = 0.0f; + float normA = 0.0f; + float normB = 0.0f; + + for (int i = 0; i < embedding1.length; i++) { + dotProduct += embedding1[i] * embedding2[i]; + normA += embedding1[i] * embedding1[i]; + normB += embedding2[i] * embedding2[i]; + } + + if (normA == 0.0f || normB == 0.0f) { + return 0.0f; + } + + float similarity = (float) (dotProduct / (Math.sqrt(normA) * Math.sqrt(normB))); + + // Validate similarity is in expected range + if (similarity < -1.0f || similarity > 1.0f) { + throw new AssertionError( + String.format("Similarity score out of range [-1,1]: %.4f", similarity)); + } + + System.out.printf("[TEST] Validated similarity calculation: %.4f%n", similarity); + return similarity; + } + + /** Main test action that processes input and validates embedding generation. */ + @Action(listenEvents = {InputEvent.class}) + public static void testEmbeddingGeneration(InputEvent event, RunnerContext ctx) + throws Exception { + String input = (String) event.getInput(); + MAPPER.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + + // Parse test input + Map<String, Object> inputData; + try { + inputData = MAPPER.readValue(input, Map.class); + } catch (Exception e) { + inputData = new HashMap<>(); + inputData.put("text", input); + inputData.put("id", "test_doc_" + System.currentTimeMillis()); + } + + String text = (String) inputData.get("text"); + String id = (String) inputData.getOrDefault("id", "test_doc_" + System.currentTimeMillis()); + + if (text == null || text.trim().isEmpty()) { + throw new AssertionError("Test input must contain valid text"); + } + + // Store test data in memory + ctx.getShortTermMemory().set("test_id", id); + ctx.getShortTermMemory().set("test_text", text); + + try { + // Generate embedding using Ollama + float[] embedding = generateEmbeddingForTest(text, ctx); + + // Test the validation tool + validateEmbeddingStorage(id, text, embedding); + + // Test similarity calculation with itself (should be 1.0) + float selfSimilarity = validateSimilarityCalculation(embedding, embedding); + if (Math.abs(selfSimilarity - 1.0f) > 0.001f) { + throw new AssertionError( + String.format("Self-similarity should be 1.0, got %.4f", selfSimilarity)); + } + + // Create a minimal test result to avoid serialization issues + Map<String, Object> testResult = new HashMap<>(); + testResult.put("test_status", "PASSED"); + testResult.put("id", id); + testResult.put("dimension", Integer.valueOf(embedding.length)); + testResult.put("self_similarity", Float.valueOf(selfSimilarity)); + + ctx.sendEvent(new OutputEvent(testResult)); + + System.out.printf( + "[TEST] Embedding generation test PASSED for: '%s'%n", + text.substring(0, Math.min(50, text.length()))); + + } catch (Exception e) { + // Create minimal error result + Map<String, Object> testResult = new HashMap<>(); + testResult.put("test_status", "FAILED"); + testResult.put("error", e.getMessage()); + testResult.put("id", id); + + ctx.sendEvent(new OutputEvent(testResult)); + + System.err.printf("[TEST] Embedding generation test FAILED: %s%n", e.getMessage()); + throw e; // Re-throw for test failure reporting + } + } + + /** Generate embedding using the framework's resource system for testing. */ + private static float[] generateEmbeddingForTest(String text, RunnerContext ctx) { + try { + OllamaEmbeddingModelSetup embeddingModel = + (OllamaEmbeddingModelSetup) + ctx.getResource( + "embeddingModel", + org.apache.flink.agents.api.resource.ResourceType + .EMBEDDING_MODEL); + + float[] embedding = embeddingModel.embed(text); + System.out.printf("[TEST] Generated embedding with dimension: %d%n", embedding.length); + return embedding; + + } catch (Exception e) { + System.err.printf("[TEST] Failed to generate embedding: %s%n", e.getMessage()); + throw new RuntimeException("Embedding generation test failed", e); + } + } + + /** Calculate the L2 norm of an embedding vector. */ + private static float calculateNorm(float[] embedding) { + float norm = 0.0f; + for (float value : embedding) { + norm += value * value; + } + return (float) Math.sqrt(norm); + } +} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbeddingExample.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbeddingExample.java new file mode 100644 index 0000000..3c12c8c --- /dev/null +++ b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/AgentWithOllamaEmbeddingExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integration.test; + +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; + +/** Example application that applies {@link AgentWithOllamaEmbedding} to a DataStream of prompts. */ +public class AgentWithOllamaEmbeddingExample { + /** Runs the example pipeline. */ + public static void main(String[] args) throws Exception { + // Create the execution environment + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Use prompts that exercise embedding generation and similarity checks + DataStream<String> inputStream = + env.fromData( + "Generate embedding for: 'Machine learning'", + "Generate embedding for: 'Deep learning techniques'", + "Find texts similar to: 'neural networks'", + "Produce embedding and return top-3 similar items for: 'natural language processing'", + "Generate embedding for: 'hello world'", + "Compare similarity between 'cat' and 'dog'", + "Create embedding for: 'space exploration'", + "Find nearest neighbors for: 'artificial intelligence'", + "Generate embedding for: 'data science'", + "Random embedding test"); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // Apply agent to the DataStream and use the prompt itself as the key + DataStream<Object> outputStream = + agentsEnv + .fromDataStream(inputStream, (KeySelector<String, String>) value -> value) + .apply(new AgentWithOllamaEmbedding()) + .toDataStream(); + + // Print the results + outputStream.print(); + + // Execute the pipeline + agentsEnv.execute(); + } +} diff --git a/integrations/chat-models/ollama/pom.xml b/integrations/chat-models/ollama/pom.xml index 0477962..942a42d 100644 --- a/integrations/chat-models/ollama/pom.xml +++ b/integrations/chat-models/ollama/pom.xml @@ -46,7 +46,7 @@ under the License. <dependency> <groupId>io.github.ollama4j</groupId> <artifactId>ollama4j</artifactId> - <version>1.1.2</version> + <version>${ollama4j.version}</version> </dependency> </dependencies> diff --git a/integrations/chat-models/ollama/pom.xml b/integrations/embedding-models/ollama/pom.xml similarity index 78% copy from integrations/chat-models/ollama/pom.xml copy to integrations/embedding-models/ollama/pom.xml index 0477962..a6ea9ba 100644 --- a/integrations/chat-models/ollama/pom.xml +++ b/integrations/embedding-models/ollama/pom.xml @@ -22,13 +22,13 @@ under the License. <parent> <groupId>org.apache.flink</groupId> - <artifactId>flink-agents-integrations-chat-models</artifactId> + <artifactId>flink-agents-integrations-embedding-models</artifactId> <version>0.2-SNAPSHOT</version> <relativePath>../pom.xml</relativePath> </parent> - <artifactId>flink-agents-integrations-chat-models-ollama</artifactId> - <name>Flink Agents : Integrations: Chat Models: Ollama</name> + <artifactId>flink-agents-integrations-embedding-models-ollama</artifactId> + <name>Flink Agents : Integrations: Embedding Models: Ollama</name> <packaging>jar</packaging> <dependencies> @@ -37,17 +37,12 @@ under the License. <artifactId>flink-agents-api</artifactId> <version>${project.version}</version> </dependency> - <dependency> - <groupId>org.apache.flink</groupId> - <artifactId>flink-agents-plan</artifactId> - <version>${project.version}</version> - </dependency> <dependency> <groupId>io.github.ollama4j</groupId> <artifactId>ollama4j</artifactId> - <version>1.1.2</version> + <version>${ollama4j.version}</version> </dependency> </dependencies> -</project> \ No newline at end of file +</project> diff --git a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java new file mode 100644 index 0000000..cc96e17 --- /dev/null +++ b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.ollama; + +import io.github.ollama4j.Ollama; +import io.github.ollama4j.exceptions.OllamaException; +import io.github.ollama4j.models.embed.OllamaEmbedRequest; +import io.github.ollama4j.models.embed.OllamaEmbedResult; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.*; +import java.util.function.BiFunction; + +/** An embedding model integration for Ollama powered by the ollama4j client. */ +public class OllamaEmbeddingModelConnection extends BaseEmbeddingModelConnection { + + private final Ollama ollamaAPI; + private final String defaultModel; + + public OllamaEmbeddingModelConnection( + ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) { + super(descriptor, getResource); + String host = + descriptor.getArgument("host") != null + ? descriptor.getArgument("host") + : "http://localhost:11434"; + this.defaultModel = + descriptor.getArgument("model") != null + ? descriptor.getArgument("model") + : "nomic-embed-text"; + + this.ollamaAPI = new Ollama(host); + } + + @Override + public float[] embed(String text, Map<String, Object> parameters) { + String model = (String) parameters.getOrDefault("model", defaultModel); + + try { + // Pull the model if needed + ollamaAPI.pullModel(model); + + // Create embedding request with the input text + OllamaEmbedRequest requestModel = + new OllamaEmbedRequest(model, Collections.singletonList(text)); + + // Get embeddings from Ollama + OllamaEmbedResult response = ollamaAPI.embed(requestModel); + + // Extract the first (and only) embedding from the response + List<List<Double>> embeddings = response.getEmbeddings(); + if (embeddings == null || embeddings.isEmpty()) { + throw new RuntimeException("No embeddings returned from Ollama for text: " + text); + } + + List<Double> embedding = embeddings.get(0); + + // Convert to float array + float[] result = new float[embedding.size()]; + for (int i = 0; i < embedding.size(); i++) { + result[i] = embedding.get(i).floatValue(); + } + + return result; + + } catch (OllamaException e) { + throw new RuntimeException("Error generating embeddings for text: " + text, e); + } + } + + @Override + public List<float[]> embed(List<String> texts, Map<String, Object> parameters) { + String model = (String) parameters.getOrDefault("model", defaultModel); + + try { + // Pull the model if needed + ollamaAPI.pullModel(model); + + // Create embedding request with all input texts + OllamaEmbedRequest requestModel = new OllamaEmbedRequest(model, texts); + + // Get embeddings from Ollama + OllamaEmbedResult response = ollamaAPI.embed(requestModel); + + // Extract embeddings from the response + List<List<Double>> embeddings = response.getEmbeddings(); + if (embeddings == null || embeddings.size() != texts.size()) { + throw new RuntimeException("Mismatch between input texts and returned embeddings"); + } + + // Convert to float arrays + List<float[]> results = new ArrayList<>(); + for (List<Double> embedding : embeddings) { + float[] result = new float[embedding.size()]; + for (int i = 0; i < embedding.size(); i++) { + result[i] = embedding.get(i).floatValue(); + } + results.add(result); + } + + return results; + + } catch (OllamaException e) { + throw new RuntimeException("Error generating embeddings for texts", e); + } + } +} diff --git a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java new file mode 100644 index 0000000..39f38e0 --- /dev/null +++ b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.ollama; + +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; + +/** An embedding model setup for Ollama powered by the ollama4j client. */ +public class OllamaEmbeddingModelSetup extends BaseEmbeddingModelSetup { + + public OllamaEmbeddingModelSetup( + ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) { + super(descriptor, getResource); + } + + @Override + public Map<String, Object> getParameters() { + Map<String, Object> parameters = new HashMap<>(); + + // Add the model name if specified + if (model != null) { + parameters.put("model", model); + } + + return parameters; + } + + @Override + public OllamaEmbeddingModelConnection getConnection() { + return (OllamaEmbeddingModelConnection) super.getConnection(); + } +} diff --git a/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java b/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java new file mode 100644 index 0000000..47c9dc6 --- /dev/null +++ b/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.ollama; + +import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; + +import static org.junit.jupiter.api.Assertions.*; + +class OllamaEmbeddingModelConnectionTest { + + private static ResourceDescriptor buildDescriptor() { + return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelConnection.class.getName()) + .addInitialArgument("host", "http://localhost:11434") + .addInitialArgument("model", "nomic-embed-text") + .build(); + } + + private static BiFunction<String, ResourceType, Resource> dummyResource = (a, b) -> null; + + @Test + @DisplayName("Create OllamaEmbeddingModelConnection and check embed method") + void testCreateAndEmbed() { + OllamaEmbeddingModelConnection conn = + new OllamaEmbeddingModelConnection(buildDescriptor(), dummyResource); + assertNotNull(conn); + // No llamamos a embed porque requiere un servidor Ollama real + } + + @Test + @DisplayName("Test EmbeddingModelConnection annotation presence") + void testAnnotationPresence() { + assertNull( + OllamaEmbeddingModelConnection.class.getAnnotation(EmbeddingModelConnection.class)); + } + + @Test + @DisplayName("Test EmbeddingModelSetup annotation presence on setup class") + void testSetupAnnotationPresence() { + class DummySetup extends BaseEmbeddingModelSetup { + public DummySetup( + ResourceDescriptor descriptor, + BiFunction<String, ResourceType, Resource> getResource) { + super(descriptor, getResource); + } + + public Map<String, Object> getParameters() { + Map<String, Object> parameters = new HashMap<>(); + if (model != null) { + parameters.put("model", model); + } + return parameters; + } + + @Override + public BaseEmbeddingModelConnection getConnection() { + return new OllamaEmbeddingModelConnection(buildDescriptor(), dummyResource); + } + } + DummySetup setup = new DummySetup(buildDescriptor(), dummyResource); + assertNotNull(setup.getConnection()); + } +} diff --git a/integrations/pom.xml b/integrations/embedding-models/pom.xml similarity index 83% copy from integrations/pom.xml copy to integrations/embedding-models/pom.xml index 657cb1f..473929b 100644 --- a/integrations/pom.xml +++ b/integrations/embedding-models/pom.xml @@ -22,16 +22,16 @@ under the License. <parent> <groupId>org.apache.flink</groupId> - <artifactId>flink-agents</artifactId> + <artifactId>flink-agents-integrations</artifactId> <version>0.2-SNAPSHOT</version> </parent> - <artifactId>flink-agents-integrations</artifactId> - <name>Flink Agents : Integrations:</name> + <artifactId>flink-agents-integrations-embedding-models</artifactId> + <name>Flink Agents : Integrations: Embedding Models</name> <packaging>pom</packaging> <modules> - <module>chat-models</module> + <module>ollama</module> </modules> -</project> \ No newline at end of file +</project> diff --git a/integrations/pom.xml b/integrations/pom.xml index 657cb1f..6cbb428 100644 --- a/integrations/pom.xml +++ b/integrations/pom.xml @@ -30,8 +30,13 @@ under the License. <name>Flink Agents : Integrations:</name> <packaging>pom</packaging> + <properties> + <ollama4j.version>1.1.2</ollama4j.version> + </properties> + <modules> <module>chat-models</module> + <module>embedding-models</module> </modules> </project> \ No newline at end of file diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java index c7d388e..7cb8138 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java @@ -22,6 +22,8 @@ import org.apache.flink.agents.api.Agent; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.annotation.ChatModelConnection; import org.apache.flink.agents.api.annotation.ChatModelSetup; +import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; +import org.apache.flink.agents.api.annotation.EmbeddingModelSetup; import org.apache.flink.agents.api.annotation.Prompt; import org.apache.flink.agents.api.annotation.Tool; import org.apache.flink.agents.api.resource.Resource; @@ -372,6 +374,10 @@ public class AgentPlan implements Serializable { extractResource(ResourceType.CHAT_MODEL, method); } else if (method.isAnnotationPresent(ChatModelConnection.class)) { extractResource(ResourceType.CHAT_MODEL_CONNECTION, method); + } else if (method.isAnnotationPresent(EmbeddingModelSetup.class)) { + extractResource(ResourceType.EMBEDDING_MODEL, method); + } else if (method.isAnnotationPresent(EmbeddingModelConnection.class)) { + extractResource(ResourceType.EMBEDDING_MODEL_CONNECTION, method); } }
