This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch graphdata
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git

commit 0b86093ce5dd42fcbdde9d6e1bca4e5ad83cf3bf
Author: imbajin <[email protected]>
AuthorDate: Thu Sep 5 16:25:42 2024 +0800

    refact(llm): enhance the graph/gremlin query phrase
---
 hugegraph-llm/README.md                            |  11 +-
 .../src/hugegraph_llm/demo/rag_web_demo.py         |   4 +-
 .../src/hugegraph_llm/operators/graph_rag_task.py  |  12 +-
 .../operators/hugegraph_op/graph_rag_query.py      | 139 ++++++++++-----------
 4 files changed, 78 insertions(+), 88 deletions(-)

diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md
index 25de21b..8df48bb 100644
--- a/hugegraph-llm/README.md
+++ b/hugegraph-llm/README.md
@@ -130,22 +130,19 @@ The methods of the `KgBuilder` class can be chained 
together to perform a sequen
 
 Run example like `python3 ./hugegraph_llm/examples/graph_rag_test.py`
 
-The `GraphRAG` class is used to integrate HugeGraph with large language models 
to provide retrieval-augmented generation capabilities.
+The `RAGPipeline` class is used to integrate HugeGraph with large language 
models to provide retrieval-augmented generation capabilities.
 Here is a brief usage guide:
 
 1. **Extract Keyword:**: Extract keywords and expand synonyms.
-    
+
     ```python
-    graph_rag.extract_keyword(text="Tell me about Al Pacino.").print_result()
+    graph_rag.extract_keywords(text="Tell me about Al Pacino.").print_result()
     ```
 
 2. **Query Graph for Rag**: Retrieve the corresponding keywords and their 
multi-degree associated relationships from HugeGraph.
 
      ```python
-     graph_rag.query_graph_for_rag(
-        max_deep=2,
-        max_items=30
-     ).print_result()
+     graph_rag.query_graph_db(max_deep=2, max_items=30).print_result()
      ```
 3. **Synthesize Answer**: Summarize the results and organize the language to 
