This is an automated email from the ASF dual-hosted git repository.
xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 8c7ae45 [Feature][runtime] Support the use of Python EembeddingModel
in Java (#417)
8c7ae45 is described below
commit 8c7ae45d3cf7b87a8157e5c9408085e425c65046
Author: Eugene <[email protected]>
AuthorDate: Wed Jan 7 18:56:07 2026 +0800
[Feature][runtime] Support the use of Python EembeddingModel in Java (#417)
---
.../api/embedding/model/EmbeddingModelUtils.java | 37 +++
.../python/PythonEmbeddingModelConnection.java | 127 ++++++++++
.../model/python/PythonEmbeddingModelSetup.java | 131 ++++++++++
.../embedding/model/EmbeddingModelUtilsTest.java | 122 +++++++++
.../python/PythonEmbeddingModelConnectionTest.java | 268 ++++++++++++++++++++
.../python/PythonEmbeddingModelSetupTest.java | 272 +++++++++++++++++++++
.../resource/test/ChatModelCrossLanguageTest.java | 21 +-
.../resource/test/EmbeddingCrossLanguageAgent.java | 174 +++++++++++++
.../resource/test/EmbeddingCrossLanguageTest.java | 97 ++++++++
.../resource/test/OllamaPreparationUtils.java | 47 ++++
.../ollama/OllamaEmbeddingModelConnection.java | 14 +-
.../api/embedding_models/embedding_model.py | 4 +-
.../local/ollama_embedding_model.py | 8 +-
.../embedding_models/openai_embedding_model.py | 7 +-
python/flink_agents/plan/tests/test_agent_plan.py | 6 +-
15 files changed, 1294 insertions(+), 41 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/embedding/model/EmbeddingModelUtils.java
b/api/src/main/java/org/apache/flink/agents/api/embedding/model/EmbeddingModelUtils.java
new file mode 100644
index 0000000..c1c71e5
--- /dev/null
+++
b/api/src/main/java/org/apache/flink/agents/api/embedding/model/EmbeddingModelUtils.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.api.embedding.model;
+
+import java.util.List;
+
+public class EmbeddingModelUtils {
+ public static float[] toFloatArray(List list) {
+ float[] array = new float[list.size()];
+ for (int i = 0; i < list.size(); i++) {
+ Object element = list.get(i);
+ if (element instanceof Number) {
+ array[i] = ((Number) element).floatValue();
+ } else {
+ throw new IllegalArgumentException(
+ "Expected numeric value in embedding result, but got: "
+ + element.getClass().getName());
+ }
+ }
+ return array;
+ }
+}
diff --git
a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java
b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java
new file mode 100644
index 0000000..b4ff933
--- /dev/null
+++
b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.api.embedding.model.python;
+
+import
org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection;
+import org.apache.flink.agents.api.embedding.model.EmbeddingModelUtils;
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.api.resource.python.PythonResourceAdapter;
+import org.apache.flink.agents.api.resource.python.PythonResourceWrapper;
+import pemja.core.object.PyObject;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiFunction;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Python-based implementation of EmbeddingModelConnection that bridges Java
and Python embedding
+ * model functionality. This class wraps a Python embedding model connection
object and provides
+ * Java interface compatibility while delegating actual embed operations to
the underlying Python
+ * implementation.
+ */
+public class PythonEmbeddingModelConnection extends
BaseEmbeddingModelConnection
+ implements PythonResourceWrapper {
+
+ private final PyObject embeddingModel;
+ private final PythonResourceAdapter adapter;
+
+ /**
+ * Creates a new PythonEmbeddingModelConnection.
+ *
+ * @param adapter The Python resource adapter (required by
PythonResourceProvider's
+ * reflection-based instantiation but not used directly in this
implementation)
+ * @param embeddingModel The Python embedding model object
+ * @param descriptor The resource descriptor
+ * @param getResource Function to retrieve resources by name and type
+ */
+ public PythonEmbeddingModelConnection(
+ PythonResourceAdapter adapter,
+ PyObject embeddingModel,
+ ResourceDescriptor descriptor,
+ BiFunction<String, ResourceType, Resource> getResource) {
+ super(descriptor, getResource);
+ this.embeddingModel = embeddingModel;
+ this.adapter = adapter;
+ }
+
+ @Override
+ public float[] embed(String text, Map<String, Object> parameters) {
+ checkState(
+ embeddingModel != null,
+ "EmbeddingModelSetup is not initialized. Cannot perform embed
operation.");
+
+ Map<String, Object> kwargs = new HashMap<>(parameters);
+ kwargs.put("text", text);
+
+ Object result = adapter.callMethod(embeddingModel, "embed", kwargs);
+
+ // Convert to float arrays
+ if (result instanceof List) {
+ List<?> list = (List<?>) result;
+ return EmbeddingModelUtils.toFloatArray(list);
+ }
+
+ throw new IllegalArgumentException(
+ "Expected List from Python embed method, but got: "
+ + (result == null ? "null" :
result.getClass().getName()));
+ }
+
+ @Override
+ public List<float[]> embed(List<String> texts, Map<String, Object>
parameters) {
+ checkState(
+ embeddingModel != null,
+ "EmbeddingModelSetup is not initialized. Cannot perform embed
operation.");
+
+ Map<String, Object> kwargs = new HashMap<>(parameters);
+ kwargs.put("text", texts);
+
+ Object results = adapter.callMethod(embeddingModel, "embed", kwargs);
+
+ if (results instanceof List) {
+ List<?> list = (List<?>) results;
+ List<float[]> embeddings = new ArrayList<>();
+
+ for (Object element : list) {
+ if (element instanceof List) {
+ List<?> listElement = (List<?>) element;
+
embeddings.add(EmbeddingModelUtils.toFloatArray(listElement));
+ } else {
+ throw new IllegalArgumentException(
+ "Expected List value in embedding results, but
got: "
+ + element.getClass().getName());
+ }
+ }
+ return embeddings;
+ }
+
+ throw new IllegalArgumentException(
+ "Expected List from Python embed method, but got: "
+ + (results == null ? "null" :
results.getClass().getName()));
+ }
+
+ @Override
+ public Object getPythonResource() {
+ return embeddingModel;
+ }
+}
diff --git
a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java
b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java
new file mode 100644
index 0000000..33c4410
--- /dev/null
+++
b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.api.embedding.model.python;
+
+import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup;
+import org.apache.flink.agents.api.embedding.model.EmbeddingModelUtils;
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.api.resource.python.PythonResourceAdapter;
+import org.apache.flink.agents.api.resource.python.PythonResourceWrapper;
+import pemja.core.object.PyObject;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiFunction;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Python-based implementation of EmbeddingModelSetup that bridges Java and
Python embedding model
+ * setup functionality. This class wraps a Python embedding model setup object
and provides Java
+ * interface compatibility while delegating actual embed operations to the
underlying Python
+ * implementation.
+ */
+public class PythonEmbeddingModelSetup extends BaseEmbeddingModelSetup
+ implements PythonResourceWrapper {
+ private final PyObject embeddingModelSetup;
+ private final PythonResourceAdapter adapter;
+
+ /**
+ * Creates a new PythonEmbeddingModelSetup.
+ *
+ * @param adapter The Python resource adapter (required by
PythonResourceProvider's
+ * reflection-based instantiation but not used directly in this
implementation)
+ * @param embeddingModelSetup The Python embedding model object
+ * @param descriptor The resource descriptor
+ * @param getResource Function to retrieve resources by name and type
+ */
+ public PythonEmbeddingModelSetup(
+ PythonResourceAdapter adapter,
+ PyObject embeddingModelSetup,
+ ResourceDescriptor descriptor,
+ BiFunction<String, ResourceType, Resource> getResource) {
+ super(descriptor, getResource);
+ this.embeddingModelSetup = embeddingModelSetup;
+ this.adapter = adapter;
+ }
+
+ @Override
+ public float[] embed(String text, Map<String, Object> parameters) {
+ checkState(
+ embeddingModelSetup != null,
+ "EmbeddingModelSetup is not initialized. Cannot perform embed
operation.");
+
+ Map<String, Object> kwargs = new HashMap<>(parameters);
+ kwargs.put("text", text);
+
+ Object result = adapter.callMethod(embeddingModelSetup, "embed",
kwargs);
+
+ // Convert to float arrays
+ if (result instanceof List) {
+ List<?> list = (List<?>) result;
+ return EmbeddingModelUtils.toFloatArray(list);
+ }
+
+ throw new IllegalArgumentException(
+ "Expected List from Python embed method, but got: "
+ + (result == null ? "null" :
result.getClass().getName()));
+ }
+
+ @Override
+ public List<float[]> embed(List<String> texts, Map<String, Object>
parameters) {
+ checkState(
+ embeddingModelSetup != null,
+ "EmbeddingModelSetup is not initialized. Cannot perform embed
operation.");
+
+ Map<String, Object> kwargs = new HashMap<>(parameters);
+ kwargs.put("text", texts);
+
+ Object results = adapter.callMethod(embeddingModelSetup, "embed",
kwargs);
+
+ if (results instanceof List) {
+ List<?> list = (List<?>) results;
+ List<float[]> embeddings = new ArrayList<>();
+
+ for (Object element : list) {
+ if (element instanceof List) {
+ List<?> listElement = (List<?>) element;
+
embeddings.add(EmbeddingModelUtils.toFloatArray(listElement));
+ } else {
+ throw new IllegalArgumentException(
+ "Expected List value in embedding results, but
got: "
+ + element.getClass().getName());
+ }
+ }
+ return embeddings;
+ }
+
+ throw new IllegalArgumentException(
+ "Expected List from Python embed method, but got: "
+ + (results == null ? "null" :
results.getClass().getName()));
+ }
+
+ @Override
+ public Map<String, Object> getParameters() {
+ return Map.of();
+ }
+
+ @Override
+ public Object getPythonResource() {
+ return embeddingModelSetup;
+ }
+}
diff --git
a/api/src/test/java/org/apache/flink/agents/api/embedding/model/EmbeddingModelUtilsTest.java
b/api/src/test/java/org/apache/flink/agents/api/embedding/model/EmbeddingModelUtilsTest.java
new file mode 100644
index 0000000..9c6f3e9
--- /dev/null
+++
b/api/src/test/java/org/apache/flink/agents/api/embedding/model/EmbeddingModelUtilsTest.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.api.embedding.model;
+
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+/** Test for {@link EmbeddingModelUtils}. */
+public class EmbeddingModelUtilsTest {
+
+ @Test
+ @DisplayName("Test converting List<Double> to float array")
+ void testToFloatArrayFromDoubleList() {
+ List<Double> doubleList = Arrays.asList(1.0, 2.5, 3.7, 4.2);
+
+ float[] result = EmbeddingModelUtils.toFloatArray(doubleList);
+
+ assertNotNull(result);
+ assertEquals(4, result.length);
+ assertEquals(1.0f, result[0], 0.0001f);
+ assertEquals(2.5f, result[1], 0.0001f);
+ assertEquals(3.7f, result[2], 0.0001f);
+ assertEquals(4.2f, result[3], 0.0001f);
+ }
+
+ @Test
+ @DisplayName("Test converting List<Float> to float array")
+ void testToFloatArrayFromFloatList() {
+ List<Float> floatList = Arrays.asList(1.5f, 2.5f, 3.5f);
+
+ float[] result = EmbeddingModelUtils.toFloatArray(floatList);
+
+ assertNotNull(result);
+ assertEquals(3, result.length);
+ assertArrayEquals(new float[] {1.5f, 2.5f, 3.5f}, result, 0.0001f);
+ }
+
+ @Test
+ @DisplayName("Test converting mixed Number types to float array")
+ void testToFloatArrayFromMixedNumberList() {
+ List<Number> mixedList = new ArrayList<>();
+ mixedList.add(1); // Integer
+ mixedList.add(2.5); // Double
+ mixedList.add(3.5f); // Float
+ mixedList.add(4L); // Long
+
+ float[] result = EmbeddingModelUtils.toFloatArray(mixedList);
+
+ assertNotNull(result);
+ assertEquals(4, result.length);
+ assertEquals(1.0f, result[0], 0.0001f);
+ assertEquals(2.5f, result[1], 0.0001f);
+ assertEquals(3.5f, result[2], 0.0001f);
+ assertEquals(4.0f, result[3], 0.0001f);
+ }
+
+ @Test
+ @DisplayName("Test converting empty list to empty float array")
+ void testToFloatArrayFromEmptyList() {
+ List<Double> emptyList = new ArrayList<>();
+
+ float[] result = EmbeddingModelUtils.toFloatArray(emptyList);
+
+ assertNotNull(result);
+ assertEquals(0, result.length);
+ }
+
+ @Test
+ @DisplayName("Test converting single element list to float array")
+ void testToFloatArrayFromSingleElementList() {
+ List<Double> singleList = List.of(42.0);
+
+ float[] result = EmbeddingModelUtils.toFloatArray(singleList);
+
+ assertNotNull(result);
+ assertEquals(1, result.length);
+ assertEquals(42.0f, result[0], 0.0001f);
+ }
+
+ @Test
+ @DisplayName("Test exception when list contains non-numeric value")
+ void testToFloatArrayThrowsExceptionForNonNumericValue() {
+ List<Object> invalidList = new ArrayList<>();
+ invalidList.add(1.0);
+ invalidList.add("not a number");
+ invalidList.add(3.0);
+
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> EmbeddingModelUtils.toFloatArray(invalidList));
+
+ assertNotNull(exception.getMessage());
+ assertEquals(
+ "Expected numeric value in embedding result, but got:
java.lang.String",
+ exception.getMessage());
+ }
+}
diff --git
a/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnectionTest.java
b/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnectionTest.java
new file mode 100644
index 0000000..13ae25b
--- /dev/null
+++
b/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnectionTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.api.embedding.model.python;
+
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.api.resource.python.PythonResourceAdapter;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import pemja.core.object.PyObject;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiFunction;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.argThat;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class PythonEmbeddingModelConnectionTest {
+ @Mock private PythonResourceAdapter mockAdapter;
+
+ @Mock private PyObject mockEmbeddingModel;
+
+ @Mock private ResourceDescriptor mockDescriptor;
+
+ @Mock private BiFunction<String, ResourceType, Resource> mockGetResource;
+
+ private PythonEmbeddingModelConnection pythonEmbeddingModelConnection;
+ private AutoCloseable mocks;
+
+ @BeforeEach
+ void setUp() throws Exception {
+ mocks = MockitoAnnotations.openMocks(this);
+ pythonEmbeddingModelConnection =
+ new PythonEmbeddingModelConnection(
+ mockAdapter, mockEmbeddingModel, mockDescriptor,
mockGetResource);
+ }
+
+ @AfterEach
+ void tearDown() throws Exception {
+ if (mocks != null) {
+ mocks.close();
+ }
+ }
+
+ @Test
+ void testConstructor() {
+ assertThat(pythonEmbeddingModelConnection).isNotNull();
+ assertThat(pythonEmbeddingModelConnection.getPythonResource())
+ .isEqualTo(mockEmbeddingModel);
+ }
+
+ @Test
+ void testGetPythonResourceWithNullEmbeddingModel() {
+ PythonEmbeddingModelConnection connectionWithNullModel =
+ new PythonEmbeddingModelConnection(
+ mockAdapter, null, mockDescriptor, mockGetResource);
+
+ Object result = connectionWithNullModel.getPythonResource();
+
+ assertThat(result).isNull();
+ }
+
+ @Test
+ void testEmbedSingleText() {
+ String text = "test embedding text";
+ Map<String, Object> parameters = new HashMap<>();
+ parameters.put("temperature", 0.5);
+
+ List<Double> pythonResult = Arrays.asList(0.1, 0.2, 0.3, 0.4);
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModel), eq("embed"),
any(Map.class)))
+ .thenReturn(pythonResult);
+
+ float[] result = pythonEmbeddingModelConnection.embed(text,
parameters);
+
+ assertThat(result).isNotNull();
+ assertThat(result).hasSize(4);
+ assertThat(result[0]).isEqualTo(0.1f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result[1]).isEqualTo(0.2f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result[2]).isEqualTo(0.3f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result[3]).isEqualTo(0.4f,
org.assertj.core.data.Offset.offset(0.0001f));
+
+ verify(mockAdapter)
+ .callMethod(
+ eq(mockEmbeddingModel),
+ eq("embed"),
+ argThat(
+ kwargs -> {
+ assertThat(kwargs).containsKey("text");
+
assertThat(kwargs).containsKey("temperature");
+
assertThat(kwargs.get("text")).isEqualTo(text);
+
assertThat(kwargs.get("temperature")).isEqualTo(0.5);
+ return true;
+ }));
+ }
+
+ @Test
+ void testEmbedSingleTextWithEmptyParameters() {
+ String text = "test text";
+ Map<String, Object> parameters = new HashMap<>();
+
+ List<Double> pythonResult = Arrays.asList(1.0, 2.0);
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModel), eq("embed"),
any(Map.class)))
+ .thenReturn(pythonResult);
+
+ float[] result = pythonEmbeddingModelConnection.embed(text,
parameters);
+
+ assertThat(result).isNotNull();
+ assertThat(result).hasSize(2);
+ }
+
+ @Test
+ void testEmbedSingleTextWithNullEmbeddingModelThrowsException() {
+ PythonEmbeddingModelConnection connectionWithNullModel =
+ new PythonEmbeddingModelConnection(
+ mockAdapter, null, mockDescriptor, mockGetResource);
+
+ String text = "test text";
+ Map<String, Object> parameters = new HashMap<>();
+
+ assertThatThrownBy(() -> connectionWithNullModel.embed(text,
parameters))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("EmbeddingModelSetup is not initialized")
+ .hasMessageContaining("Cannot perform embed operation");
+ }
+
+ @Test
+ void testEmbedSingleTextWithNonListResultThrowsException() {
+ String text = "test text";
+ Map<String, Object> parameters = new HashMap<>();
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModel), eq("embed"),
any(Map.class)))
+ .thenReturn("invalid result");
+
+ assertThatThrownBy(() -> pythonEmbeddingModelConnection.embed(text,
parameters))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Expected List from Python embed method")
+ .hasMessageContaining("java.lang.String");
+ }
+
+ @Test
+ void testEmbedMultipleTexts() {
+ List<String> texts = Arrays.asList("text1", "text2", "text3");
+ Map<String, Object> parameters = new HashMap<>();
+ parameters.put("batch_size", 3);
+
+ List<List<Double>> pythonResult = new ArrayList<>();
+ pythonResult.add(Arrays.asList(0.1, 0.2));
+ pythonResult.add(Arrays.asList(0.3, 0.4));
+ pythonResult.add(Arrays.asList(0.5, 0.6));
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModel), eq("embed"),
any(Map.class)))
+ .thenReturn(pythonResult);
+
+ List<float[]> result = pythonEmbeddingModelConnection.embed(texts,
parameters);
+
+ assertThat(result).isNotNull();
+ assertThat(result).hasSize(3);
+ assertThat(result.get(0)).hasSize(2);
+ assertThat(result.get(0)[0]).isEqualTo(0.1f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result.get(1)[0]).isEqualTo(0.3f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result.get(2)[0]).isEqualTo(0.5f,
org.assertj.core.data.Offset.offset(0.0001f));
+
+ verify(mockAdapter)
+ .callMethod(
+ eq(mockEmbeddingModel),
+ eq("embed"),
+ argThat(
+ kwargs -> {
+ assertThat(kwargs).containsKey("text");
+
assertThat(kwargs).containsKey("batch_size");
+
assertThat(kwargs.get("text")).isEqualTo(texts);
+
assertThat(kwargs.get("batch_size")).isEqualTo(3);
+ return true;
+ }));
+ }
+
+ @Test
+ void testEmbedMultipleTextsWithNullEmbeddingModelThrowsException() {
+ PythonEmbeddingModelConnection connectionWithNullModel =
+ new PythonEmbeddingModelConnection(
+ mockAdapter, null, mockDescriptor, mockGetResource);
+
+ List<String> texts = Arrays.asList("text1", "text2");
+ Map<String, Object> parameters = new HashMap<>();
+
+ assertThatThrownBy(() -> connectionWithNullModel.embed(texts,
parameters))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("EmbeddingModelSetup is not initialized")
+ .hasMessageContaining("Cannot perform embed operation");
+ }
+
+ @Test
+ void testEmbedMultipleTextsWithNonListResultThrowsException() {
+ List<String> texts = Arrays.asList("text1", "text2");
+ Map<String, Object> parameters = new HashMap<>();
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModel), eq("embed"),
any(Map.class)))
+ .thenReturn("invalid result");
+
+ assertThatThrownBy(() -> pythonEmbeddingModelConnection.embed(texts,
parameters))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Expected List from Python embed method")
+ .hasMessageContaining("java.lang.String");
+ }
+
+ @Test
+ void testEmbedMultipleTextsWithNonListElementThrowsException() {
+ List<String> texts = Arrays.asList("text1", "text2");
+ Map<String, Object> parameters = new HashMap<>();
+
+ List<Object> pythonResult = new ArrayList<>();
+ pythonResult.add(Arrays.asList(0.1, 0.2));
+ pythonResult.add("invalid element");
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModel), eq("embed"),
any(Map.class)))
+ .thenReturn(pythonResult);
+
+ assertThatThrownBy(() -> pythonEmbeddingModelConnection.embed(texts,
parameters))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Expected List value in embedding
results")
+ .hasMessageContaining("java.lang.String");
+ }
+
+ @Test
+ void testInheritanceFromBaseEmbeddingModelConnection() {
+ assertThat(pythonEmbeddingModelConnection)
+ .isInstanceOf(
+
org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection
+ .class);
+ }
+
+ @Test
+ void testImplementsPythonResourceWrapper() {
+ assertThat(pythonEmbeddingModelConnection)
+ .isInstanceOf(
+
org.apache.flink.agents.api.resource.python.PythonResourceWrapper.class);
+ }
+}
diff --git
a/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetupTest.java
b/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetupTest.java
new file mode 100644
index 0000000..cedd2a8
--- /dev/null
+++
b/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetupTest.java
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.api.embedding.model.python;
+
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.api.resource.python.PythonResourceAdapter;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import pemja.core.object.PyObject;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiFunction;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.argThat;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class PythonEmbeddingModelSetupTest {
+ @Mock private PythonResourceAdapter mockAdapter;
+
+ @Mock private PyObject mockEmbeddingModelSetup;
+
+ @Mock private ResourceDescriptor mockDescriptor;
+
+ @Mock private BiFunction<String, ResourceType, Resource> mockGetResource;
+
+ private PythonEmbeddingModelSetup pythonEmbeddingModelSetup;
+ private AutoCloseable mocks;
+
+ @BeforeEach
+ void setUp() throws Exception {
+ mocks = MockitoAnnotations.openMocks(this);
+ pythonEmbeddingModelSetup =
+ new PythonEmbeddingModelSetup(
+ mockAdapter, mockEmbeddingModelSetup, mockDescriptor,
mockGetResource);
+ }
+
+ @AfterEach
+ void tearDown() throws Exception {
+ if (mocks != null) {
+ mocks.close();
+ }
+ }
+
+ @Test
+ void testConstructor() {
+ assertThat(pythonEmbeddingModelSetup).isNotNull();
+ assertThat(pythonEmbeddingModelSetup.getPythonResource())
+ .isEqualTo(mockEmbeddingModelSetup);
+ }
+
+ @Test
+ void testGetPythonResourceWithNullEmbeddingModelSetup() {
+ PythonEmbeddingModelSetup setupWithNullModel =
+ new PythonEmbeddingModelSetup(mockAdapter, null,
mockDescriptor, mockGetResource);
+
+ Object result = setupWithNullModel.getPythonResource();
+
+ assertThat(result).isNull();
+ }
+
+ @Test
+ void testGetParameters() {
+ Map<String, Object> result = pythonEmbeddingModelSetup.getParameters();
+
+ assertThat(result).isNotNull();
+ assertThat(result).isEmpty();
+ }
+
+ @Test
+ void testEmbedSingleText() {
+ String text = "test embedding text";
+ Map<String, Object> parameters = new HashMap<>();
+ parameters.put("temperature", 0.5);
+
+ List<Double> pythonResult = Arrays.asList(0.1, 0.2, 0.3, 0.4);
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModelSetup), eq("embed"),
any(Map.class)))
+ .thenReturn(pythonResult);
+
+ float[] result = pythonEmbeddingModelSetup.embed(text, parameters);
+
+ assertThat(result).isNotNull();
+ assertThat(result).hasSize(4);
+ assertThat(result[0]).isEqualTo(0.1f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result[1]).isEqualTo(0.2f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result[2]).isEqualTo(0.3f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result[3]).isEqualTo(0.4f,
org.assertj.core.data.Offset.offset(0.0001f));
+
+ verify(mockAdapter)
+ .callMethod(
+ eq(mockEmbeddingModelSetup),
+ eq("embed"),
+ argThat(
+ kwargs -> {
+ assertThat(kwargs).containsKey("text");
+
assertThat(kwargs).containsKey("temperature");
+
assertThat(kwargs.get("text")).isEqualTo(text);
+
assertThat(kwargs.get("temperature")).isEqualTo(0.5);
+ return true;
+ }));
+ }
+
+ @Test
+ void testEmbedSingleTextWithEmptyParameters() {
+ String text = "test text";
+ Map<String, Object> parameters = new HashMap<>();
+
+ List<Double> pythonResult = Arrays.asList(1.0, 2.0);
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModelSetup), eq("embed"),
any(Map.class)))
+ .thenReturn(pythonResult);
+
+ float[] result = pythonEmbeddingModelSetup.embed(text, parameters);
+
+ assertThat(result).isNotNull();
+ assertThat(result).hasSize(2);
+ }
+
+ @Test
+ void testEmbedSingleTextWithNullEmbeddingModelSetupThrowsException() {
+ PythonEmbeddingModelSetup setupWithNullModel =
+ new PythonEmbeddingModelSetup(mockAdapter, null,
mockDescriptor, mockGetResource);
+
+ String text = "test text";
+ Map<String, Object> parameters = new HashMap<>();
+
+ assertThatThrownBy(() -> setupWithNullModel.embed(text, parameters))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("EmbeddingModelSetup is not initialized")
+ .hasMessageContaining("Cannot perform embed operation");
+ }
+
+ @Test
+ void testEmbedSingleTextWithNonListResultThrowsException() {
+ String text = "test text";
+ Map<String, Object> parameters = new HashMap<>();
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModelSetup), eq("embed"),
any(Map.class)))
+ .thenReturn("invalid result");
+
+ assertThatThrownBy(() -> pythonEmbeddingModelSetup.embed(text,
parameters))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Expected List from Python embed method")
+ .hasMessageContaining("java.lang.String");
+ }
+
+ @Test
+ void testEmbedMultipleTexts() {
+ List<String> texts = Arrays.asList("text1", "text2", "text3");
+ Map<String, Object> parameters = new HashMap<>();
+ parameters.put("batch_size", 3);
+
+ List<List<Double>> pythonResult = new ArrayList<>();
+ pythonResult.add(Arrays.asList(0.1, 0.2));
+ pythonResult.add(Arrays.asList(0.3, 0.4));
+ pythonResult.add(Arrays.asList(0.5, 0.6));
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModelSetup), eq("embed"),
any(Map.class)))
+ .thenReturn(pythonResult);
+
+ List<float[]> result = pythonEmbeddingModelSetup.embed(texts,
parameters);
+
+ assertThat(result).isNotNull();
+ assertThat(result).hasSize(3);
+ assertThat(result.get(0)).hasSize(2);
+ assertThat(result.get(0)[0]).isEqualTo(0.1f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result.get(1)[0]).isEqualTo(0.3f,
org.assertj.core.data.Offset.offset(0.0001f));
+ assertThat(result.get(2)[0]).isEqualTo(0.5f,
org.assertj.core.data.Offset.offset(0.0001f));
+
+ verify(mockAdapter)
+ .callMethod(
+ eq(mockEmbeddingModelSetup),
+ eq("embed"),
+ argThat(
+ kwargs -> {
+ assertThat(kwargs).containsKey("text");
+
assertThat(kwargs).containsKey("batch_size");
+
assertThat(kwargs.get("text")).isEqualTo(texts);
+
assertThat(kwargs.get("batch_size")).isEqualTo(3);
+ return true;
+ }));
+ }
+
+ @Test
+ void testEmbedMultipleTextsWithNullEmbeddingModelSetupThrowsException() {
+ PythonEmbeddingModelSetup setupWithNullModel =
+ new PythonEmbeddingModelSetup(mockAdapter, null,
mockDescriptor, mockGetResource);
+
+ List<String> texts = Arrays.asList("text1", "text2");
+ Map<String, Object> parameters = new HashMap<>();
+
+ assertThatThrownBy(() -> setupWithNullModel.embed(texts, parameters))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("EmbeddingModelSetup is not initialized")
+ .hasMessageContaining("Cannot perform embed operation");
+ }
+
+ @Test
+ void testEmbedMultipleTextsWithNonListResultThrowsException() {
+ List<String> texts = Arrays.asList("text1", "text2");
+ Map<String, Object> parameters = new HashMap<>();
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModelSetup), eq("embed"),
any(Map.class)))
+ .thenReturn("invalid result");
+
+ assertThatThrownBy(() -> pythonEmbeddingModelSetup.embed(texts,
parameters))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Expected List from Python embed method")
+ .hasMessageContaining("java.lang.String");
+ }
+
+ @Test
+ void testEmbedMultipleTextsWithNonListElementThrowsException() {
+ List<String> texts = Arrays.asList("text1", "text2");
+ Map<String, Object> parameters = new HashMap<>();
+
+ List<Object> pythonResult = new ArrayList<>();
+ pythonResult.add(Arrays.asList(0.1, 0.2));
+ pythonResult.add("invalid element");
+
+ when(mockAdapter.callMethod(eq(mockEmbeddingModelSetup), eq("embed"),
any(Map.class)))
+ .thenReturn(pythonResult);
+
+ assertThatThrownBy(() -> pythonEmbeddingModelSetup.embed(texts,
parameters))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Expected List value in embedding
results")
+ .hasMessageContaining("java.lang.String");
+ }
+
+ @Test
+ void testInheritanceFromBaseEmbeddingModelSetup() {
+ assertThat(pythonEmbeddingModelSetup)
+ .isInstanceOf(
+
org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup.class);
+ }
+
+ @Test
+ void testImplementsPythonResourceWrapper() {
+ assertThat(pythonEmbeddingModelSetup)
+ .isInstanceOf(
+
org.apache.flink.agents.api.resource.python.PythonResourceWrapper.class);
+ }
+}
diff --git
a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java
index 473b4bb..7ec8475 100644
---
a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java
+++
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java
@@ -31,10 +31,9 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.List;
-import java.util.Objects;
-import java.util.concurrent.TimeUnit;
import static
org.apache.flink.agents.resource.test.ChatModelCrossLanguageAgent.OLLAMA_MODEL;
+import static
org.apache.flink.agents.resource.test.OllamaPreparationUtils.pullModel;
public class ChatModelCrossLanguageTest {
private static final Logger LOG =
LoggerFactory.getLogger(ChatModelCrossLanguageTest.class);
@@ -105,22 +104,4 @@ public class ChatModelCrossLanguageTest {
}
}
}
-
- public static boolean pullModel(String model) throws IOException {
- String path =
- Objects.requireNonNull(
- ChatModelCrossLanguageTest.class
- .getClassLoader()
- .getResource("ollama_pull_model.sh"))
- .getPath();
- ProcessBuilder builder = new ProcessBuilder("bash", path, model);
- Process process = builder.start();
- try {
- process.waitFor(120, TimeUnit.SECONDS);
- return process.exitValue() == 0;
- } catch (Exception e) {
- LOG.warn("Pull {} failed, will skip test", model);
- }
- return false;
- }
}
diff --git
a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageAgent.java
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageAgent.java
new file mode 100644
index 0000000..a0c6d89
--- /dev/null
+++
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageAgent.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.resource.test;
+
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.flink.agents.api.InputEvent;
+import org.apache.flink.agents.api.OutputEvent;
+import org.apache.flink.agents.api.agents.Agent;
+import org.apache.flink.agents.api.annotation.Action;
+import org.apache.flink.agents.api.annotation.EmbeddingModelConnection;
+import org.apache.flink.agents.api.annotation.EmbeddingModelSetup;
+import org.apache.flink.agents.api.context.RunnerContext;
+import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup;
+import
org.apache.flink.agents.api.embedding.model.python.PythonEmbeddingModelConnection;
+import
org.apache.flink.agents.api.embedding.model.python.PythonEmbeddingModelSetup;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Integration test agent for verifying embedding functionality with Python
embedding model.
+ *
+ * <p>This test agent validates: - Python embedding model integration - Vector
generation and
+ * processing - Embedding dimension consistency - Tool integration for
embedding operations - Error
+ * handling in embedding generation
+ *
+ * <p>Used for e2e testing of the embedding model subsystem.
+ */
+public class EmbeddingCrossLanguageAgent extends Agent {
+ public static final String OLLAMA_MODEL = "nomic-embed-text";
+ private static final ObjectMapper MAPPER = new ObjectMapper();
+
+ @EmbeddingModelConnection
+ public static ResourceDescriptor embeddingConnection() {
+ return
ResourceDescriptor.Builder.newBuilder(PythonEmbeddingModelConnection.class.getName())
+ .addInitialArgument(
+ "module",
+
"flink_agents.integrations.embedding_models.local.ollama_embedding_model")
+ .addInitialArgument("clazz", "OllamaEmbeddingModelConnection")
+ .build();
+ }
+
+ @EmbeddingModelSetup
+ public static ResourceDescriptor embeddingModel() {
+ return
ResourceDescriptor.Builder.newBuilder(PythonEmbeddingModelSetup.class.getName())
+ .addInitialArgument(
+ "module",
+
"flink_agents.integrations.embedding_models.local.ollama_embedding_model")
+ .addInitialArgument("clazz", "OllamaEmbeddingModelSetup")
+ .addInitialArgument("connection", "embeddingConnection")
+ .addInitialArgument("model", OLLAMA_MODEL)
+ .build();
+ }
+
+ /** Main test action that processes input and validates embedding
generation. */
+ @Action(listenEvents = {InputEvent.class})
+ public static void testEmbeddingGeneration(InputEvent event, RunnerContext
ctx)
+ throws Exception {
+ String input = (String) event.getInput();
+ MAPPER.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES,
false);
+
+ // Parse test input
+ Map<String, Object> inputData;
+ try {
+ inputData = MAPPER.readValue(input, Map.class);
+ } catch (Exception e) {
+ inputData = new HashMap<>();
+ inputData.put("text", input);
+ inputData.put("id", "test_doc_" + System.currentTimeMillis());
+ }
+
+ String text = (String) inputData.get("text");
+ String id = (String) inputData.getOrDefault("id", "test_doc_" +
System.currentTimeMillis());
+
+ if (text == null || text.trim().isEmpty()) {
+ throw new AssertionError("Test input must contain valid text");
+ }
+
+ // Store test data in memory
+ ctx.getShortTermMemory().set("test_id", id);
+ ctx.getShortTermMemory().set("test_text", text);
+
+ try {
+ // Generate embedding using Ollama
+ BaseEmbeddingModelSetup embeddingModel =
+ (BaseEmbeddingModelSetup)
+ ctx.getResource(
+ "embeddingModel",
+
org.apache.flink.agents.api.resource.ResourceType
+ .EMBEDDING_MODEL);
+
+ float[] embedding = embeddingModel.embed(text);
+ System.out.printf("[TEST] Generated embedding with dimension:
%d%n", embedding.length);
+ validateEmbeddingResult(id, text, embedding);
+
+ List<float[]> embeddings = embeddingModel.embed(List.of(text));
+ validateEmbeddingResults(id, List.of(text), embeddings);
+
+ // Create a minimal test result to avoid serialization issues
+ Map<String, Object> testResult = new HashMap<>();
+ testResult.put("test_status", "PASSED");
+ testResult.put("id", id);
+
+ ctx.sendEvent(new OutputEvent(testResult));
+
+ System.out.printf(
+ "[TEST] Embedding generation test PASSED for: '%s'%n",
+ text.substring(0, Math.min(50, text.length())));
+
+ } catch (Exception e) {
+ // Create minimal error result
+ Map<String, Object> testResult = new HashMap<>();
+ testResult.put("test_status", "FAILED");
+ testResult.put("error", e.getMessage());
+ testResult.put("id", id);
+
+ ctx.sendEvent(new OutputEvent(testResult));
+
+ System.err.printf("[TEST] Embedding generation test FAILED: %s%n",
e.getMessage());
+ throw e; // Re-throw for test failure reporting
+ }
+ }
+
+ /** Validate embedding result. */
+ public static void validateEmbeddingResult(String id, String text, float[]
embedding) {
+
+ // Validation assertions for testing
+ if (embedding == null || embedding.length == 0) {
+ throw new AssertionError("Embedding cannot be null or empty");
+ }
+
+ System.out.printf(
+ "[TEST] Validated embedding: ID=%s, Dimension=%d,
Text='%s...'%n",
+ id, embedding.length, text.substring(0, Math.min(30,
text.length())));
+ }
+
+ /** Validate embedding results. */
+ public static void validateEmbeddingResults(
+ String id, List<String> texts, List<float[]> embeddings) {
+
+ // Validation assertions for testing
+ if (embeddings == null || embeddings.isEmpty()) {
+ throw new AssertionError("Embedding cannot be null or empty");
+ }
+
+ if (texts.size() != embeddings.size()) {
+ throw new AssertionError("Text and embedding lists must have the
same size");
+ }
+
+ for (int i = 0; i < texts.size(); i++) {
+ String text = texts.get(i);
+ float[] embedding = embeddings.get(i);
+ validateEmbeddingResult(id, text, embedding);
+ }
+ }
+}
diff --git
a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java
new file mode 100644
index 0000000..8806ac9
--- /dev/null
+++
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.resource.test;
+
+import org.apache.flink.agents.api.AgentsExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.util.CloseableIterator;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Assumptions;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.Map;
+
+import static
org.apache.flink.agents.resource.test.OllamaPreparationUtils.pullModel;
+
+/**
+ * Example application that applies {@link EmbeddingCrossLanguageAgent} to a
DataStream of prompts.
+ */
+public class EmbeddingCrossLanguageTest {
+
+ private final boolean ollamaReady;
+
+ public EmbeddingCrossLanguageTest() throws IOException {
+ ollamaReady = pullModel(EmbeddingCrossLanguageAgent.OLLAMA_MODEL);
+ }
+
+ @Test
+ public void testEmbeddingIntegration() throws Exception {
+ Assumptions.assumeTrue(ollamaReady, "Ollama Server information is not
provided");
+
+ // Create the execution environment
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(1);
+
+ // Use prompts that exercise embedding generation and similarity checks
+ DataStream<String> inputStream =
+ env.fromData(
+ "Generate embedding for: 'Machine learning'",
+ "Generate embedding for: 'Deep learning techniques'",
+ "Find texts similar to: 'neural networks'",
+ "Produce embedding and return top-3 similar items for:
'natural language processing'",
+ "Generate embedding for: 'hello world'",
+ "Compare similarity between 'cat' and 'dog'",
+ "Create embedding for: 'space exploration'",
+ "Find nearest neighbors for: 'artificial
intelligence'",
+ "Generate embedding for: 'data science'",
+ "Random embedding test");
+
+ // Create agents execution environment
+ AgentsExecutionEnvironment agentsEnv =
+ AgentsExecutionEnvironment.getExecutionEnvironment(env);
+
+ // Apply agent to the DataStream and use the prompt itself as the key
+ DataStream<Object> outputStream =
+ agentsEnv
+ .fromDataStream(inputStream, (KeySelector<String,
String>) value -> value)
+ .apply(new EmbeddingCrossLanguageAgent())
+ .toDataStream();
+
+ // Collect the results
+ CloseableIterator<Object> results = outputStream.collectAsync();
+
+ // Execute the pipeline
+ agentsEnv.execute();
+
+ checkResult(results);
+ }
+
+ @SuppressWarnings("unchecked")
+ private void checkResult(CloseableIterator<Object> results) {
+ for (int i = 1; i <= 10; i++) {
+ Assertions.assertTrue(
+ results.hasNext(),
+ String.format("Output messages count %s is less than
expected 10.", i));
+ Map<String, Object> res = (Map<String, Object>) results.next();
+ Assertions.assertEquals("PASSED", res.get("test_status"));
+ }
+ }
+}
diff --git
a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/OllamaPreparationUtils.java
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/OllamaPreparationUtils.java
new file mode 100644
index 0000000..a42809e
--- /dev/null
+++
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/OllamaPreparationUtils.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.resource.test;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+
+public class OllamaPreparationUtils {
+ private static final Logger LOG =
LoggerFactory.getLogger(OllamaPreparationUtils.class);
+
+ public static boolean pullModel(String model) throws IOException {
+ String path =
+ Objects.requireNonNull(
+ OllamaPreparationUtils.class
+ .getClassLoader()
+ .getResource("ollama_pull_model.sh"))
+ .getPath();
+ ProcessBuilder builder = new ProcessBuilder("bash", path, model);
+ Process process = builder.start();
+ try {
+ process.waitFor(120, TimeUnit.SECONDS);
+ return process.exitValue() == 0;
+ } catch (Exception e) {
+ LOG.warn("Pull {} failed, will skip test", model);
+ }
+ return false;
+ }
+}
diff --git
a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java
b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java
index cc96e17..6ada023 100644
---
a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java
+++
b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java
@@ -23,6 +23,7 @@ import io.github.ollama4j.exceptions.OllamaException;
import io.github.ollama4j.models.embed.OllamaEmbedRequest;
import io.github.ollama4j.models.embed.OllamaEmbedResult;
import
org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection;
+import org.apache.flink.agents.api.embedding.model.EmbeddingModelUtils;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
@@ -74,13 +75,7 @@ public class OllamaEmbeddingModelConnection extends
BaseEmbeddingModelConnection
List<Double> embedding = embeddings.get(0);
- // Convert to float array
- float[] result = new float[embedding.size()];
- for (int i = 0; i < embedding.size(); i++) {
- result[i] = embedding.get(i).floatValue();
- }
-
- return result;
+ return EmbeddingModelUtils.toFloatArray(embedding);
} catch (OllamaException e) {
throw new RuntimeException("Error generating embeddings for text:
" + text, e);
@@ -110,10 +105,7 @@ public class OllamaEmbeddingModelConnection extends
BaseEmbeddingModelConnection
// Convert to float arrays
List<float[]> results = new ArrayList<>();
for (List<Double> embedding : embeddings) {
- float[] result = new float[embedding.size()];
- for (int i = 0; i < embedding.size(); i++) {
- result[i] = embedding.get(i).floatValue();
- }
+ float[] result = EmbeddingModelUtils.toFloatArray(embedding);
results.add(result);
}
diff --git a/python/flink_agents/api/embedding_models/embedding_model.py
b/python/flink_agents/api/embedding_models/embedding_model.py
index c5384e9..2ec369a 100644
--- a/python/flink_agents/api/embedding_models/embedding_model.py
+++ b/python/flink_agents/api/embedding_models/embedding_model.py
@@ -16,7 +16,7 @@
# limitations under the License.
#################################################################################
from abc import ABC, abstractmethod
-from typing import Any, Dict
+from typing import Any, Dict, Sequence
from pydantic import Field
from typing_extensions import override
@@ -45,7 +45,7 @@ class BaseEmbeddingModelConnection(Resource, ABC):
return ResourceType.EMBEDDING_MODEL_CONNECTION
@abstractmethod
- def embed(self, text: str, **kwargs: Any) -> list[float]:
+ def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] |
list[list[float]]:
"""Generate embedding vector for a single text input.
Converts the input text into a high-dimensional vector representation
diff --git
a/python/flink_agents/integrations/embedding_models/local/ollama_embedding_model.py
b/python/flink_agents/integrations/embedding_models/local/ollama_embedding_model.py
index 1829703..174fd7f 100644
---
a/python/flink_agents/integrations/embedding_models/local/ollama_embedding_model.py
+++
b/python/flink_agents/integrations/embedding_models/local/ollama_embedding_model.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-from typing import Any, Dict
+from typing import Any, Dict, Sequence
from ollama import Client
from pydantic import Field
@@ -77,7 +77,7 @@ class
OllamaEmbeddingModelConnection(BaseEmbeddingModelConnection):
self.__client = Client(host=self.base_url,
timeout=self.request_timeout)
return self.__client
- def embed(self, text: str, **kwargs: Any) -> list[float]:
+ def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] |
list[list[float]]:
"""Generate embedding vector for a single text query."""
# Extract specific parameters
model = kwargs.pop("model")
@@ -92,7 +92,9 @@ class
OllamaEmbeddingModelConnection(BaseEmbeddingModelConnection):
keep_alive=keep_alive,
options=kwargs,
)
- return list(response.embeddings[0])
+
+ embeddings = [list(embedding) for embedding in response.embeddings]
+ return embeddings[0] if isinstance(text, str) else embeddings
class OllamaEmbeddingModelSetup(BaseEmbeddingModelSetup):
diff --git
a/python/flink_agents/integrations/embedding_models/openai_embedding_model.py
b/python/flink_agents/integrations/embedding_models/openai_embedding_model.py
index 30b8ee1..73cf65f 100644
---
a/python/flink_agents/integrations/embedding_models/openai_embedding_model.py
+++
b/python/flink_agents/integrations/embedding_models/openai_embedding_model.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-from typing import Any, Dict
+from typing import Any, Dict, Sequence
from openai import NOT_GIVEN, OpenAI
from pydantic import Field
@@ -110,7 +110,7 @@ class
OpenAIEmbeddingModelConnection(BaseEmbeddingModelConnection):
)
return self.__client
- def embed(self, text: str, **kwargs: Any) -> list[float]:
+ def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] |
list[list[float]]:
"""Generate embedding vector for a single text query."""
# Extract OpenAI specific parameters
model = kwargs.pop("model")
@@ -127,7 +127,8 @@ class
OpenAIEmbeddingModelConnection(BaseEmbeddingModelConnection):
user=user if user is not None else NOT_GIVEN,
)
- return list(response.data[0].embedding)
+ embeddings = [list(embedding.embedding) for embedding in response.data]
+ return embeddings[0] if isinstance(text, str) else embeddings
class OpenAIEmbeddingModelSetup(BaseEmbeddingModelSetup):
diff --git a/python/flink_agents/plan/tests/test_agent_plan.py
b/python/flink_agents/plan/tests/test_agent_plan.py
index 0b28de6..1657a02 100644
--- a/python/flink_agents/plan/tests/test_agent_plan.py
+++ b/python/flink_agents/plan/tests/test_agent_plan.py
@@ -110,9 +110,11 @@ class MockChatModelImpl(BaseChatModelSetup): # noqa: D101
class MockEmbeddingModelConnection(BaseEmbeddingModelConnection): # noqa: D101
api_key: str
- def embed(self, text: str, **kwargs: Any) -> list[float]:
+ def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float]:
"""Testing Implementation."""
- return [0.1234, -0.5678, 0.9012, -0.3456, 0.7890]
+ if isinstance(text, str):
+ return [0.1234, -0.5678, 0.9012, -0.3456, 0.7890]
+ return [[0.1234, -0.5678, 0.9012, -0.3456, 0.7890]]
class MockEmbeddingModelSetup(BaseEmbeddingModelSetup): # noqa: D101