This is an automated email from the ASF dual-hosted git repository.
xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new e7f2adb [Runtime] Simplify the vector store setup process (#231)
e7f2adb is described below
commit e7f2adbcad4df12bcfe86960a861e5c83ff18eb8
Author: Alan Z. <[email protected]>
AuthorDate: Sun Sep 28 05:31:13 2025 -0700
[Runtime] Simplify the vector store setup process (#231)
---
.../flink/agents/api/resource/ResourceType.java | 1 -
python/flink_agents/api/decorators.py | 24 +----
python/flink_agents/api/resource.py | 1 -
.../flink_agents/api/vector_stores/vector_store.py | 61 ++++-------
.../vector_stores/chroma/chroma_vector_store.py | 116 +++++++++------------
.../chroma/tests/test_chroma_vector_store.py | 46 ++++----
python/flink_agents/plan/agent_plan.py | 4 +-
.../plan/tests/resources/agent_plan.json | 18 +---
python/flink_agents/plan/tests/test_agent_plan.py | 48 +++------
9 files changed, 111 insertions(+), 208 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceType.java
b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceType.java
index 700f484..333e12f 100644
--- a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceType.java
+++ b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceType.java
@@ -29,7 +29,6 @@ public enum ResourceType {
EMBEDDING_MODEL("embedding_model"),
EMBEDDING_MODEL_CONNECTION("embedding_model_connection"),
VECTOR_STORE("vector_store"),
- VECTOR_STORE_CONNECTION("vector_store_connection"),
PROMPT("prompt"),
TOOL("tool");
diff --git a/python/flink_agents/api/decorators.py
b/python/flink_agents/api/decorators.py
index fe6ecb7..b45e55a 100644
--- a/python/flink_agents/api/decorators.py
+++ b/python/flink_agents/api/decorators.py
@@ -173,26 +173,8 @@ def mcp_server(func: Callable) -> Callable:
return func
-def vector_store_connection(func: Callable) -> Callable:
- """Decorator for marking a function declaring a vector store connection.
-
- Parameters
- ----------
- func : Callable
- Function to be decorated.
-
- Returns:
- -------
- Callable
- Decorator function that marks the target function declare a vector
store
- connection.
- """
- func._is_vector_store_connection = True
- return func
-
-
-def vector_store_setup(func: Callable) -> Callable:
- """Decorator for marking a function declaring a vector store setup.
+def vector_store(func: Callable) -> Callable:
+ """Decorator for marking a function declaring a vector store.
Parameters
----------
@@ -204,5 +186,5 @@ def vector_store_setup(func: Callable) -> Callable:
Callable
Decorator function that marks the target function declare a vector
store.
"""
- func._is_vector_store_setup = True
+ func._is_vector_store = True
return func
diff --git a/python/flink_agents/api/resource.py
b/python/flink_agents/api/resource.py
index 5659e58..f50f76c 100644
--- a/python/flink_agents/api/resource.py
+++ b/python/flink_agents/api/resource.py
@@ -35,7 +35,6 @@ class ResourceType(Enum):
EMBEDDING_MODEL = "embedding_model"
EMBEDDING_MODEL_CONNECTION = "embedding_model_connection"
VECTOR_STORE = "vector_store"
- VECTOR_STORE_CONNECTION = "vector_store_connection"
PROMPT = "prompt"
MCP_SERVER = "mcp_server"
diff --git a/python/flink_agents/api/vector_stores/vector_store.py
b/python/flink_agents/api/vector_stores/vector_store.py
index 208a591..8feb2b2 100644
--- a/python/flink_agents/api/vector_stores/vector_store.py
+++ b/python/flink_agents/api/vector_stores/vector_store.py
@@ -110,43 +110,14 @@ class VectorStoreQueryResult(BaseModel):
return f"QueryResult: {len(self.documents)} documents"
-class BaseVectorStoreConnection(Resource, ABC):
- """Base abstract class for vector store connection.
+class BaseVectorStore(Resource, ABC):
+ """Base abstract class for vector store.
- Manages connection configuration and provides raw vector search operations
- using pre-computed embeddings. One connection can be shared across multiple
- vector store setups.
+ Provides vector store functionality that integrates embedding models
+ for text-based semantic search. Handles both connection management and
+ embedding generation internally.
"""
- @classmethod
- @override
- def resource_type(cls) -> ResourceType:
- """Return resource type of class."""
- return ResourceType.VECTOR_STORE_CONNECTION
-
- @abstractmethod
- def query(self, embedding: List[float], limit: int = 10, **kwargs: Any) ->
List[Document]:
- """Perform vector search using pre-computed embedding.
-
- Args:
- embedding: Pre-computed embedding vector for semantic search
- limit: Maximum number of results to return (default: 10)
- **kwargs: Vector store-specific parameters (filters, distance
metrics, etc.)
-
- Returns:
- List of documents matching the search criteria
- """
-
-
-class BaseVectorStoreSetup(Resource, ABC):
- """Base abstract class for vector store setup.
-
- Coordinates between vector store connections and embedding models to
provide
- text-based semantic search. Automatically converts text queries to
embeddings
- before delegating to the connection layer.
- """
-
- connection: str = Field(description="Name of the referenced connection.")
embedding_model: str = Field(description="Name of the embedding model
resource to use.")
@classmethod
@@ -181,19 +152,27 @@ class BaseVectorStoreSetup(Resource, ABC):
)
query_embedding = embedding_model.embed(query.query_text)
- # Get vector store connection resource
- connection = self.get_resource(
- self.connection, ResourceType.VECTOR_STORE_CONNECTION
- )
-
# Merge setup kwargs with query-specific args
merged_kwargs = self.store_kwargs.copy()
merged_kwargs.update(query.extra_args)
- # Perform vector search
- documents = connection.query(query_embedding, query.limit,
**merged_kwargs)
+ # Perform vector search using the abstract method
+ documents = self.query_embedding(query_embedding, query.limit,
**merged_kwargs)
# Return structured result
return VectorStoreQueryResult(
documents=documents
)
+
+ @abstractmethod
+ def query_embedding(self, embedding: List[float], limit: int = 10,
**kwargs: Any) -> List[Document]:
+ """Perform vector search using pre-computed embedding.
+
+ Args:
+ embedding: Pre-computed embedding vector for semantic search
+ limit: Maximum number of results to return (default: 10)
+ **kwargs: Vector store-specific parameters (filters, distance
metrics, etc.)
+
+ Returns:
+ List of documents matching the search criteria
+ """
diff --git
a/python/flink_agents/integrations/vector_stores/chroma/chroma_vector_store.py
b/python/flink_agents/integrations/vector_stores/chroma/chroma_vector_store.py
index 9c4fc33..6be8045 100644
---
a/python/flink_agents/integrations/vector_stores/chroma/chroma_vector_store.py
+++
b/python/flink_agents/integrations/vector_stores/chroma/chroma_vector_store.py
@@ -24,16 +24,15 @@ from chromadb.config import Settings
from pydantic import Field
from flink_agents.api.vector_stores.vector_store import (
- BaseVectorStoreConnection,
- BaseVectorStoreSetup,
+ BaseVectorStore,
Document,
)
DEFAULT_COLLECTION = "flink_agents_chroma_collection"
-class ChromaVectorStoreConnection(BaseVectorStoreConnection):
- """ChromaDB Vector Store Connection which manages connection to ChromaDB.
+class ChromaVectorStore(BaseVectorStore):
+ """ChromaDB vector store that handles connection and semantic search.
Visit https://docs.trychroma.com/ for ChromaDB documentation.
@@ -59,8 +58,15 @@ class ChromaVectorStoreConnection(BaseVectorStoreConnection):
ChromaDB tenant for multi-tenancy support (default: "default_tenant").
database : str
ChromaDB database name (default: "default_database").
+ collection : str
+ Name of the ChromaDB collection to use (default:
flink_agents_collection).
+ collection_metadata : Dict[str, Any]
+ Metadata for the collection (optional).
+ create_collection_if_not_exists : bool
+ Whether to create the collection if it doesn't exist (default: True).
"""
+ # Connection configuration
persist_directory: str | None = Field(
default=None,
description="Directory for persistent storage. If None, uses in-memory
client.",
@@ -90,10 +96,26 @@ class
ChromaVectorStoreConnection(BaseVectorStoreConnection):
description="ChromaDB database name.",
)
+ # Collection configuration
+ collection: str = Field(
+ default=DEFAULT_COLLECTION,
+ description="Name of the ChromaDB collection to use.",
+ )
+ collection_metadata: Dict[str, Any] = Field(
+ default_factory=dict,
+ description="Metadata for the collection.",
+ )
+ create_collection_if_not_exists: bool = Field(
+ default=True,
+ description="Whether to create the collection if it doesn't exist.",
+ )
+
__client: ChromaClient | None = None
def __init__(
self,
+ *,
+ embedding_model: str,
persist_directory: str | None = None,
host: str | None = None,
port: int | None = 8000,
@@ -101,10 +123,16 @@ class
ChromaVectorStoreConnection(BaseVectorStoreConnection):
client_settings: Settings | None = None,
tenant: str = "default_tenant",
database: str = "default_database",
+ collection: str = DEFAULT_COLLECTION,
+ collection_metadata: Dict[str, Any] | None = None,
+ create_collection_if_not_exists: bool = True,
**kwargs: Any,
) -> None:
"""Init method."""
+ if collection_metadata is None:
+ collection_metadata = {}
super().__init__(
+ embedding_model=embedding_model,
persist_directory=persist_directory,
host=host,
port=port,
@@ -112,6 +140,9 @@ class
ChromaVectorStoreConnection(BaseVectorStoreConnection):
client_settings=client_settings,
tenant=tenant,
database=database,
+ collection=collection,
+ collection_metadata=collection_metadata,
+ create_collection_if_not_exists=create_collection_if_not_exists,
**kwargs,
)
@@ -119,7 +150,6 @@ class
ChromaVectorStoreConnection(BaseVectorStoreConnection):
def client(self) -> ChromaClient:
"""Return ChromaDB client, creating it if necessary."""
if self.__client is None:
-
# Choose client type based on configuration
if self.api_key is not None:
# Cloud mode
@@ -155,7 +185,16 @@ class
ChromaVectorStoreConnection(BaseVectorStoreConnection):
return self.__client
- def query(self, embedding: List[float], limit: int = 10, **kwargs: Any) ->
List[Document]:
+ @property
+ def store_kwargs(self) -> Dict[str, Any]:
+ """Return ChromaDB-specific setup settings."""
+ return {
+ "collection": self.collection,
+ "collection_metadata": self.collection_metadata,
+ "create_collection_if_not_exists":
self.create_collection_if_not_exists,
+ }
+
+ def query_embedding(self, embedding: List[float], limit: int = 10,
**kwargs: Any) -> List[Document]:
"""Perform vector search using pre-computed embedding.
Args:
@@ -167,9 +206,9 @@ class
ChromaVectorStoreConnection(BaseVectorStoreConnection):
List of documents matching the search criteria
"""
# Extract ChromaDB-specific parameters
- collection_name = kwargs.get("collection", DEFAULT_COLLECTION)
- collection_metadata = kwargs.get("collection_metadata", {})
- create_collection_if_not_exists =
kwargs.get("create_collection_if_not_exists", True)
+ collection_name = kwargs.get("collection", self.collection)
+ collection_metadata = kwargs.get("collection_metadata",
self.collection_metadata)
+ create_collection_if_not_exists =
kwargs.get("create_collection_if_not_exists",
self.create_collection_if_not_exists)
where = kwargs.get("where") # Metadata filters
# Get or create collection based on configuration
@@ -206,62 +245,3 @@ class
ChromaVectorStoreConnection(BaseVectorStoreConnection):
return documents
-
-class ChromaVectorStoreSetup(BaseVectorStoreSetup):
- """ChromaDB vector store setup which manages collection configuration
- and coordinates with embedding models for semantic search.
-
- Attributes:
- ----------
- collection : str
- Name of the ChromaDB collection to use (default:
flink_agents_collection).
- collection_metadata : Dict[str, Any]
- Metadata for the collection (optional).
- create_collection_if_not_exists : bool
- Whether to create the collection if it doesn't exist (default: True).
- """
-
- collection: str = Field(
- default=DEFAULT_COLLECTION,
- description="Name of the ChromaDB collection to use.",
- )
- collection_metadata: Dict[str, Any] = Field(
- default_factory=dict,
- description="Metadata for the collection.",
- )
- create_collection_if_not_exists: bool = Field(
- default=True,
- description="Whether to create the collection if it doesn't exist.",
- )
-
- def __init__(
- self,
- *,
- connection: str,
- embedding_model: str,
- collection: str = DEFAULT_COLLECTION,
- collection_metadata: Dict[str, Any] | None = None,
- create_collection_if_not_exists: bool = True,
- **kwargs: Any,
- ) -> None:
- """Init method."""
- if collection_metadata is None:
- collection_metadata = {}
- super().__init__(
- connection=connection,
- embedding_model=embedding_model,
- collection=collection,
- collection_metadata=collection_metadata,
- create_collection_if_not_exists=create_collection_if_not_exists,
- **kwargs,
- )
-
- @property
- def store_kwargs(self) -> Dict[str, Any]:
- """Return ChromaDB-specific setup settings passed to connection."""
- return {
- "collection": self.collection,
- "collection_metadata": self.collection_metadata,
- "create_collection_if_not_exists":
self.create_collection_if_not_exists,
- }
-
diff --git
a/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py
b/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py
index 9f0b468..233c6b6 100644
---
a/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py
+++
b/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py
@@ -32,8 +32,7 @@ from flink_agents.api.vector_stores.vector_store import (
VectorStoreQuery,
)
from flink_agents.integrations.vector_stores.chroma.chroma_vector_store import
(
- ChromaVectorStoreConnection,
- ChromaVectorStoreSetup,
+ ChromaVectorStore,
)
api_key = os.environ.get("TEST_API_KEY")
@@ -54,9 +53,9 @@ class MockEmbeddingModel(Resource): # noqa: D101
return [0.1, 0.2, 0.3, 0.4, 0.5]
-def _populate_test_data(connection: ChromaVectorStoreConnection) -> None:
+def _populate_test_data(vector_store: ChromaVectorStore) -> None:
"""Private helper method to populate ChromaDB with test data."""
- collection = connection.client.get_or_create_collection(
+ collection = vector_store.client.get_or_create_collection(
name="test_collection",
metadata=None,
)
@@ -82,72 +81,67 @@ def _populate_test_data(connection:
ChromaVectorStoreConnection) -> None:
@pytest.mark.skipif(
not chromadb_available, reason="ChromaDB is not available"
)
-def test_local_chroma_vector_store_setup() -> None:
- """Test ChromaDB vector store setup with embedding model integration."""
- connection = ChromaVectorStoreConnection(name="chroma_conn")
+def test_local_chroma_vector_store() -> None:
+ """Test ChromaDB vector store with embedding model integration."""
embedding_model = MockEmbeddingModel(name="mock_embeddings")
def get_resource(name: str, resource_type: ResourceType) -> Resource:
- if resource_type == ResourceType.VECTOR_STORE_CONNECTION:
- return connection
- elif resource_type == ResourceType.EMBEDDING_MODEL:
+ if resource_type == ResourceType.EMBEDDING_MODEL:
return embedding_model
else:
msg = f"Unknown resource type: {resource_type}"
raise ValueError(msg)
- setup = ChromaVectorStoreSetup(
- name="chroma_setup",
- connection="chroma_conn",
+ vector_store = ChromaVectorStore(
+ name="chroma_vector_store",
embedding_model="mock_embeddings",
collection="test_collection",
get_resource=get_resource
)
- _populate_test_data(connection)
+ _populate_test_data(vector_store)
query = VectorStoreQuery(
query_text="What is Flink Agent?",
limit=1
)
- result = setup.query(query)
+ result = vector_store.query(query)
assert result is not None
assert len(result.documents) == 1
assert result.documents[0].id == "doc2"
@pytest.mark.skipif(api_key is None, reason="TEST_API_KEY is not set")
-def test_cloud_chroma_vector_store_setup() -> None:
- """Test cloud ChromaDB vector store setup with embedding model
integration."""
- connection = ChromaVectorStoreConnection(name="cloud_chroma_conn",
api_key=api_key, tenant=tenant, database=database)
+def test_cloud_chroma_vector_store() -> None:
+ """Test cloud ChromaDB vector store with embedding model integration."""
embedding_model = MockEmbeddingModel(name="mock_embeddings")
def get_resource(name: str, resource_type: ResourceType) -> Resource:
- if resource_type == ResourceType.VECTOR_STORE_CONNECTION:
- return connection
- elif resource_type == ResourceType.EMBEDDING_MODEL:
+ if resource_type == ResourceType.EMBEDDING_MODEL:
return embedding_model
else:
msg = f"Unknown resource type: {resource_type}"
raise ValueError(msg)
- setup = ChromaVectorStoreSetup(
- name="chroma_setup",
- connection="cloud_chroma_conn",
+ vector_store = ChromaVectorStore(
+ name="chroma_vector_store",
embedding_model="mock_embeddings",
collection="test_collection",
+ api_key=api_key,
+ tenant=tenant,
+ database=database,
get_resource=get_resource
)
- _populate_test_data(connection)
+ _populate_test_data(vector_store)
query = VectorStoreQuery(
query_text="What is Flink Agent?",
limit=1
)
- result = setup.query(query)
+ result = vector_store.query(query)
assert result is not None
assert len(result.documents) == 1
assert result.documents[0].id == "doc2"
diff --git a/python/flink_agents/plan/agent_plan.py
b/python/flink_agents/plan/agent_plan.py
index 54f71c5..04abe3a 100644
--- a/python/flink_agents/plan/agent_plan.py
+++ b/python/flink_agents/plan/agent_plan.py
@@ -277,8 +277,7 @@ def _get_resource_providers(agent: Agent) ->
List[ResourceProvider]:
or hasattr(value, "_is_chat_model_connection")
or hasattr(value, "_is_embedding_model_setup")
or hasattr(value, "_is_embedding_model_connection")
- or hasattr(value, "_is_vector_store_setup")
- or hasattr(value, "_is_vector_store_connection")
+ or hasattr(value, "_is_vector_store")
):
if isinstance(value, staticmethod):
value = value.__func__
@@ -339,7 +338,6 @@ def _get_resource_providers(agent: Agent) ->
List[ResourceProvider]:
ResourceType.EMBEDDING_MODEL,
ResourceType.EMBEDDING_MODEL_CONNECTION,
ResourceType.VECTOR_STORE,
- ResourceType.VECTOR_STORE_CONNECTION,
]:
for name, descriptor in agent.resources[resource_type].items():
resource_providers.append(
diff --git a/python/flink_agents/plan/tests/resources/agent_plan.json
b/python/flink_agents/plan/tests/resources/agent_plan.json
index d451056..e226869 100644
--- a/python/flink_agents/plan/tests/resources/agent_plan.json
+++ b/python/flink_agents/plan/tests/resources/agent_plan.json
@@ -114,24 +114,12 @@
"name": "mock_vector_store",
"type": "vector_store",
"module": "flink_agents.plan.tests.test_agent_plan",
- "clazz": "MockVectorStoreSetup",
+ "clazz": "MockVectorStore",
"kwargs": {
- "connection": "mock_vector_conn",
"embedding_model": "mock_embedding",
- "collection_name": "test_collection"
- },
- "__resource_provider_type__": "PythonResourceProvider"
- }
- },
- "vector_store_connection": {
- "mock_vector_conn": {
- "name": "mock_vector_conn",
- "type": "vector_store_connection",
- "module": "flink_agents.plan.tests.test_agent_plan",
- "clazz": "MockVectorStoreConnection",
- "kwargs": {
"host": "localhost",
- "port": 8000
+ "port": 8000,
+ "collection_name": "test_collection"
},
"__resource_provider_type__": "PythonResourceProvider"
}
diff --git a/python/flink_agents/plan/tests/test_agent_plan.py
b/python/flink_agents/plan/tests/test_agent_plan.py
index 600457a..c75337b 100644
--- a/python/flink_agents/plan/tests/test_agent_plan.py
+++ b/python/flink_agents/plan/tests/test_agent_plan.py
@@ -29,8 +29,7 @@ from flink_agents.api.decorators import (
chat_model_setup,
embedding_model_connection,
embedding_model_setup,
- vector_store_connection,
- vector_store_setup,
+ vector_store,
)
from flink_agents.api.embedding_models.embedding_model import (
BaseEmbeddingModelConnection,
@@ -40,8 +39,7 @@ from flink_agents.api.events.event import Event, InputEvent,
OutputEvent
from flink_agents.api.resource import ResourceDescriptor, ResourceType
from flink_agents.api.runner_context import RunnerContext
from flink_agents.api.vector_stores.vector_store import (
- BaseVectorStoreConnection,
- BaseVectorStoreSetup,
+ BaseVectorStore,
Document,
)
from flink_agents.plan.agent_plan import AgentPlan
@@ -123,11 +121,16 @@ class MockEmbeddingModelSetup(BaseEmbeddingModelSetup):
# noqa: D101
return {"model": self.model}
-class MockVectorStoreConnection(BaseVectorStoreConnection): # noqa: D101
+class MockVectorStore(BaseVectorStore): # noqa: D101
host: str
port: int
+ collection_name: str
+
+ @property
+ def store_kwargs(self) -> Dict[str, Any]: # noqa: D102
+ return {"collection_name": self.collection_name}
- def query(
+ def query_embedding(
self, embedding: list[float], limit: int = 10, **kwargs: Any
) -> list[Document]:
"""Testing Implementation."""
@@ -145,14 +148,6 @@ class
MockVectorStoreConnection(BaseVectorStoreConnection): # noqa: D101
][:limit]
-class MockVectorStoreSetup(BaseVectorStoreSetup): # noqa: D101
- collection_name: str
-
- @property
- def store_kwargs(self) -> Dict[str, Any]: # noqa: D102
- return {"collection_name": self.collection_name}
-
-
class MyAgent(Agent): # noqa: D101
@chat_model_setup
@staticmethod
@@ -180,20 +175,14 @@ class MyAgent(Agent): # noqa: D101
connection="mock_embedding_conn",
)
- @vector_store_connection
- @staticmethod
- def mock_vector_conn() -> ResourceDescriptor: # noqa: D102
- return ResourceDescriptor(
- clazz=MockVectorStoreConnection, host="localhost", port=8000
- )
-
- @vector_store_setup
+ @vector_store
@staticmethod
def mock_vector_store() -> ResourceDescriptor: # noqa: D102
return ResourceDescriptor(
- clazz=MockVectorStoreSetup,
- connection="mock_vector_conn",
+ clazz=MockVectorStore,
embedding_model="mock_embedding",
+ host="localhost",
+ port=8000,
collection_name="test_collection",
)
@@ -275,18 +264,13 @@ def test_add_action_and_resource_to_agent() -> None: #
noqa: D103
connection="mock_embedding_conn",
),
)
- my_agent.add_resource(
- name="mock_vector_conn",
- instance=ResourceDescriptor(
- clazz=MockVectorStoreConnection, host="localhost", port=8000
- ),
- )
my_agent.add_resource(
name="mock_vector_store",
instance=ResourceDescriptor(
- clazz=MockVectorStoreSetup,
- connection="mock_vector_conn",
+ clazz=MockVectorStore,
embedding_model="mock_embedding",
+ host="localhost",
+ port=8000,
collection_name="test_collection",
),
)