This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new 6a8a379  refactor: optimize the process of graph rag (#101)
6a8a379 is described below

commit 6a8a379f9717b2f2f7ff17c9c3385b31ec1ce87c
Author: chenzihong <[email protected]>
AuthorDate: Fri Nov 15 00:34:08 2024 +0800

    refactor: optimize the process of graph rag (#101)
    
    - [x] replace print with log
    - [x] delete single quote in keywords
    - [x] special handling is required when keywords list is empty
    
    * fix: empty keywords list lead g.V() all
    * remove log color str
    ---------
    
    Co-authored-by: imbajin <[email protected]>
---
 .../operators/hugegraph_op/graph_rag_query.py      | 46 ++++++++++++++--------
 .../operators/index_op/semantic_id_query.py        | 31 +++++++++------
 .../operators/llm_op/answer_synthesize.py          | 39 +++++++++---------
 .../operators/llm_op/keyword_extract.py            |  1 +
 4 files changed, 70 insertions(+), 47 deletions(-)

diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index 9fc6cd7..d15a8c8 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -20,7 +20,8 @@ from hugegraph_llm.config import settings
 from hugegraph_llm.utils.log import log
 from pyhugegraph.client import PyHugeClient
 
-VERTEX_QUERY_TPL = "g.V({keywords}).as('subj').toList()"
+# TODO: remove 'as('subj)' step
+VERTEX_QUERY_TPL = "g.V({keywords}).limit(8).as('subj').toList()"
 
 # TODO: we could use a simpler query (like kneighbor-api to get the edges)
 # TODO: test with profile()/explain() to speed up the query
@@ -137,7 +138,6 @@ class GraphRAGQuery:
             graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_graph_query_result(
                 query_paths=paths
             )
-            graph_chain_knowledge.update(vertex_knowledge)
             if vertex_degree_list:
                 vertex_degree_list[0].update(vertex_knowledge)
             else:
@@ -170,12 +170,9 @@ class GraphRAGQuery:
             "`vertexA --[links]--> vertexB <--[links]-- vertexC ...`"
             "extracted based on key entities as subject:\n"
         )
-
-        # TODO: replace print to log
-        verbose = context.get("verbose") or False
-        if verbose:
-            print("\033[93mKnowledge from Graph:")
-            print("\n".join(context["graph_result"]) + "\033[0m")
+        # TODO: set color for ↓ "\033[93mKnowledge from Graph:\033[0m"
+        log.debug("Knowledge from Graph:")
+        log.debug("\n".join(context["graph_result"]))
 
         return context
 
@@ -192,10 +189,12 @@ class GraphRAGQuery:
         subgraph = set()
         subgraph_with_degree = {}
         vertex_degree_list: List[Set[str]] = []
+        v_cache: Set[str] = set()
+        e_cache: Set[str] = set()
 
         for path in query_paths:
             # 1. Process each path
-            flat_rel, nodes_with_degree = self._process_path(path, 
use_id_to_match)
+            flat_rel, nodes_with_degree = self._process_path(path, 
use_id_to_match, v_cache, e_cache)
             subgraph.add(flat_rel)
             subgraph_with_degree[flat_rel] = nodes_with_degree
             # 2. Update vertex degree list
@@ -203,7 +202,8 @@ class GraphRAGQuery:
 
         return subgraph, vertex_degree_list, subgraph_with_degree
 
-    def _process_path(self, path: Any, use_id_to_match: bool) -> Tuple[str, 
List[str]]:
+    def _process_path(self, path: Any, use_id_to_match: bool, v_cache: 
Set[str],
+                      e_cache: Set[str]) -> Tuple[str, List[str]]:
         flat_rel = ""
         raw_flat_rel = path["objects"]
         assert len(raw_flat_rel) % 2 == 1, "The length of raw_flat_rel should 
be odd."
@@ -217,19 +217,20 @@ class GraphRAGQuery:
             if i % 2 == 0:
                 # Process each vertex
                 flat_rel, prior_edge_str_len, depth = self._process_vertex(
-                    item, flat_rel, node_cache, prior_edge_str_len, depth, 
nodes_with_degree, use_id_to_match
+                    item, flat_rel, node_cache, prior_edge_str_len, depth, 
nodes_with_degree, use_id_to_match,
+                    v_cache
                 )
             else:
                 # Process each edge
                 flat_rel, prior_edge_str_len = self._process_edge(
-                    item, flat_rel, prior_edge_str_len, raw_flat_rel, 
i,use_id_to_match
+                    item, flat_rel, prior_edge_str_len, raw_flat_rel, 
i,use_id_to_match, e_cache
                 )
 
         return flat_rel, nodes_with_degree
 
     def _process_vertex(self, item: Any, flat_rel: str, node_cache: Set[str],
                         prior_edge_str_len: int, depth: int, 
nodes_with_degree: List[str],
-                        use_id_to_match: bool) -> Tuple[str, int, int]:
+                        use_id_to_match: bool, v_cache: Set[str]) -> 
Tuple[str, int, int]:
         matched_str = item["id"] if use_id_to_match else 
