This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 0c9c15c2 fix(llm): refactor embedding parallelization to preserve
order (#295)
0c9c15c2 is described below
commit 0c9c15c20fee3748718c5a77b3120469b6b06d02
Author: imbajin <[email protected]>
AuthorDate: Thu Jul 31 17:30:41 2025 +0800
fix(llm): refactor embedding parallelization to preserve order (#295)
Reworked get_embeddings_parallel to use asyncio.gather for batch
processing, ensuring output order matches input. Added a helper for
batch progress updates and improved progress bar accuracy.
---
.../src/hugegraph_llm/utils/embedding_utils.py | 35 ++++++++++++++--------
1 file changed, 23 insertions(+), 12 deletions(-)
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
index 2c7b3874..4209890f 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
@@ -24,9 +24,13 @@ from tqdm import tqdm
from hugegraph_llm.models.embeddings.base import BaseEmbedding
-async def get_embeddings_parallel(
- embedding: BaseEmbedding, vids: list[str]
-) -> list[Any]:
+async def _get_batch_with_progress(embedding: BaseEmbedding, batch: list[str],
pbar: tqdm) -> list[Any]:
+ result = await embedding.async_get_texts_embeddings(batch)
+ pbar.update(1)
+ return result
+
+
+async def get_embeddings_parallel(embedding: BaseEmbedding, vids: list[str])
-> list[Any]:
"""Get embeddings for texts in parallel.
This function processes text embeddings asynchronously in parallel, using
batching and semaphore
@@ -43,20 +47,27 @@ async def get_embeddings_parallel(
Note:
- Note: Uses a semaphore to limit maximum concurrency if we need
- Processes texts in batches of 500
- - Displays progress using a progress bar
+ - Displays progress using a progress bar that updates as each batch
completes
+ - Uses asyncio.gather() to preserve order correspondence between input
and output
"""
batch_size = 500
# Split vids into batches of size batch_size
vid_batches = [vids[i : i + batch_size] for i in range(0, len(vids),
batch_size)]
- # Create tasks for each batch
- tasks = [embedding.async_get_texts_embeddings(batch) for batch in
vid_batches]
-
embeddings = []
- with tqdm(total=len(tasks)) as pbar:
- for future in asyncio.as_completed(tasks):
- batch_embeddings = await future
- embeddings.extend(batch_embeddings) # Extend the list with batch
results
- pbar.update(1)
+ with tqdm(total=len(vid_batches)) as pbar:
+ # Create tasks for each batch with progress bar updates
+ tasks = [
+ _get_batch_with_progress(embedding, batch, pbar)
+ for batch in vid_batches
+ ]
+
+ # Use asyncio.gather() to preserve order
+ batch_results = await asyncio.gather(*tasks)
+
+ # Combine all batch results in order
+ for batch_embeddings in batch_results:
+ embeddings.extend(batch_embeddings)
+
return embeddings