This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new cd404b8c perf(llm): optimize vector index with asyncio embedding (#264)
cd404b8c is described below
commit cd404b8c7facade86bc541212e6e071ec6a61402
Author: Linyu <[email protected]>
AuthorDate: Thu Jul 31 17:02:58 2025 +0800
perf(llm): optimize vector index with asyncio embedding (#264)
## Changes
This PR introduces performance optimizations for vector index building
and querying by implementing parallel text embedding generation.
### Key Improvements
1. Added new utility class `embedding_utils.py` with parallel batch
processing capabilities
- Implements `get_embeddings_parallel` function for efficient batch
processing
- Uses asyncio with semaphore for controlled concurrency
- Supports batch size of 1000 with max 10 concurrent tasks
2. Refactored all index operation classes to use parallel processing:
- `BuildGremlinExampleIndex`
- `BuildSemanticIndex`
- `BuildVectorIndex`
- `GremlinExampleIndexQuery`
- `SemanticIdQuery`
- `VectorIndexQuery`
3. Unified embedding generation approach:
- Replaced individual `get_text_embedding` calls with batch
`get_texts_embeddings`
- Removed duplicate parallel processing code
- Improved code reusability and maintainabilityl
---------
Co-authored-by: imbajin <[email protected]>
---
.asf.yaml | 2 +-
.../config/models/base_prompt_config.py | 79 ++++++++++++++--------
.../src/hugegraph_llm/indices/vector_index.py | 8 ++-
.../src/hugegraph_llm/models/embeddings/base.py | 25 +++++--
.../src/hugegraph_llm/models/embeddings/litellm.py | 6 +-
.../src/hugegraph_llm/models/embeddings/ollama.py | 21 ++++--
.../src/hugegraph_llm/models/embeddings/openai.py | 24 +++++--
.../index_op/build_gremlin_example_index.py | 9 ++-
.../operators/index_op/build_semantic_index.py | 64 +++++++-----------
.../operators/index_op/build_vector_index.py | 7 +-
.../index_op/gremlin_example_index_query.py | 12 ++--
.../operators/index_op/semantic_id_query.py | 4 +-
.../operators/index_op/vector_index_query.py | 2 +-
.../src/hugegraph_llm/utils/embedding_utils.py | 62 +++++++++++++++++
.../src/hugegraph_llm/utils/vector_index_utils.py | 10 ++-
15 files changed, 228 insertions(+), 107 deletions(-)
diff --git a/.asf.yaml b/.asf.yaml
index cafdba4f..21bb671e 100644
--- a/.asf.yaml
+++ b/.asf.yaml
@@ -53,7 +53,7 @@ github:
# (for non-committer): assign/edit/close issues & PR, without write access
to the code
collaborators:
- ChenZiHong-Gavin
- - MrJs133
+ - weijinglin
- HJ-Young
- afterimagex
- returnToInnocence
diff --git
a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py
b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py
index 691247b3..23832bf9 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-import os
import sys
+import os
from pathlib import Path
import yaml
@@ -30,22 +30,24 @@ yaml_file_path = os.path.join(os.getcwd(),
"src/hugegraph_llm/resources/demo", F
class BasePromptConfig:
- graph_schema: str = ''
- extract_graph_prompt: str = ''
- default_question: str = ''
- custom_rerank_info: str = ''
- answer_prompt: str = ''
- keywords_extract_prompt: str = ''
- text2gql_graph_schema: str = ''
- gremlin_generate_prompt: str = ''
- doc_input_text: str = ''
- generate_extract_prompt_template: str = ''
+ graph_schema: str = ""
+ extract_graph_prompt: str = ""
+ default_question: str = ""
+ custom_rerank_info: str = ""
+ answer_prompt: str = ""
+ keywords_extract_prompt: str = ""
+ text2gql_graph_schema: str = ""
+ gremlin_generate_prompt: str = ""
+ doc_input_text: str = ""
+ generate_extract_prompt_template: str = ""
def ensure_yaml_file_exists(self):
current_dir = Path.cwd().resolve()
project_root = get_project_root()
if current_dir == project_root:
- log.info("Current working directory is the project root,
proceeding to run the app.")
+ log.info(
+ "Current working directory is the project root, proceeding to
run the app."
+ )
else:
error_msg = (
f"Current working directory is not the project root. "
@@ -66,22 +68,42 @@ class BasePromptConfig:
log.info("Prompt file '%s' doesn't exist, create it.",
yaml_file_path)
def save_to_yaml(self):
- indented_schema = "\n".join([f" {line}" for line in
self.graph_schema.splitlines()])
- indented_text2gql_schema = "\n".join([f" {line}" for line in
self.text2gql_graph_schema.splitlines()])
- indented_gremlin_prompt = "\n".join([f" {line}" for line in
self.gremlin_generate_prompt.splitlines()])
- indented_example_prompt = "\n".join([f" {line}" for line in
self.extract_graph_prompt.splitlines()])
- indented_question = "\n".join([f" {line}" for line in
self.default_question.splitlines()])
- indented_custom_related_information = (
- "\n".join([f" {line}" for line in
self.custom_rerank_info.splitlines()])
+ indented_schema = "\n".join(
+ [f" {line}" for line in self.graph_schema.splitlines()]
+ )
+ indented_text2gql_schema = "\n".join(
+ [f" {line}" for line in self.text2gql_graph_schema.splitlines()]
+ )
+ indented_gremlin_prompt = "\n".join(
+ [f" {line}" for line in self.gremlin_generate_prompt.splitlines()]
+ )
+ indented_example_prompt = "\n".join(
+ [f" {line}" for line in self.extract_graph_prompt.splitlines()]
+ )
+ indented_question = "\n".join(
+ [f" {line}" for line in self.default_question.splitlines()]
)
- indented_default_answer_template = "\n".join([f" {line}" for line
in self.answer_prompt.splitlines()])
- indented_keywords_extract_template = (
- "\n".join([f" {line}" for line in
self.keywords_extract_prompt.splitlines()])
+ indented_custom_related_information = "\n".join(
+ [f" {line}" for line in self.custom_rerank_info.splitlines()]
+ )
+ indented_default_answer_template = "\n".join(
+ [f" {line}" for line in self.answer_prompt.splitlines()]
+ )
+ indented_keywords_extract_template = "\n".join(
+ [f" {line}" for line in
self.keywords_extract_prompt.splitlines()]
+ )
+ indented_doc_input_text = "\n".join(
+ [f" {line}" for line in self.doc_input_text.splitlines()]
+ )
+ indented_generate_extract_prompt = (
+ "\n".join(
+ [
+ f" {line}"
+ for line in
self.generate_extract_prompt_template.splitlines()
+ ]
+ )
+ + "\n"
)
- indented_doc_input_text = "\n".join([f" {line}" for line in
self.doc_input_text.splitlines()])
- indented_generate_extract_prompt = "\n".join(
- [f" {line}" for line in
self.generate_extract_prompt_template.splitlines()]
- ) + "\n"
# This can be extended to add storage fields according to the data
needs to be stored
yaml_content = f"""graph_schema: |
{indented_schema}
@@ -118,7 +140,10 @@ generate_extract_prompt_template: |
def generate_yaml_file(self):
if os.path.exists(yaml_file_path):
- log.info("%s already exists, do you want to override with the
default configuration? (y/n)", yaml_file_path)
+ log.info(
+ "%s already exists, do you want to override with the default
configuration? (y/n)",
+ yaml_file_path,
+ )
update = input()
if update.lower() != "y":
return
diff --git a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py
b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py
index 5e810ed9..301d741f 100644
--- a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py
@@ -37,11 +37,13 @@ class VectorIndex:
self.properties = []
@staticmethod
- def from_index_file(dir_path: str) -> "VectorIndex":
+ def from_index_file(dir_path: str, record_miss: bool = True) ->
"VectorIndex":
index_file = os.path.join(dir_path, INDEX_FILE_NAME)
properties_file = os.path.join(dir_path, PROPERTIES_FILE_NAME)
- if not os.path.exists(index_file) or not
os.path.exists(properties_file):
- log.warning("No index file found, create a new one.")
+ miss_files = [f for f in [index_file, properties_file] if not
os.path.exists(f)]
+ if miss_files:
+ if record_miss:
+ log.warning("Missing vector files: %s. \nNeed create a new one
for it.", ", ".join(miss_files))
return VectorIndex()
faiss_index = faiss.read_index(index_file)
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
index d6b66294..db9b2f10 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
@@ -86,13 +86,28 @@ class BaseEmbedding(ABC):
The order of embeddings should match the order of input texts.
"""
- # TODO: [PR-238] Add & implement batch processing for
async_get_texts_embeddings (refactor here)
@abstractmethod
- async def async_get_text_embedding(
+ async def async_get_texts_embeddings(
self,
- text: str
- ) -> List[float]:
- """Comment"""
+ texts: List[str]
+ ) -> List[List[float]]:
+ """Get embeddings for multiple texts in a single batch asynchronously.
+
+ This method should efficiently process multiple texts at once by
leveraging
+ the embedding model's batching capabilities, which is typically more
efficient
+ than processing texts individually.
+
+ Parameters
+ ----------
+ texts : List[str]
+ A list of text strings to be embedded.
+
+ Returns
+ -------
+ List[List[float]]
+ A list of embedding vectors, where each vector is a list of floats.
+ The order of embeddings should match the order of input texts.
+ """
@staticmethod
def similarity(
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py
index ee808b09..b793800e 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py
@@ -77,17 +77,17 @@ class LiteLLMEmbedding(BaseEmbedding):
log.error("Error in LiteLLM batch embedding call: %s", e)
raise
- async def async_get_text_embedding(self, text: str) -> List[float]:
+ async def async_get_texts_embeddings(self, texts: List[str]) ->
List[List[float]]:
"""Get embedding for a single text asynchronously."""
try:
response = await aembedding(
model=self.model,
- input=text,
+ input=texts,
api_key=self.api_key,
api_base=self.api_base,
)
log.info("Token usage: %s", response.usage)
- return response.data[0]["embedding"]
+ return [data["embedding"] for data in response.data]
except (RateLimitError, APIConnectionError, APIError) as e:
log.error("Error in async LiteLLM embedding call: %s", e)
raise
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
index e54750f0..78c5bd08 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
@@ -53,7 +53,20 @@ class OllamaEmbedding(BaseEmbedding):
response = self.client.embed(model=self.model,
input=texts)["embeddings"]
return [list(inner_sequence) for inner_sequence in response]
- # TODO: Add & implement batch processing for async_get_texts_embeddings
(refactor here)
- async def async_get_text_embedding(self, text: str) -> List[float]:
- response = await self.async_client.embeddings(model=self.model,
prompt=text)
- return list(response["embedding"])
+ async def async_get_texts_embeddings(self, texts: List[str]) ->
List[List[float]]:
+ """Get embeddings for multiple texts in a single batch asynchronously.
+
+ Returns
+ -------
+ List[List[float]]
+ A list of embedding vectors, where each vector is a list of floats.
+ The order of embeddings matches the order of input texts.
+ """
+ if not hasattr(self.client, "embed"):
+ error_message = (
+ "The required 'embed' method was not found on the Ollama
client. "
+ "Please ensure your ollama library is up-to-date and supports
batch embedding. "
+ )
+ raise AttributeError(error_message)
+ response = await self.async_client.embed(model=self.model, input=texts)
+ return [list(inner_sequence) for inner_sequence in
response["embeddings"]]
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
index 890f4918..c18a0fb1 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
@@ -62,7 +62,23 @@ class OpenAIEmbedding:
response = self.client.embeddings.create(input=texts,
model=self.embedding_model_name)
return [data.embedding for data in response.data]
- async def async_get_text_embedding(self, text: str) -> List[float]:
- """Comment"""
- response = await self.aclient.embeddings.create(input=text,
model=self.embedding_model_name)
- return response.data[0].embedding
+ async def async_get_texts_embeddings(self, texts: List[str]) ->
List[List[float]]:
+ """Get embeddings for multiple texts in a single batch asynchronously.
+
+ This method should efficiently process multiple texts at once by
leveraging
+ the embedding model's batching capabilities, which is typically more
efficient
+ than processing texts individually.
+
+ Parameters
+ ----------
+ texts : List[str]
+ A list of text strings to be embedded.
+
+ Returns
+ -------
+ List[List[float]]
+ A list of embedding vectors, where each vector is a list of floats.
+ The order of embeddings should match the order of input texts.
+ """
+ response = await self.aclient.embeddings.create(input=texts,
model=self.embedding_model_name)
+ return [data.embedding for data in response.data]
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
index b865bc65..b444dcd6 100644
---
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
+++
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
@@ -16,12 +16,14 @@
# under the License.
+import asyncio
import os
from typing import Dict, Any, List
from hugegraph_llm.config import resource_path
from hugegraph_llm.indices.vector_index import VectorIndex
from hugegraph_llm.models.embeddings.base import BaseEmbedding
+from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel
# FIXME: we need keep the logic same with build_semantic_index.py
@@ -32,9 +34,10 @@ class BuildGremlinExampleIndex:
self.embedding = embedding
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
- examples_embedding = []
- for example in self.examples:
-
examples_embedding.append(self.embedding.get_text_embedding(example["query"]))
+ # !: We have assumed that self.example is not empty
+ queries = [example["query"] for example in self.examples]
+ # TODO: refactor function chain async to avoid blocking
+ examples_embedding =
asyncio.run(get_embeddings_parallel(self.embedding, queries))
embed_dim = len(examples_embedding[0])
if len(self.examples) > 0:
vector_index = VectorIndex(embed_dim)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
index e6b4080a..a2a0412f 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
@@ -20,18 +20,19 @@ import asyncio
import os
from typing import Any, Dict
-from tqdm import tqdm
-
from hugegraph_llm.config import resource_path, huge_settings
from hugegraph_llm.indices.vector_index import VectorIndex
from hugegraph_llm.models.embeddings.base import BaseEmbedding
-from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
from hugegraph_llm.utils.log import log
+from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel
+from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
class BuildSemanticIndex:
def __init__(self, embedding: BaseEmbedding):
- self.index_dir = str(os.path.join(resource_path,
huge_settings.graph_name, "graph_vids"))
+ self.index_dir = str(
+ os.path.join(resource_path, huge_settings.graph_name, "graph_vids")
+ )
self.vid_index = VectorIndex.from_index_file(self.index_dir)
self.embedding = embedding
self.sm = SchemaManager(huge_settings.graph_name)
@@ -39,55 +40,38 @@ class BuildSemanticIndex:
def _extract_names(self, vertices: list[str]) -> list[str]:
return [v.split(":")[1] for v in vertices]
- async def _get_embeddings_parallel(self, vids: list[str]) -> list[Any]:
- sem = asyncio.Semaphore(10)
- batch_size = 1000
-
- # TODO: refactor the logic here (call async method)
- async def get_embeddings_with_semaphore(vid_list: list[str]) -> Any:
- # Executes sync embedding method in a thread pool via
loop.run_in_executor, combining async programming
- # with multi-threading capabilities.
- # This pattern avoids blocking the event loop and prepares for a
future fully async pipeline.
- async with sem:
- loop = asyncio.get_running_loop()
- # FIXME: [PR-238] add & use async_get_texts_embedding instead
of sync method
- return await loop.run_in_executor(None,
self.embedding.get_texts_embeddings, vid_list)
-
- # Split vids into batches of size batch_size
- vid_batches = [vids[i:i + batch_size] for i in range(0, len(vids),
batch_size)]
-
- # Create tasks for each batch
- tasks = [get_embeddings_with_semaphore(batch) for batch in vid_batches]
-
- embeddings = []
- with tqdm(total=len(tasks)) as pbar:
- for future in asyncio.as_completed(tasks):
- batch_embeddings = await future
- embeddings.extend(batch_embeddings) # Extend the list with
batch results
- pbar.update(1)
- return embeddings
-
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
vertexlabels = self.sm.schema.getSchema()["vertexlabels"]
- all_pk_flag = all(data.get('id_strategy') == 'PRIMARY_KEY' for data in
vertexlabels)
+ all_pk_flag = all(
+ data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels
+ )
past_vids = self.vid_index.properties
# TODO: We should build vid vector index separately, especially when
the vertices may be very large
- present_vids = context["vertices"] # Warning: data truncated by
fetch_graph_data.py
+
+ present_vids = context[
+ "vertices"
+ ] # Warning: data truncated by fetch_graph_data.py
removed_vids = set(past_vids) - set(present_vids)
removed_num = self.vid_index.remove(removed_vids)
added_vids = list(set(present_vids) - set(past_vids))
if added_vids:
- vids_to_process = self._extract_names(added_vids) if all_pk_flag
else added_vids
- added_embeddings =
asyncio.run(self._get_embeddings_parallel(vids_to_process))
+ vids_to_process = (
+ self._extract_names(added_vids) if all_pk_flag else added_vids
+ )
+ added_embeddings = asyncio.run(
+ get_embeddings_parallel(self.embedding, vids_to_process)
+ )
log.info("Building vector index for %s vertices...",
len(added_vids))
self.vid_index.add(added_embeddings, added_vids)
self.vid_index.to_index_file(self.index_dir)
else:
log.debug("No update vertices to build vector index.")
- context.update({
- "removed_vid_vector_num": removed_num,
- "added_vid_vector_num": len(added_vids)
- })
+ context.update(
+ {
+ "removed_vid_vector_num": removed_num,
+ "added_vid_vector_num": len(added_vids),
+ }
+ )
return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py
index 01499f52..8a66d0b0 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py
@@ -16,15 +16,15 @@
# under the License.
+import asyncio
import os
from typing import Dict, Any
-from tqdm import tqdm
-
from hugegraph_llm.config import huge_settings, resource_path
from hugegraph_llm.indices.vector_index import VectorIndex
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.utils.log import log
+from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel
class BuildVectorIndex:
@@ -40,8 +40,7 @@ class BuildVectorIndex:
chunks_embedding = []
log.debug("Building vector index for %s chunks...",
len(context["chunks"]))
# TODO: use async_get_texts_embedding instead of single sync method
- for chunk in tqdm(chunks):
- chunks_embedding.append(self.embedding.get_text_embedding(chunk))
+ chunks_embedding = asyncio.run(get_embeddings_parallel(self.embedding,
chunks))
if len(chunks_embedding) > 0:
self.vector_index.add(chunks_embedding, chunks)
self.vector_index.to_index_file(self.index_dir)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
index b8acd506..b14da3d5 100644
---
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
+++
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
@@ -16,17 +16,18 @@
# under the License.
+import asyncio
import os
from typing import Dict, Any, List
import pandas as pd
-from tqdm import tqdm
from hugegraph_llm.config import resource_path
from hugegraph_llm.indices.vector_index import VectorIndex
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.models.embeddings.init_embedding import Embeddings
from hugegraph_llm.utils.log import log
+from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel
class GremlinExampleIndexQuery:
@@ -51,17 +52,14 @@ class GremlinExampleIndexQuery:
query_embedding = context.get("query_embedding")
if not isinstance(query_embedding, list):
- query_embedding = self.embedding.get_text_embedding(query)
+ query_embedding = self.embedding.get_texts_embeddings([query])[0]
return self.vector_index.search(query_embedding, self.num_examples,
dis_threshold=1.8)
def _build_default_example_index(self):
properties = pd.read_csv(os.path.join(resource_path, "demo",
"text2gremlin.csv")).to_dict(orient="records")
- from concurrent.futures import ThreadPoolExecutor
# TODO: reuse the logic in build_semantic_index.py (consider extract
the batch-embedding method)
- with ThreadPoolExecutor() as executor:
- embeddings = list(
- tqdm(executor.map(self.embedding.get_text_embedding,
[row["query"] for row in properties]),
- total=len(properties)))
+ queries = [row["query"] for row in properties]
+ embeddings = asyncio.run(get_embeddings_parallel(self.embedding,
queries))
vector_index = VectorIndex(len(embeddings[0]))
vector_index.add(embeddings, properties)
vector_index.to_index_file(self.index_dir)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
index e3375ef0..10233a6c 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
@@ -75,7 +75,7 @@ class SemanticIdQuery:
def _fuzzy_match_vids(self, keywords: List[str]) -> List[str]:
fuzzy_match_result = []
for keyword in keywords:
- keyword_vector = self.embedding.get_text_embedding(keyword)
+ keyword_vector = self.embedding.get_texts_embeddings([keyword])[0]
results = self.vector_index.search(keyword_vector,
top_k=self.topk_per_keyword,
dis_threshold=float(self.vector_dis_threshold))
if results:
@@ -86,7 +86,7 @@ class SemanticIdQuery:
graph_query_list = set()
if self.by == "query":
query = context["query"]
- query_vector = self.embedding.get_text_embedding(query)
+ query_vector = self.embedding.get_texts_embeddings([query])[0]
results = self.vector_index.search(query_vector,
top_k=self.topk_per_query)
if results:
graph_query_list.update(results[:self.topk_per_query])
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
index 976155c3..f2d5d600 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
@@ -34,7 +34,7 @@ class VectorIndexQuery:
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
query = context.get("query")
- query_embedding = self.embedding.get_text_embedding(query)
+ query_embedding = self.embedding.get_texts_embeddings([query])[0]
# TODO: why set dis_threshold=2?
results = self.vector_index.search(query_embedding, self.topk,
dis_threshold=2)
# TODO: check format results
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
new file mode 100644
index 00000000..2c7b3874
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import asyncio
+from typing import Any
+
+from tqdm import tqdm
+
+from hugegraph_llm.models.embeddings.base import BaseEmbedding
+
+
+async def get_embeddings_parallel(
+ embedding: BaseEmbedding, vids: list[str]
+) -> list[Any]:
+ """Get embeddings for texts in parallel.
+
+ This function processes text embeddings asynchronously in parallel, using
batching and semaphore
+ to control concurrency, improving processing efficiency while preventing
resource overuse.
+
+ Args:
+ embedding (BaseEmbedding): The embedding model instance used to
compute text embeddings.
+ vids (list[str]): List of texts to compute embeddings for.
+
+ Returns:
+ list[Any]: List of embedding vectors corresponding to the input texts,
maintaining the same
+ order as the input vids list.
+
+ Note:
+ - Note: Uses a semaphore to limit maximum concurrency if we need
+ - Processes texts in batches of 500
+ - Displays progress using a progress bar
+ """
+ batch_size = 500
+
+ # Split vids into batches of size batch_size
+ vid_batches = [vids[i : i + batch_size] for i in range(0, len(vids),
batch_size)]
+
+ # Create tasks for each batch
+ tasks = [embedding.async_get_texts_embeddings(batch) for batch in
vid_batches]
+
+ embeddings = []
+ with tqdm(total=len(tasks)) as pbar:
+ for future in asyncio.as_completed(tasks):
+ batch_embeddings = await future
+ embeddings.extend(batch_embeddings) # Extend the list with batch
results
+ pbar.update(1)
+ return embeddings
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
index ef2b5e9b..0542b906 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import json
import os
@@ -55,10 +56,13 @@ def read_documents(input_file, input_text):
return texts
-#pylint: disable=C0301
+# pylint: disable=C0301
def get_vector_index_info():
- chunk_vector_index =
VectorIndex.from_index_file(str(os.path.join(resource_path,
huge_settings.graph_name, "chunks")))
- graph_vid_vector_index =
VectorIndex.from_index_file(str(os.path.join(resource_path,
huge_settings.graph_name, "graph_vids")))
+ chunk_vector_index = VectorIndex.from_index_file(
+ str(os.path.join(resource_path, huge_settings.graph_name, "chunks")),
record_miss=False,
+ )
+ graph_vid_vector_index =
VectorIndex.from_index_file(str(os.path.join(resource_path,
+
huge_settings.graph_name, "graph_vids")))
return json.dumps({
"embed_dim": chunk_vector_index.index.d,
"vector_info": {