item["props"][self._prop_to_match]
         if matched_str in node_cache:
             flat_rel = flat_rel[:-prior_edge_str_len]
@@ -237,7 +238,12 @@ class GraphRAGQuery:
 
         node_cache.add(matched_str)
         props_str = ", ".join(f"{k}: {v}" for k, v in item["props"].items())
-        node_str = f"{item['id']}{{{props_str}}}"
+        # TODO: we may remove label id or replace with label name
+        if matched_str in v_cache:
+            node_str = matched_str
+        else:
+            v_cache.add(matched_str)
+            node_str = f"{item['id']}{{{props_str}}}"
         flat_rel += node_str
         nodes_with_degree.append(node_str)
         depth += 1
@@ -245,16 +251,22 @@ class GraphRAGQuery:
         return flat_rel, prior_edge_str_len, depth
 
     def _process_edge(self, item: Any, flat_rel: str, prior_edge_str_len: int,
-                      raw_flat_rel: List[Any], i: int, use_id_to_match: bool) 
-> Tuple[str, int]:
+                      raw_flat_rel: List[Any], i: int, use_id_to_match: bool, 
e_cache: Set[str]) -> Tuple[str, int]:
         props_str = ", ".join(f"{k}: {v}" for k, v in item["props"].items())
         props_str = f"{{{props_str}}}" if len(props_str) > 0 else ""
         prev_matched_str = raw_flat_rel[i - 1]["id"] if use_id_to_match else (
             raw_flat_rel)[i - 1]["props"][self._prop_to_match]
 
+        if item["label"] in e_cache:
+            edge_str = f"{item['label']}"
+        else:
+            e_cache.add(item["label"])
+            edge_str = f"{item['label']}{props_str}"
+
         if item["outV"] == prev_matched_str:
-            edge_str = f" --[{item['label']}{props_str}]--> "
+            edge_str = f" --[{edge_str}]--> "
         else:
-            edge_str = f" <--[{item['label']}{props_str}]-- "
+            edge_str = f" <--[{edge_str}]-- "
 
         flat_rel += edge_str
         prior_edge_str_len = len(edge_str)
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
index 04804e6..d7b5b89 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
@@ -17,7 +17,6 @@
 
 
 import os
-from copy import deepcopy
 from typing import Dict, Any, Literal, List, Tuple
 
 from hugegraph_llm.config import resource_path, settings
@@ -28,7 +27,8 @@ from pyhugegraph.client import PyHugeClient
 
 
 class SemanticIdQuery:
-    ID_QUERY_TEMPL = "g.V({vids_str})"
+    ID_QUERY_TEMPL = "g.V({vids_str}).limit(8)"
+
     def __init__(
             self,
             embedding: BaseEmbedding,
@@ -52,31 +52,39 @@ class SemanticIdQuery:
         )
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
-        graph_query_list = []
+        graph_query_list = set()
         if self.by == "query":
             query = context["query"]
             query_vector = self.embedding.get_text_embedding(query)
             results = self.vector_index.search(query_vector, 
top_k=self.topk_per_query)
             if results:
-                graph_query_list.extend(results[:self.topk_per_query])
+                graph_query_list.update(results[:self.topk_per_query])
         else:  # by keywords
-            exact_match_vids, unmatched_vids = 
self._exact_match_vids(context["keywords"])
-            graph_query_list.extend(exact_match_vids)
+            keywords = context.get("keywords", [])
+            if not keywords:
+                context["match_vids"] = []
+                return context
+
+            exact_match_vids, unmatched_vids = self._exact_match_vids(keywords)
+            graph_query_list.update(exact_match_vids)
             fuzzy_match_vids = self._fuzzy_match_vids(unmatched_vids)
             log.debug("Fuzzy match vids: %s", fuzzy_match_vids)
-            graph_query_list.extend(fuzzy_match_vids)
-        context["match_vids"] = list(set(graph_query_list))
+            graph_query_list.update(fuzzy_match_vids)
+        context["match_vids"] = list(graph_query_list)
         return context
 
     def _exact_match_vids(self, keywords: List[str]) -> Tuple[List[str], 
List[str]]:
+        assert keywords, "keywords can't be empty, please check the logic"
+        # TODO: we should add a global GraphSchemaCache to avoid calling the 
server every time
         vertex_label_num = len(self._client.schema().getVertexLabels())
-        possible_vids = deepcopy(keywords)
+        possible_vids = set(keywords)
         for i in range(vertex_label_num):
-            possible_vids.extend([f"{i+1}:{keyword}" for keyword in keywords])
+            possible_vids.update([f"{i + 1}:{keyword}" for keyword in 
keywords])
 
         vids_str = ",".join([f"'{vid}'" for vid in possible_vids])
         resp = 
