This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 6a8a379 refactor: optimize the process of graph rag (#101)
6a8a379 is described below
commit 6a8a379f9717b2f2f7ff17c9c3385b31ec1ce87c
Author: chenzihong <[email protected]>
AuthorDate: Fri Nov 15 00:34:08 2024 +0800
refactor: optimize the process of graph rag (#101)
- [x] replace print with log
- [x] delete single quote in keywords
- [x] special handling is required when keywords list is empty
* fix: empty keywords list lead g.V() all
* remove log color str
---------
Co-authored-by: imbajin <[email protected]>
---
.../operators/hugegraph_op/graph_rag_query.py | 46 ++++++++++++++--------
.../operators/index_op/semantic_id_query.py | 31 +++++++++------
.../operators/llm_op/answer_synthesize.py | 39 +++++++++---------
.../operators/llm_op/keyword_extract.py | 1 +
4 files changed, 70 insertions(+), 47 deletions(-)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index 9fc6cd7..d15a8c8 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -20,7 +20,8 @@ from hugegraph_llm.config import settings
from hugegraph_llm.utils.log import log
from pyhugegraph.client import PyHugeClient
-VERTEX_QUERY_TPL = "g.V({keywords}).as('subj').toList()"
+# TODO: remove 'as('subj)' step
+VERTEX_QUERY_TPL = "g.V({keywords}).limit(8).as('subj').toList()"
# TODO: we could use a simpler query (like kneighbor-api to get the edges)
# TODO: test with profile()/explain() to speed up the query
@@ -137,7 +138,6 @@ class GraphRAGQuery:
graph_chain_knowledge, vertex_degree_list, knowledge_with_degree =
self._format_graph_query_result(
query_paths=paths
)
- graph_chain_knowledge.update(vertex_knowledge)
if vertex_degree_list:
vertex_degree_list[0].update(vertex_knowledge)
else:
@@ -170,12 +170,9 @@ class GraphRAGQuery:
"`vertexA --[links]--> vertexB <--[links]-- vertexC ...`"
"extracted based on key entities as subject:\n"
)
-
- # TODO: replace print to log
- verbose = context.get("verbose") or False
- if verbose:
- print("\033[93mKnowledge from Graph:")
- print("\n".join(context["graph_result"]) + "\033[0m")
+ # TODO: set color for ↓ "\033[93mKnowledge from Graph:\033[0m"
+ log.debug("Knowledge from Graph:")
+ log.debug("\n".join(context["graph_result"]))
return context
@@ -192,10 +189,12 @@ class GraphRAGQuery:
subgraph = set()
subgraph_with_degree = {}
vertex_degree_list: List[Set[str]] = []
+ v_cache: Set[str] = set()
+ e_cache: Set[str] = set()
for path in query_paths:
# 1. Process each path
- flat_rel, nodes_with_degree = self._process_path(path,
use_id_to_match)
+ flat_rel, nodes_with_degree = self._process_path(path,
use_id_to_match, v_cache, e_cache)
subgraph.add(flat_rel)
subgraph_with_degree[flat_rel] = nodes_with_degree
# 2. Update vertex degree list
@@ -203,7 +202,8 @@ class GraphRAGQuery:
return subgraph, vertex_degree_list, subgraph_with_degree
- def _process_path(self, path: Any, use_id_to_match: bool) -> Tuple[str,
List[str]]:
+ def _process_path(self, path: Any, use_id_to_match: bool, v_cache:
Set[str],
+ e_cache: Set[str]) -> Tuple[str, List[str]]:
flat_rel = ""
raw_flat_rel = path["objects"]
assert len(raw_flat_rel) % 2 == 1, "The length of raw_flat_rel should
be odd."
@@ -217,19 +217,20 @@ class GraphRAGQuery:
if i % 2 == 0:
# Process each vertex
flat_rel, prior_edge_str_len, depth = self._process_vertex(
- item, flat_rel, node_cache, prior_edge_str_len, depth,
nodes_with_degree, use_id_to_match
+ item, flat_rel, node_cache, prior_edge_str_len, depth,
nodes_with_degree, use_id_to_match,
+ v_cache
)
else:
# Process each edge
flat_rel, prior_edge_str_len = self._process_edge(
- item, flat_rel, prior_edge_str_len, raw_flat_rel,
i,use_id_to_match
+ item, flat_rel, prior_edge_str_len, raw_flat_rel,
i,use_id_to_match, e_cache
)
return flat_rel, nodes_with_degree
def _process_vertex(self, item: Any, flat_rel: str, node_cache: Set[str],
prior_edge_str_len: int, depth: int,
nodes_with_degree: List[str],
- use_id_to_match: bool) -> Tuple[str, int, int]:
+ use_id_to_match: bool, v_cache: Set[str]) ->
Tuple[str, int, int]:
matched_str = item["id"] if use_id_to_match else
item["props"][self._prop_to_match]
if matched_str in node_cache:
flat_rel = flat_rel[:-prior_edge_str_len]
@@ -237,7 +238,12 @@ class GraphRAGQuery:
node_cache.add(matched_str)
props_str = ", ".join(f"{k}: {v}" for k, v in item["props"].items())
- node_str = f"{item['id']}{{{props_str}}}"
+ # TODO: we may remove label id or replace with label name
+ if matched_str in v_cache:
+ node_str = matched_str
+ else:
+ v_cache.add(matched_str)
+ node_str = f"{item['id']}{{{props_str}}}"
flat_rel += node_str
nodes_with_degree.append(node_str)
depth += 1
@@ -245,16 +251,22 @@ class GraphRAGQuery:
return flat_rel, prior_edge_str_len, depth
def _process_edge(self, item: Any, flat_rel: str, prior_edge_str_len: int,
- raw_flat_rel: List[Any], i: int, use_id_to_match: bool)
-> Tuple[str, int]:
+ raw_flat_rel: List[Any], i: int, use_id_to_match: bool,
e_cache: Set[str]) -> Tuple[str, int]:
props_str = ", ".join(f"{k}: {v}" for k, v in item["props"].items())
props_str = f"{{{props_str}}}" if len(props_str) > 0 else ""
prev_matched_str = raw_flat_rel[i - 1]["id"] if use_id_to_match else (
raw_flat_rel)[i - 1]["props"][self._prop_to_match]
+ if item["label"] in e_cache:
+ edge_str = f"{item['label']}"
+ else:
+ e_cache.add(item["label"])
+ edge_str = f"{item['label']}{props_str}"
+
if item["outV"] == prev_matched_str:
- edge_str = f" --[{item['label']}{props_str}]--> "
+ edge_str = f" --[{edge_str}]--> "
else:
- edge_str = f" <--[{item['label']}{props_str}]-- "
+ edge_str = f" <--[{edge_str}]-- "
flat_rel += edge_str
prior_edge_str_len = len(edge_str)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
index 04804e6..d7b5b89 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
@@ -17,7 +17,6 @@
import os
-from copy import deepcopy
from typing import Dict, Any, Literal, List, Tuple
from hugegraph_llm.config import resource_path, settings
@@ -28,7 +27,8 @@ from pyhugegraph.client import PyHugeClient
class SemanticIdQuery:
- ID_QUERY_TEMPL = "g.V({vids_str})"
+ ID_QUERY_TEMPL = "g.V({vids_str}).limit(8)"
+
def __init__(
self,
embedding: BaseEmbedding,
@@ -52,31 +52,39 @@ class SemanticIdQuery:
)
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
- graph_query_list = []
+ graph_query_list = set()
if self.by == "query":
query = context["query"]
query_vector = self.embedding.get_text_embedding(query)
results = self.vector_index.search(query_vector,
top_k=self.topk_per_query)
if results:
- graph_query_list.extend(results[:self.topk_per_query])
+ graph_query_list.update(results[:self.topk_per_query])
else: # by keywords
- exact_match_vids, unmatched_vids =
self._exact_match_vids(context["keywords"])
- graph_query_list.extend(exact_match_vids)
+ keywords = context.get("keywords", [])
+ if not keywords:
+ context["match_vids"] = []
+ return context
+
+ exact_match_vids, unmatched_vids = self._exact_match_vids(keywords)
+ graph_query_list.update(exact_match_vids)
fuzzy_match_vids = self._fuzzy_match_vids(unmatched_vids)
log.debug("Fuzzy match vids: %s", fuzzy_match_vids)
- graph_query_list.extend(fuzzy_match_vids)
- context["match_vids"] = list(set(graph_query_list))
+ graph_query_list.update(fuzzy_match_vids)
+ context["match_vids"] = list(graph_query_list)
return context
def _exact_match_vids(self, keywords: List[str]) -> Tuple[List[str],
List[str]]:
+ assert keywords, "keywords can't be empty, please check the logic"
+ # TODO: we should add a global GraphSchemaCache to avoid calling the
server every time
vertex_label_num = len(self._client.schema().getVertexLabels())
- possible_vids = deepcopy(keywords)
+ possible_vids = set(keywords)
for i in range(vertex_label_num):
- possible_vids.extend([f"{i+1}:{keyword}" for keyword in keywords])
+ possible_vids.update([f"{i + 1}:{keyword}" for keyword in
keywords])
vids_str = ",".join([f"'{vid}'" for vid in possible_vids])
resp =
self._client.gremlin().exec(SemanticIdQuery.ID_QUERY_TEMPL.format(vids_str=vids_str))
searched_vids = [v['id'] for v in resp['data']]
+
unsearched_keywords = set(keywords)
for vid in searched_vids:
for keyword in unsearched_keywords:
@@ -91,5 +99,6 @@ class SemanticIdQuery:
keyword_vector = self.embedding.get_text_embedding(keyword)
results = self.vector_index.search(keyword_vector,
top_k=self.topk_per_keyword)
if results:
+ # FIXME: type mismatch, got 'list[dict[str, Any]]' instead
fuzzy_match_result.extend(results[:self.topk_per_keyword])
- return fuzzy_match_result # FIXME: type mismatch, got 'list[dict[str,
Any]]' instead
+ return fuzzy_match_result
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index baf61e6..0f894cb 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -105,25 +105,26 @@ class AnswerSynthesize:
graph_result_context: str):
# pylint: disable=R0912 (too-many-branches)
verbose = context.get("verbose") or False
- # TODO: replace task_cache with a better name
- task_cache = {}
+
+ # async_tasks stores the async tasks for different answer types
+ async_tasks = {}
if self._raw_answer:
final_prompt = self._question
- task_cache["raw_task"] =
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
+ async_tasks["raw_task"] =
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
if self._vector_only_answer:
context_str = (f"{context_head_str}\n"
f"{vector_result_context}\n"
f"{context_tail_str}".strip("\n"))
final_prompt =
self._prompt_template.format(context_str=context_str, query_str=self._question)
- task_cache["vector_only_task"] =
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
+ async_tasks["vector_only_task"] =
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
if self._graph_only_answer:
context_str = (f"{context_head_str}\n"
f"{graph_result_context}\n"
f"{context_tail_str}".strip("\n"))
final_prompt =
self._prompt_template.format(context_str=context_str, query_str=self._question)
- task_cache["graph_only_task"] =
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
+ async_tasks["graph_only_task"] =
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
if self._graph_vector_answer:
context_body_str =
f"{vector_result_context}\n{graph_result_context}"
if context.get("graph_ratio", 0.5) < 0.5:
@@ -133,30 +134,30 @@ class AnswerSynthesize:
f"{context_tail_str}".strip("\n"))
final_prompt =
self._prompt_template.format(context_str=context_str, query_str=self._question)
- task_cache["graph_vector_task"] = asyncio.create_task(
+ async_tasks["graph_vector_task"] = asyncio.create_task(
self._llm.agenerate(prompt=final_prompt)
)
- # TODO: use log.debug instead of print
- if task_cache.get("raw_task"):
- response = await task_cache["raw_task"]
+
+ if async_tasks.get("raw_task"):
+ response = await async_tasks["raw_task"]
context["raw_answer"] = response
if verbose:
- print(f"\033[91mANSWER: {response}\033[0m")
- if task_cache.get("vector_only_task"):
- response = await task_cache["vector_only_task"]
+ log.debug(f"ANSWER: {response}")
+ if async_tasks.get("vector_only_task"):
+ response = await async_tasks["vector_only_task"]
context["vector_only_answer"] = response
if verbose:
- print(f"\033[91mANSWER: {response}\033[0m")
- if task_cache.get("graph_only_task"):
- response = await task_cache["graph_only_task"]
+ log.debug(f"ANSWER: {response}")
+ if async_tasks.get("graph_only_task"):
+ response = await async_tasks["graph_only_task"]
context["graph_only_answer"] = response
if verbose:
- print(f"\033[91mANSWER: {response}\033[0m")
- if task_cache.get("graph_vector_task"):
- response = await task_cache["graph_vector_task"]
+ log.debug(f"ANSWER: {response}")
+ if async_tasks.get("graph_vector_task"):
+ response = await async_tasks["graph_vector_task"]
context["graph_vector_answer"] = response
if verbose:
- print(f"\033[91mANSWER: {response}\033[0m")
+ log.debug(f"ANSWER: {response}")
ops = sum([self._raw_answer, self._vector_only_answer,
self._graph_only_answer, self._graph_vector_answer])
context['call_count'] = context.get('call_count', 0) + ops
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
index 2cad98d..cdfa6b5 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -83,6 +83,7 @@ class KeywordExtract:
response=response, lowercase=False, start_token="KEYWORDS:"
)
keywords.union(self._expand_synonyms(keywords=keywords))
+ keywords = {k.replace("'", "") for k in keywords}
context["keywords"] = list(keywords)
verbose = context.get("verbose") or False