answer the question.
 
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index c4c68c0..3cedc50 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -81,9 +81,9 @@ def rag_answer(
         return "", "", "", ""
     searcher = RAGPipeline()
     if vector_search:
-        searcher.query_vector_index_for_rag()
+        searcher.query_vector_index()
     if graph_search:
-        searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag()
+        searcher.extract_keywords().keywords_to_vid().query_graph_db()
     # TODO: add more user-defined search strategies
     searcher.merge_dedup_rerank(
         graph_ratio, rerank_method, near_neighbor_first, 
custom_related_information
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 91bc7b3..c464af2 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -70,7 +70,7 @@ class RAGPipeline:
         )
         return self
 
-    def extract_keyword(
+    def extract_keywords(
             self,
             text: Optional[str] = None,
             max_keywords: int = 5,
@@ -99,7 +99,7 @@ class RAGPipeline:
         )
         return self
 
-    def match_keyword_to_id(
+    def keywords_to_vid(
         self,
         by: Literal["query", "keywords"] = "keywords",
         topk_per_keyword: int = 1,
@@ -108,6 +108,8 @@ class RAGPipeline:
         """
         Add a semantic ID query operator to the pipeline.
 
+        :param topk_per_query: Top K results per query.
+        :param by: Match by query or keywords.
         :param topk_per_keyword: Top K results per keyword.
         :return: Self-instance for chaining.
         """
@@ -121,7 +123,7 @@ class RAGPipeline:
         )
         return self
 
-    def query_graph_for_rag(
+    def query_graph_db(
         self,
         max_deep: int = 2,
         max_items: int = 30,
@@ -144,7 +146,7 @@ class RAGPipeline:
         )
         return self
 
-    def query_vector_index_for_rag(self, max_items: int = 3):
+    def query_vector_index(self, max_items: int = 3):
         """
         Add a vector index query operator to the pipeline.
 
@@ -230,7 +232,7 @@ class RAGPipeline:
         :return: Final context after all operators have been executed.
         """
         if len(self._operators) == 0:
-            self.extract_keyword().query_graph_for_rag().synthesize_answer()
+            self.extract_keywords().query_graph_db().synthesize_answer()
 
         context = kwargs
         context["llm"] = self._llm
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index fe225c2..1cb7c17 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -23,61 +23,59 @@ from hugegraph_llm.config import settings
 from pyhugegraph.client import PyHugeClient
 
 
-class GraphRAGQuery:
-    VERTEX_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').toList()"
-    # ID_RAG_GREMLIN_QUERY_TEMPL = 
"g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV(
-    # ).as('obj')).times({max_deep}).path().by(project('label', 'id', 
'props').by(label()).by(id()).by(valueMap().by(
-    # unfold()))).by(project('label', 'inV', 'outV', 
'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap(
-    # ).by(unfold()))).limit({max_items}).toList()"
+VERTEX_QUERY_TPL = "g.V({keywords}).as('subj').toList()"
+# ID_RAG_GREMLIN_QUERY_TEMPL = 
"g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV(
+# ).as('obj')).times({max_deep}).path().by(project('label', 'id', 
'props').by(label()).by(id()).by(valueMap().by(
+# unfold()))).by(project('label', 'inV', 'outV', 
'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap(
+# ).by(unfold()))).limit({max_items}).toList()"
+
+# TODO: we could use a simpler query (like kneighbor-api to get the edges)
+# TODO: use dedup() to filter duplicate paths
+ID_QUERY_NEIGHBOR_TPL = """
+g.V({keywords}).as('subj')
+.repeat(
+   bothE({edge_labels}).as('rel').otherV().as('obj')
+).times({max_deep})
+.path()
+.by(project('label', 'id', 'props')
+   .by(label())
+   .by(id())
+   .by(valueMap().by(unfold()))
+)
+.by(project('label', 'inV', 'outV', 'props')
+   .by(label())
+   .by(inV().id())
+   .by(outV().id())
+   .by(valueMap().by(unfold()))
+)
+.limit({max_items})
+.toList()
+"""
 
-    # TODO: we could use a simpler query (like kneighbor-api to get the edges)
-    ID_RAG_GREMLIN_QUERY_TEMPL = """
-    g.V().hasId({keywords}).as('subj')
-    .repeat(
-       bothE({edge_labels}).as('rel').otherV().as('obj')
-    ).times({max_deep})
-    .path()
-    .by(project('label', 'id', 'props')
-       .by(label())
-       .by(id())
-       .by(valueMap().by(unfold()))
-    )
-    .by(project('label', 'inV', 'outV', 'props')
-       .by(label())
-       .by(inV().id())
-       .by(outV().id())
-       .by(valueMap().by(unfold()))
-    )
-    .limit({max_items})
-    .toList()
-    """
+PROPERTY_QUERY_NEIGHBOR_TPL = """
+g.V().has('{prop}', within({keywords})).as('subj')
+.repeat(
+   bothE({edge_labels}).as('rel').otherV().as('obj')
+).times({max_deep})
+.path()
+.by(project('label', 'props')
+   .by(label())
+   .by(valueMap().by(unfold()))
+)
+.by(project('label', 'inV', 'outV', 'props')
+   .by(label())
+   .by(inV().values('{prop}'))
+   .by(outV().values('{prop}'))
+   .by(valueMap().by(unfold()))
+)
+.limit({max_items})
+.toList()
+"""
 
-    PROP_RAG_GREMLIN_QUERY_TEMPL = """
-    g.V().has('{prop}', within({keywords})).as('subj')
-    .repeat(
-       bothE({edge_labels}).as('rel').otherV().as('obj')
-    ).times({max_deep})
-    .path()
-    .by(project('label', 'props')
-       .by(label())
-       .by(valueMap().by(unfold()))
-    )
-    .by(project('label', 'inV', 'outV', 'props')
-       .by(label())
-       .by(inV().values('{prop}'))
-       .by(outV().values('{prop}'))
-       .by(valueMap().by(unfold()))
-    )
-    .limit({max_items})
-    .toList()
-    """
 
-    def __init__(
-        self,
-        max_deep: int = 2,
-        max_items: int = 30,
-        prop_to_match: Optional[str] = None,
-    ):
+class GraphRAGQuery:
+
+    def __init__(self, max_deep: int = 2, max_items: int = 30, prop_to_match: 
Optional[str] = None):
         self._client = PyHugeClient(
             settings.graph_ip,
             settings.graph_port,
@@ -119,36 +117,33 @@ class GraphRAGQuery:
         edge_labels_str = ",".join("'" + label + "'" for label in edge_labels)
 
         use_id_to_match = self._prop_to_match is None
-
         if not use_id_to_match:
             assert keywords is not None, "No keywords for graph query."
             keywords_str = ",".join("'" + kw + "'" for kw in keywords)
-            rag_gremlin_query = self.PROP_RAG_GREMLIN_QUERY_TEMPL.format(
+            gremlin_query = PROPERTY_QUERY_NEIGHBOR_TPL.format(
                 prop=self._prop_to_match,
                 keywords=keywords_str,
                 max_deep=self._max_deep,
                 max_items=self._max_items,
                 edge_labels=edge_labels_str,
             )
-            result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
-            graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_knowledge_from_query_result(
+            result: List[Any] = 
self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+            graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_graph_from_query_result(
                 query_result=result
             )
         else:
             assert entrance_vids is not None, "No entrance vertices for query."
-            rag_gremlin_query = self.VERTEX_GREMLIN_QUERY_TEMPL.format(
-                keywords=entrance_vids,
-            )
-            result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
-            vertex_knowledge = 
self._format_knowledge_from_vertex(query_result=result)
-            rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format(
+            gremlin_query = VERTEX_QUERY_TPL.format(keywords=entrance_vids)
+            result: List[Any] = 
self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+            vertex_knowledge = 
self._format_graph_from_vertex(query_result=result)
+            gremlin_query = ID_QUERY_NEIGHBOR_TPL.format(
                 keywords=entrance_vids,
                 max_deep=self._max_deep,
                 max_items=self._max_items,
                 edge_labels=edge_labels_str,
             )
-            result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
-            graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_knowledge_from_query_result(
+            result: List[Any] = 
self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+            graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_graph_from_query_result(
                 query_result=result
             )
             graph_chain_knowledge.update(vertex_knowledge)
@@ -172,7 +167,7 @@ class GraphRAGQuery:
 
         return context
 
-    def _format_knowledge_from_vertex(self, query_result: List[Any]) -> 
Set[str]:
+    def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]:
         knowledge = set()
         for item in query_result:
             props_str = ", ".join(f"{k}: {v}" for k, v in 
item["properties"].items())
@@ -180,8 +175,8 @@ class GraphRAGQuery:
             knowledge.add(node_str)
         return knowledge
 
-    def _format_knowledge_from_query_result(
-        self, query_result: List[Any]
+    def _format_graph_from_query_result(
+            self, query_result: List[Any]
     ) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]:
         use_id_to_match = self._prop_to_match is None
         knowledge = set()
@@ -234,18 +229,14 @@ class GraphRAGQuery:
     def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]:
         schema = self._get_graph_schema()
         node_props_str, edge_props_str = schema.split("\n")[:2]
-        node_props_str = node_props_str[len("Node properties: ") 
:].strip("[").strip("]")
-        edge_props_str = edge_props_str[len("Edge properties: ") 
:].strip("[").strip("]")
+        node_props_str = node_props_str[len("Node properties: 
"):].strip("[").strip("]")
+        edge_props_str = edge_props_str[len("Edge properties: 
"):].strip("[").strip("]")
         node_labels = self._extract_label_names(node_props_str)
         edge_labels = self._extract_label_names(edge_props_str)
         return node_labels, edge_labels
 
     @staticmethod
-    def _extract_label_names(
-        source: str,
-        head: str = "name: ",
-        tail: str = ", ",
-    ) -> List[str]:
+    def _extract_label_names(source: str, head: str = "name: ", tail: str = ", 
") -> List[str]:
         result = []
         for s in source.split(head):
             end = s.find(tail)

Reply via email to