self._client.gremlin().exec(SemanticIdQuery.ID_QUERY_TEMPL.format(vids_str=vids_str))
         searched_vids = [v['id'] for v in resp['data']]
+
         unsearched_keywords = set(keywords)
         for vid in searched_vids:
             for keyword in unsearched_keywords:
@@ -91,5 +99,6 @@ class SemanticIdQuery:
             keyword_vector = self.embedding.get_text_embedding(keyword)
             results = self.vector_index.search(keyword_vector, 
top_k=self.topk_per_keyword)
             if results:
+                # FIXME: type mismatch, got 'list[dict[str, Any]]' instead
                 fuzzy_match_result.extend(results[:self.topk_per_keyword])
-        return fuzzy_match_result # FIXME: type mismatch, got 'list[dict[str, 
Any]]' instead
+        return fuzzy_match_result
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index baf61e6..0f894cb 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -105,25 +105,26 @@ class AnswerSynthesize:
                              graph_result_context: str):
         # pylint: disable=R0912 (too-many-branches)
         verbose = context.get("verbose") or False
-        # TODO: replace task_cache with a better name
-        task_cache = {}
+
+        # async_tasks stores the async tasks for different answer types
+        async_tasks = {}
         if self._raw_answer:
             final_prompt = self._question
-            task_cache["raw_task"] = 
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
+            async_tasks["raw_task"] = 
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
         if self._vector_only_answer:
             context_str = (f"{context_head_str}\n"
                            f"{vector_result_context}\n"
                            f"{context_tail_str}".strip("\n"))
 
             final_prompt = 
self._prompt_template.format(context_str=context_str, query_str=self._question)
-            task_cache["vector_only_task"] = 
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
+            async_tasks["vector_only_task"] = 
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
         if self._graph_only_answer:
             context_str = (f"{context_head_str}\n"
                            f"{graph_result_context}\n"
                            f"{context_tail_str}".strip("\n"))
 
             final_prompt = 
self._prompt_template.format(context_str=context_str, query_str=self._question)
-            task_cache["graph_only_task"] = 
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
+            async_tasks["graph_only_task"] = 
asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
         if self._graph_vector_answer:
             context_body_str = 
f"{vector_result_context}\n{graph_result_context}"
             if context.get("graph_ratio", 0.5) < 0.5:
@@ -133,30 +134,30 @@ class AnswerSynthesize:
                            f"{context_tail_str}".strip("\n"))
 
             final_prompt = 
self._prompt_template.format(context_str=context_str, query_str=self._question)
-            task_cache["graph_vector_task"] = asyncio.create_task(
+            async_tasks["graph_vector_task"] = asyncio.create_task(
                 self._llm.agenerate(prompt=final_prompt)
             )
-        # TODO: use log.debug instead of print
-        if task_cache.get("raw_task"):
-            response = await task_cache["raw_task"]
+
+        if async_tasks.get("raw_task"):
+            response = await async_tasks["raw_task"]
             context["raw_answer"] = response
             if verbose:
-                print(f"\033[91mANSWER: {response}\033[0m")
-        if task_cache.get("vector_only_task"):
-            response = await task_cache["vector_only_task"]
+                log.debug(f"ANSWER: {response}")
+        if async_tasks.get("vector_only_task"):
+            response = await async_tasks["vector_only_task"]
             context["vector_only_answer"] = response
             if verbose:
-                print(f"\033[91mANSWER: {response}\033[0m")
-        if task_cache.get("graph_only_task"):
-            response = await task_cache["graph_only_task"]
+                log.debug(f"ANSWER: {response}")
+        if async_tasks.get("graph_only_task"):
+            response = await async_tasks["graph_only_task"]
             context["graph_only_answer"] = response
             if verbose:
-                print(f"\033[91mANSWER: {response}\033[0m")
-        if task_cache.get("graph_vector_task"):
-            response = await task_cache["graph_vector_task"]
+                log.debug(f"ANSWER: {response}")
+        if async_tasks.get("graph_vector_task"):
+            response = await async_tasks["graph_vector_task"]
             context["graph_vector_answer"] = response
             if verbose:
-                print(f"\033[91mANSWER: {response}\033[0m")
+                log.debug(f"ANSWER: {response}")
 
         ops = sum([self._raw_answer, self._vector_only_answer, 
self._graph_only_answer, self._graph_vector_answer])
         context['call_count'] = context.get('call_count', 0) + ops
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
index 2cad98d..cdfa6b5 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -83,6 +83,7 @@ class KeywordExtract:
             response=response, lowercase=False, start_token="KEYWORDS:"
         )
         keywords.union(self._expand_synonyms(keywords=keywords))
+        keywords = {k.replace("'", "") for k in keywords}
         context["keywords"] = list(keywords)
 
         verbose = context.get("verbose") or False

Reply via email to