This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 8a61cf7 feat(llm): support batch embedding (#238)
8a61cf7 is described below
commit 8a61cf7991fe9dc264ad197e8ed1fd52e61494a9
Author: Linyu <[email protected]>
AuthorDate: Wed May 21 19:49:15 2025 +0800
feat(llm): support batch embedding (#238)
## Implement Batch Embedding by modifying the underlying LLM interaction
interface
- [✔] ollama
- [✔ ] openai
- [✔ ] qianfan
I've modified the original concurrent call to a single batch call in
build_semantic_index.py & perform a simple test
close #233
---------
Co-authored-by: imbajin <[email protected]>
---
hugegraph-llm/pyproject.toml | 2 +-
hugegraph-llm/requirements.txt | 2 +-
.../src/hugegraph_llm/models/embeddings/base.py | 23 ++++++++++++++++
.../src/hugegraph_llm/models/embeddings/ollama.py | 26 +++++++++++++++++-
.../src/hugegraph_llm/models/embeddings/openai.py | 24 +++++++++++++++++
.../src/hugegraph_llm/models/embeddings/qianfan.py | 8 ++++++
.../index_op/build_gremlin_example_index.py | 3 ++-
.../operators/index_op/build_semantic_index.py | 31 +++++++++++++---------
.../index_op/gremlin_example_index_query.py | 2 +-
9 files changed, 103 insertions(+), 18 deletions(-)
diff --git a/hugegraph-llm/pyproject.toml b/hugegraph-llm/pyproject.toml
index 3f72aa5..dfb2681 100644
--- a/hugegraph-llm/pyproject.toml
+++ b/hugegraph-llm/pyproject.toml
@@ -39,7 +39,7 @@ documentation =
"https://hugegraph.apache.org/docs/quickstart/hugegraph-ai/"
[tool.poetry.dependencies]
python = "^3.10,<3.12"
openai = "~1.61.0"
-ollama = "~0.2.1"
+ollama = "~0.4.8"
qianfan = "~0.3.18"
retry = "~0.9.2"
tiktoken = ">=0.7.0"
diff --git a/hugegraph-llm/requirements.txt b/hugegraph-llm/requirements.txt
index 7467ec6..3abe63e 100644
--- a/hugegraph-llm/requirements.txt
+++ b/hugegraph-llm/requirements.txt
@@ -1,5 +1,5 @@
openai~=1.61.0
-ollama~=0.2.1
+ollama~=0.4.8
qianfan~=0.3.18
retry~=0.9.2
tiktoken>=0.7.0
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
index 2ea8786..73e973e 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
@@ -60,6 +60,29 @@ class BaseEmbedding(ABC):
) -> List[float]:
"""Comment"""
+ @abstractmethod
+ def get_texts_embeddings(
+ self,
+ texts: List[str]
+ ) -> List[List[float]]:
+ """Get embeddings for multiple texts in a single batch.
+
+ This method should efficiently process multiple texts at once by
leveraging
+ the embedding model's batching capabilities, which is typically more
efficient
+ than processing texts individually.
+
+ Parameters
+ ----------
+ texts : List[str]
+ A list of text strings to be embedded.
+
+ Returns
+ -------
+ List[List[float]]
+ A list of embedding vectors, where each vector is a list of floats.
+ The order of embeddings should match the order of input texts.
+ """
+
@abstractmethod
async def async_get_text_embedding(
self,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
index 81e11cc..062e098 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
@@ -40,7 +40,31 @@ class OllamaEmbedding(BaseEmbedding):
text: str
) -> List[float]:
"""Comment"""
- return list(self.client.embeddings(model=self.model,
prompt=text)["embedding"])
+ return list(self.client.embed(model=self.model,
input=text)["embeddings"][0])
+
+ def get_texts_embeddings(
+ self,
+ texts: List[str]
+ ) -> List[List[float]]:
+ """Get embeddings for multiple texts in a single batch.
+
+ This method efficiently processes multiple texts at once by leveraging
+ Ollama's batching capabilities, which is more efficient than processing
+ texts individually.
+
+ Parameters
+ ----------
+ texts : List[str]
+ A list of text strings to be embedded.
+
+ Returns
+ -------
+ List[List[float]]
+ A list of embedding vectors, where each vector is a list of floats.
+ The order of embeddings matches the order of input texts.
+ """
+ response = self.client.embed(model=self.model,
input=texts)["embeddings"]
+ return [list(inner_sequence) for inner_sequence in response]
async def async_get_text_embedding(
self,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
index aacef1e..890f491 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
@@ -38,6 +38,30 @@ class OpenAIEmbedding:
response = self.client.embeddings.create(input=text,
model=self.embedding_model_name)
return response.data[0].embedding
+ def get_texts_embeddings(
+ self,
+ texts: List[str]
+ ) -> List[List[float]]:
+ """Get embeddings for multiple texts in a single batch.
+
+ This method efficiently processes multiple texts at once by leveraging
+ OpenAI's batching capabilities, which is more efficient than processing
+ texts individually.
+
+ Parameters
+ ----------
+ texts : List[str]
+ A list of text strings to be embedded.
+
+ Returns
+ -------
+ List[List[float]]
+ A list of embedding vectors, where each vector is a list of floats.
+ The order of embeddings matches the order of input texts.
+ """
+ response = self.client.embeddings.create(input=texts,
model=self.embedding_model_name)
+ return [data.embedding for data in response.data]
+
async def async_get_text_embedding(self, text: str) -> List[float]:
"""Comment"""
response = await self.aclient.embeddings.create(input=text,
model=self.embedding_model_name)
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
index 1745eb2..99eeb59 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
@@ -51,6 +51,14 @@ class QianFanEmbedding:
)
return response["body"]["data"][0]["embedding"]
+ def get_texts_embeddings(self, texts: List[str]) -> List[List[float]]:
+ """ Usage refer:
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn"""
+ response = self.client.do(
+ model=self.embedding_model_name,
+ texts=texts
+ )
+ return [data["embedding"] for data in response["body"]["data"]]
+
async def async_get_text_embedding(self, text: str) -> List[float]:
""" Usage refer:
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn"""
response = await self.client.ado(
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
index 4e15274..b865bc6 100644
---
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
+++
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
@@ -19,11 +19,12 @@
import os
from typing import Dict, Any, List
-from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.config import resource_path
from hugegraph_llm.indices.vector_index import VectorIndex
+from hugegraph_llm.models.embeddings.base import BaseEmbedding
+# FIXME: we need keep the logic same with build_semantic_index.py
class BuildGremlinExampleIndex:
def __init__(self, embedding: BaseEmbedding, examples: List[Dict[str,
str]]):
self.index_dir = os.path.join(resource_path, "gremlin_examples")
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
index f8a911e..ce64442 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
@@ -25,9 +25,8 @@ from tqdm import tqdm
from hugegraph_llm.config import resource_path, huge_settings
from hugegraph_llm.indices.vector_index import VectorIndex
from hugegraph_llm.models.embeddings.base import BaseEmbedding
-from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
from hugegraph_llm.utils.log import log
-
+from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
class BuildSemanticIndex:
def __init__(self, embedding: BaseEmbedding):
@@ -41,35 +40,38 @@ class BuildSemanticIndex:
async def _get_embeddings_parallel(self, vids: list[str]) -> list[Any]:
sem = asyncio.Semaphore(10)
-
- async def get_embedding_with_semaphore(vid: str) -> Any:
+ batch_size = 1000
+ async def get_embeddings_with_semaphore(vid_list: list[str]) -> Any:
# Executes sync embedding method in a thread pool via
loop.run_in_executor, combining async programming
# with multi-threading capabilities.
# This pattern avoids blocking the event loop and prepares for a
future fully async pipeline.
async with sem:
loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None,
self.embedding.get_text_embedding, vid)
+ return await loop.run_in_executor(None,
self.embedding.get_texts_embeddings, vid_list)
+
+ # Split vids into batches of size batch_size
+ vid_batches = [vids[i:i + batch_size] for i in range(0, len(vids),
batch_size)]
+
+ # Create tasks for each batch
+ tasks = [get_embeddings_with_semaphore(batch) for batch in vid_batches]
- tasks = [get_embedding_with_semaphore(vid) for vid in vids]
embeddings = []
with tqdm(total=len(tasks)) as pbar:
for future in asyncio.as_completed(tasks):
- embedding = await future
- embeddings.append(embedding)
+ batch_embeddings = await future
+ embeddings.extend(batch_embeddings) # Extend the list with
batch results
pbar.update(1)
return embeddings
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
vertexlabels = self.sm.schema.getSchema()["vertexlabels"]
- all_pk_flag = all(data.get("id_strategy") == "PRIMARY_KEY" for data in
vertexlabels)
+ all_pk_flag = all(data.get('id_strategy') == 'PRIMARY_KEY' for data in
vertexlabels)
past_vids = self.vid_index.properties
# TODO: We should build vid vector index separately, especially when
the vertices may be very large
- present_vids = context["vertices"] # Warning: data truncated by
fetch_graph_data.py
+ present_vids = context["vertices"] # Warning: data truncated by
fetch_graph_data.py
removed_vids = set(past_vids) - set(present_vids)
removed_num = self.vid_index.remove(removed_vids)
- if removed_vids:
- self.vid_index.to_index_file(self.index_dir)
added_vids = list(set(present_vids) - set(past_vids))
if added_vids:
@@ -80,5 +82,8 @@ class BuildSemanticIndex:
self.vid_index.to_index_file(self.index_dir)
else:
log.debug("No update vertices to build vector index.")
- context.update({"removed_vid_vector_num": removed_num,
"added_vid_vector_num": len(added_vids)})
+ context.update({
+ "removed_vid_vector_num": removed_num,
+ "added_vid_vector_num": len(added_vids)
+ })
return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
index 31d9c2b..b8acd50 100644
---
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
+++
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
@@ -57,7 +57,7 @@ class GremlinExampleIndexQuery:
def _build_default_example_index(self):
properties = pd.read_csv(os.path.join(resource_path, "demo",
"text2gremlin.csv")).to_dict(orient="records")
from concurrent.futures import ThreadPoolExecutor
- # TODO: use asyncio for IO tasks
+ # TODO: reuse the logic in build_semantic_index.py (consider extract
the batch-embedding method)
with ThreadPoolExecutor() as executor:
embeddings = list(
tqdm(executor.map(self.embedding.get_text_embedding,
[row["query"] for row in properties]),