This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new d6d7990  feat(llm): handle 'island nodes' extraction in 2-step graph 
queries and add asynchronous methods to the four types of generation functions 
in the rag web demo. (#58)
d6d7990 is described below

commit d6d799015f3469f401f36a2379a6732b2a71ac6f
Author: vichayturen <[email protected]>
AuthorDate: Mon Aug 12 14:24:28 2024 +0800

    feat(llm): handle 'island nodes' extraction in 2-step graph queries and add 
asynchronous methods to the four types of generation functions in the rag web 
demo. (#58)
    
    1. Fix some bugs and prompts.
    2. Add two-stage query to get more information in graph rag query.
    3. Add optional parameters for keyword matching vid.
    4. Add asynchronous methods to the four types of generation functions in 
the rag web demo.
    5. Rename some options & format json output
    
    ---------
    
    Co-authored-by: imbajin <[email protected]>
---
 .../src/hugegraph_llm/demo/rag_web_demo.py         |  92 ++++++++--------
 .../src/hugegraph_llm/enums/id_strategy.py         |   1 +
 .../src/hugegraph_llm/indices/graph_index.py       |   4 +-
 .../src/hugegraph_llm/models/embeddings/base.py    |   7 ++
 .../src/hugegraph_llm/models/embeddings/ollama.py  |   9 ++
 .../src/hugegraph_llm/models/embeddings/openai.py  |   5 +
 .../src/hugegraph_llm/models/embeddings/qianfan.py |   8 ++
 .../src/hugegraph_llm/models/llms/base.py          |   8 ++
 .../src/hugegraph_llm/models/llms/ollama.py        |  21 ++++
 .../src/hugegraph_llm/models/llms/openai.py        |  30 ++++++
 .../src/hugegraph_llm/models/llms/qianfan.py       |  17 +++
 .../operators/common_op/check_schema.py            |   3 +-
 .../operators/common_op/merge_dedup_rerank.py      |   1 +
 .../src/hugegraph_llm/operators/graph_rag_task.py  |   3 +-
 .../operators/hugegraph_op/graph_rag_query.py      | 117 ++++++++++++---------
 .../operators/index_op/build_semantic_index.py     |  11 +-
 .../operators/index_op/semantic_id_query.py        |   9 +-
 .../operators/llm_op/answer_synthesize.py          |  66 ++++++++----
 .../operators/llm_op/disambiguate_data.py          |   4 +-
 .../hugegraph_llm/operators/llm_op/info_extract.py |   3 +-
 .../operators/llm_op/property_graph_extract.py     |  10 +-
 .../src/hugegraph_llm/utils/hugegraph_utils.py     |   6 +-
 style/pylint.conf                                  |   2 +-
 23 files changed, 303 insertions(+), 134 deletions(-)

diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index 08a0d2d..01bc85c 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -50,10 +50,12 @@ def convert_bool_str(string):
     raise gr.Error(f"Invalid boolean string: {string}")
 
 
+# TODO: enhance/distinguish the "graph_rag" name to avoid confusion
 def graph_rag(text: str, raw_answer: str, vector_only_answer: str,
               graph_only_answer: str, graph_vector_answer):
     vector_search = convert_bool_str(vector_only_answer) or 
convert_bool_str(graph_vector_answer)
     graph_search = convert_bool_str(graph_only_answer) or 
convert_bool_str(graph_vector_answer)
+
     if raw_answer == "false" and not vector_search and not graph_search:
         gr.Warning("Please select at least one generate mode.")
         return "", "", "", ""
@@ -68,6 +70,7 @@ def graph_rag(text: str, raw_answer: str, vector_only_answer: 
str,
         graph_only_answer=convert_bool_str(graph_only_answer),
         graph_vector_answer=convert_bool_str(graph_vector_answer)
     ).run(verbose=True, query=text)
+
     try:
         context = searcher.run(verbose=True, query=text)
         return (
@@ -76,9 +79,12 @@ def graph_rag(text: str, raw_answer: str, 
vector_only_answer: str,
             context.get("graph_only_answer", ""),
             context.get("graph_vector_answer", "")
         )
-    except Exception as e:  # pylint: disable=broad-exception-caught
+    except ValueError as e:
         log.error(e)
         raise gr.Error(str(e))
+    except Exception as e:  # pylint: disable=broad-exception-caught
+        log.error(e)
+        raise gr.Error(f"An unexpected error occurred: {str(e)}")
 
 
 def build_kg(file, schema, example_prompt, build_mode):  # pylint: 
disable=too-many-branches
@@ -93,10 +99,11 @@ def build_kg(file, schema, example_prompt, build_mode):  # 
pylint: disable=too-m
             text += para.text
             text += "\n"
     elif full_path.endswith(".pdf"):
-        raise gr.Error("PDF will be supported later!")
+        raise gr.Error("PDF will be supported later! Try to upload text/docx 
now")
     else:
         raise gr.Error("Please input txt or docx file.")
     builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), 
get_hg_client())
+
     if build_mode != "Rebuild vertex index":
         if schema:
             try:
@@ -108,19 +115,22 @@ def build_kg(file, schema, example_prompt, build_mode):  
# pylint: disable=too-m
         else:
             return "ERROR: please input schema."
     builder.chunk_split(text, "paragraph", "zh")
-    if build_mode == "Rebuild vertex index":
+
+    # TODO: avoid hardcoding the "build_mode" strings (use var/constant 
instead)
+    if build_mode == "Rebuild Vector":
         builder.fetch_graph_data()
     else:
         builder.extract_info(example_prompt, "property_graph")
-    if build_mode != "Test":
-        if build_mode in ("Clear and import", "Rebuild vertex index"):
+    # "Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector"
+    if build_mode != "Test Mode":
+        if build_mode in ("Clear and Import", "Rebuild Vector"):
             clean_vector_index()
         builder.build_vector_index()
-    if build_mode == "Clear and import":
+    if build_mode == "Clear and Import":
         clean_hg_data()
-    if build_mode in ("Clear and import", "Import"):
+    if build_mode in ("Clear and Import", "Import Mode"):
         builder.commit_to_hugegraph()
-    if build_mode != "Test":
+    if build_mode != "Test Mode":
         builder.build_vertex_id_semantic_index()
     log.debug(builder.operators)
     try:
@@ -319,22 +329,18 @@ if __name__ == "__main__":
 
 
         gr.Markdown(
-            """## 1. build knowledge graph
+            """## 1. Build vector/graph RAG (💡)
 - Document: Input document file which should be TXT or DOCX.
 - Schema: Accepts two types of text as below:
     - User-defined JSON format Schema.
-    - Specify the name of the HugeGraph graph instance, and it will
-    automatically extract the schema of the graph.
+    - Specify the name of the HugeGraph graph instance, it will automatically 
get the schema from it.
 - Info extract head: The head of prompt of info extracting.
 - Build mode: 
-    - Test: Only extract vertices and edges from file without building vector 
index or 
-    importing into HugeGraph.
-    - Clear and Import: Clear the vector index and data of HugeGraph and then 
extract and 
-    import new data.
-    - Import: Extract the data and append it to HugeGraph and vector index 
without clearing 
-    anything.
-    - Rebuild vertex index: Do not clear the HugeGraph data, but only clear 
vector index 
-    and build new one.
+    - Test Mode: Only extract vertices and edges from the file into memory 
(without building the vector index or 
+    writing data into HugeGraph)
+    - Import Mode: Extract the data and append it to HugeGraph & the vector 
index (without clearing any existing data)
+    - Clear and Import: Clear all existed RAG data(vector + graph), then 
rebuild them from the current input
+    - Rebuild Vector: Only rebuild vector index. (keep the graph data intact)
 """
         )
 
@@ -380,10 +386,9 @@ if __name__ == "__main__":
             info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT,
                                                label="Info extract head")
             with gr.Column():
-                mode = gr.Radio(choices=["Test", "Clear and import", "Import",
-                                         "Rebuild vertex index"],
-                                value="Test", label="Build mode")
-                btn = gr.Button("Build knowledge graph")
+                mode = gr.Radio(choices=["Test Mode", "Import Mode", "Clear 
and Import", "Rebuild Vector"],
+                                value="Test Mode", label="Build mode")
+                btn = gr.Button("Build Vector/Graph RAG")
         with gr.Row():
             out = gr.Textbox(label="Output", show_copy_button=True)
         btn.click(  # pylint: disable=no-member
@@ -392,40 +397,43 @@ if __name__ == "__main__":
             outputs=out
         )
 
-        gr.Markdown("""## 2. Retrieval augmented generation by hugegraph""")
+        gr.Markdown("""## 2. RAG with HugeGraph 📖""")
         with gr.Row():
             with gr.Column(scale=2):
-                inp = gr.Textbox(value="Tell me about Sarah.", 
label="Question")
-                raw_out = gr.Textbox(label="Raw LLM Answer", 
show_copy_button=True)
-                vector_only_out = gr.Textbox(label="Vector-only answer", 
show_copy_button=True)
-                graph_only_out = gr.Textbox(label="Graph-only answer", 
show_copy_button=True)
-                graph_vector_out = gr.Textbox(label="Graph-Vector answer", 
show_copy_button=True)
+                inp = gr.Textbox(value="Tell me about Sarah.", 
label="Question", show_copy_button=True)
+                raw_out = gr.Textbox(label="Basic LLM Answer", 
show_copy_button=True)
+                vector_only_out = gr.Textbox(label="Vector-only Answer", 
show_copy_button=True)
+                graph_only_out = gr.Textbox(label="Graph-only Answer", 
show_copy_button=True)
+                graph_vector_out = gr.Textbox(label="Graph-Vector Answer", 
show_copy_button=True)
             with gr.Column(scale=1):
                 raw_radio = gr.Radio(choices=["true", "false"], value="false",
-                                     label="Raw LLM answer")
+                                     label="Basic LLM Answer")
                 vector_only_radio = gr.Radio(choices=["true", "false"], 
value="true",
-                                             label="Vector-only answer")
+                                             label="Vector-only Answer")
                 graph_only_radio = gr.Radio(choices=["true", "false"], 
value="false",
-                                            label="Graph-only answer")
+                                            label="Graph-only Answer")
                 graph_vector_radio = gr.Radio(choices=["true", "false"], 
value="false",
-                                              label="Graph-Vector answer")
-                btn = gr.Button("Retrieval augmented generation")
-        btn.click(fn=graph_rag, inputs=[inp, raw_radio, vector_only_radio, 
graph_only_radio,  # pylint: disable=no-member
+                                              label="Graph-Vector Answer")
+                btn = gr.Button("Answer Question")
+        btn.click(fn=graph_rag, inputs=[inp, raw_radio, vector_only_radio, 
graph_only_radio, # pylint: disable=no-member
                                         graph_vector_radio],
                   outputs=[raw_out, vector_only_out, graph_only_out, 
graph_vector_out])
 
-        gr.Markdown("""## 3. Others """)
+        gr.Markdown("""## 3. Others (🚧) """)
         with gr.Row():
-            inp = []
+            with gr.Column():
+                inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin 
query", show_copy_button=True)
+                format = gr.Checkbox(label="Format JSON", value=True)
             out = gr.Textbox(label="Output", show_copy_button=True)
-        btn = gr.Button("Initialize HugeGraph test data")
-        btn.click(fn=init_hg_test_data, inputs=inp, outputs=out)  # pylint: 
disable=no-member
+        btn = gr.Button("Run gremlin query on HugeGraph")
+        btn.click(fn=run_gremlin_query, inputs=[inp, format], outputs=out)  # 
pylint: disable=no-member
 
         with gr.Row():
-            inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query")
+            inp = []
             out = gr.Textbox(label="Output", show_copy_button=True)
-        btn = gr.Button("Run gremlin query on HugeGraph")
-        btn.click(fn=run_gremlin_query, inputs=inp, outputs=out)  # pylint: 
disable=no-member
+        btn = gr.Button("(BETA) Init HugeGraph test data (🚧WIP)")
+        btn.click(fn=init_hg_test_data, inputs=inp, outputs=out)  # pylint: 
disable=no-member
+
     app = gr.mount_gradio_app(app, hugegraph_llm, path="/")
     # Note: set reload to False in production environment
     uvicorn.run(app, host=args.host, port=args.port)
diff --git a/hugegraph-llm/src/hugegraph_llm/enums/id_strategy.py 
b/hugegraph-llm/src/hugegraph_llm/enums/id_strategy.py
index 8db4374..5f3cadb 100644
--- a/hugegraph-llm/src/hugegraph_llm/enums/id_strategy.py
+++ b/hugegraph-llm/src/hugegraph_llm/enums/id_strategy.py
@@ -19,6 +19,7 @@
 from enum import Enum
 
 
+# Note: we don't support the "UUID" strategy for now
 class IdStrategy(Enum):
     AUTOMATIC = "AUTOMATIC"
     CUSTOMIZE_NUMBER = "CUSTOMIZE_NUMBER"
diff --git a/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py 
b/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py
index 3961204..74269fc 100644
--- a/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py
@@ -36,12 +36,12 @@ class GraphIndex:
     def clear_graph(self):
         self.client.gremlin().exec("g.V().drop()")
 
+    # TODO: replace triples with a more specific graph element type & 
implement it
     def add_triples(self, triples: list):
-        # TODO
         pass
 
+    # TODO: replace triples with a more specific graph element type & 
implement it
     def search_triples(self, max_deep: int = 2):
-        # TODO
         pass
 
     def execute_gremlin_query(self, query: str):
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py 
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
index 15bc4ea..2ea8786 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
@@ -60,6 +60,13 @@ class BaseEmbedding(ABC):
     ) -> List[float]:
         """Comment"""
 
+    @abstractmethod
+    async def async_get_text_embedding(
+            self,
+            text: str
+    ) -> List[float]:
+        """Comment"""
+
     @staticmethod
     def similarity(
             embedding1: Union[List[float], np.ndarray],
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py 
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
index f87502c..81e11cc 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
@@ -32,6 +32,7 @@ class OllamaEmbedding(BaseEmbedding):
     ):
         self.model = model
         self.client = ollama.Client(host=f"http://{host}:{port}";, **kwargs)
+        self.async_client = ollama.AsyncClient(host=f"http://{host}:{port}";, 
**kwargs)
         self.embedding_dimension = None
 
     def get_text_embedding(
@@ -40,3 +41,11 @@ class OllamaEmbedding(BaseEmbedding):
     ) -> List[float]:
         """Comment"""
         return list(self.client.embeddings(model=self.model, 
prompt=text)["embedding"])
+
+    async def async_get_text_embedding(
+            self,
+            text: str
+    ) -> List[float]:
+        """Comment"""
+        response = await self.async_client.embeddings(model=self.model, 
prompt=text)
+        return list(response["embedding"])
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py 
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
index 267effa..2a092e7 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
@@ -38,3 +38,8 @@ class OpenAIEmbedding:
         """Comment"""
         response = self.client.create(input=text, 
model=self.embedding_model_name)
         return response.data[0].embedding
+
+    async def async_get_text_embedding(self, text: str) -> List[float]:
+        """Comment"""
+        response = await self.client.acreate(input=text, 
model=self.embedding_model_name)
+        return response.data[0].embedding
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py 
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
index c86a920..2f41fe5 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
@@ -49,3 +49,11 @@ class QianFanEmbedding:
             texts=[text]
         )
         return response["body"]["data"][0]["embedding"]
+
+    async def async_get_text_embedding(self, text: str) -> List[float]:
+        """ Usage refer: 
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn""";
+        response = await self.client.ado(
+            model=self.embedding_model_name,
+            texts=[text]
+        )
+        return response["body"]["data"][0]["embedding"]
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
index 0e05156..04c1c27 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
@@ -31,6 +31,14 @@ class BaseLLM(ABC):
     ) -> str:
         """Comment"""
 
+    @abstractmethod
+    async def agenerate(
+            self,
+            messages: Optional[List[Dict[str, Any]]] = None,
+            prompt: Optional[str] = None,
+    ) -> str:
+        """Comment"""
+
     @abstractmethod
     def generate_streaming(
         self,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
index f94e268..5965599 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
@@ -29,6 +29,7 @@ class OllamaClient(BaseLLM):
     def __init__(self, model: str, host: str = "127.0.0.1", port: int = 11434, 
**kwargs):
         self.model = model
         self.client = ollama.Client(host=f"http://{host}:{port}";, **kwargs)
+        self.async_client = ollama.AsyncClient(host=f"http://{host}:{port}";, 
**kwargs)
 
     @retry(tries=3, delay=1)
     def generate(
@@ -50,6 +51,26 @@ class OllamaClient(BaseLLM):
             print(f"Retrying LLM call {e}")
             raise e
 
+    @retry(tries=3, delay=1)
+    async def agenerate(
+            self,
+            messages: Optional[List[Dict[str, Any]]] = None,
+            prompt: Optional[str] = None,
+    ) -> str:
+        """Comment"""
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+        try:
+            response = await self.async_client.chat(
+                model=self.model,
+                messages=messages,
+            )
+            return response["message"]["content"]
+        except Exception as e:
+            print(f"Retrying LLM call {e}")
+            raise e
+
     def generate_streaming(
         self,
         messages: Optional[List[Dict[str, Any]]] = None,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
index 4f50e96..36cb11d 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
@@ -73,6 +73,36 @@ class OpenAIChat(BaseLLM):
             log.error("Retrying LLM call %s", e)
             raise e
 
+    @retry(tries=3, delay=1)
+    async def agenerate(
+            self,
+            messages: Optional[List[Dict[str, Any]]] = None,
+            prompt: Optional[str] = None,
+    ) -> str:
+        """Generate a response to the query messages/prompt."""
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+        try:
+            completions = await openai.ChatCompletion.acreate(
+                model=self.model,
+                temperature=self.temperature,
+                max_tokens=self.max_tokens,
+                messages=messages,
+            )
+            return completions.choices[0].message.content
+        # catch context length / do not retry
+        except openai.error.InvalidRequestError as e:
+            log.critical("Fatal: %s", e)
+            return str(f"Error: {e}")
+        # catch authorization errors / do not retry
+        except openai.error.AuthenticationError:
+            log.critical("The provided OpenAI API key is invalid")
+            return "Error: The provided OpenAI API key is invalid"
+        except Exception as e:
+            log.error("Retrying LLM call %s", e)
+            raise e
+
     def generate_streaming(
         self,
         messages: Optional[List[Dict[str, Any]]] = None,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
index bebfa1a..25d5e21 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
@@ -49,6 +49,23 @@ class QianfanClient(BaseLLM):
             )
         return response.body["result"]
 
+    @retry(tries=3, delay=1)
+    async def agenerate(
+            self,
+            messages: Optional[List[Dict[str, Any]]] = None,
+            prompt: Optional[str] = None,
+    ) -> str:
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+
+        response = await self.chat_comp.ado(model=self.chat_model, 
messages=messages)
+        if response.code != 200:
+            raise Exception(
+                f"Request failed with code {response.code}, message: 
{response.body['error_msg']}"
+            )
+        return response.body["result"]
+
     def generate_streaming(
             self,
             messages: Optional[List[Dict[str, Any]]] = None,
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py 
b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py
index 616c7b1..7a1f64a 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py
@@ -31,7 +31,8 @@ class CheckSchema:
             raise ValueError("Input data is not a dictionary.")
         if "vertexlabels" not in schema or "edgelabels" not in schema:
             raise ValueError("Input data does not contain 'vertexlabels' or 
'edgelabels'.")
-        if not isinstance(schema["vertexlabels"], list) or not 
isinstance(schema["edgelabels"], list):
+        if not isinstance(schema["vertexlabels"], list) or not 
isinstance(schema["edgelabels"],
+                                                                          
list):
             raise ValueError("'vertexlabels' or 'edgelabels' in input data is 
not a list.")
         for vertex in schema["vertexlabels"]:
             if not isinstance(vertex, dict):
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py 
b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
index a65a709..19ad4e4 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
@@ -35,6 +35,7 @@ class MergeDedupRerank:
         self.topk = topk
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
+        # TODO: exact > fuzzy; vertex > 1-depth-neighbour > 2-depth-neighbour; 
priority vertices
         query = context.get("query")
 
         vector_result = context.get("vector_result", [])
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index e343fad..e3dcd99 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -58,10 +58,11 @@ class GraphRAG:
         )
         return self
 
-    def match_keyword_to_id(self):
+    def match_keyword_to_id(self, topk_per_keyword: int = 1):
         self._operators.append(
             SemanticIdQuery(
                 embedding=self._embedding,
+                topk_per_keyword=topk_per_keyword
             )
         )
         return self
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index b76637c..2c81ee0 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -24,45 +24,50 @@ from pyhugegraph.client import PyHugeClient
 
 
 class GraphRAGQuery:
-    ID_RAG_GREMLIN_QUERY_TEMPL = (
-        "g.V().hasId({keywords}).as('subj')"
-        ".repeat("
-        "   bothE({edge_labels}).as('rel').otherV().as('obj')"
-        ").times({max_deep})"
-        ".path()"
-        ".by(project('label', 'id', 'props')"
-        "   .by(label())"
-        "   .by(id())"
-        "   .by(valueMap().by(unfold()))"
-        ")"
-        ".by(project('label', 'inV', 'outV', 'props')"
-        "   .by(label())"
-        "   .by(inV().id())"
-        "   .by(outV().id())"
-        "   .by(valueMap().by(unfold()))"
-        ")"
-        ".limit({max_items})"
-        ".toList()"
+    VERTEX_GREMLIN_QUERY_TEMPL = (
+        "g.V().hasId({keywords}).as('subj').toList()"
     )
-    PROP_RAG_GREMLIN_QUERY_TEMPL = (
-        "g.V().has('{prop}', within({keywords})).as('subj')"
-        ".repeat("
-        "   bothE({edge_labels}).as('rel').otherV().as('obj')"
-        ").times({max_deep})"
-        ".path()"
-        ".by(project('label', 'props')"
-        "   .by(label())"
-        "   .by(valueMap().by(unfold()))"
-        ")"
-        ".by(project('label', 'inV', 'outV', 'props')"
-        "   .by(label())"
-        "   .by(inV().values('{prop}'))"
-        "   .by(outV().values('{prop}'))"
-        "   .by(valueMap().by(unfold()))"
-        ")"
-        ".limit({max_items})"
-        ".toList()"
+    # TODO: we could use a simpler query (like kneighbor-api to get the edges)
+    ID_RAG_GREMLIN_QUERY_TEMPL = """
+    g.V().hasId({keywords}).as('subj')
+    .repeat(
+       bothE({edge_labels}).as('rel').otherV().as('obj')
+    ).times({max_deep})
+    .path()
+    .by(project('label', 'id', 'props')
+       .by(label())
+       .by(id())
+       .by(valueMap().by(unfold()))
     )
+    .by(project('label', 'inV', 'outV', 'props')
+       .by(label())
+       .by(inV().id())
+       .by(outV().id())
+       .by(valueMap().by(unfold()))
+    )
+    .limit({max_items})
+    .toList()
+    """
+
+    PROP_RAG_GREMLIN_QUERY_TEMPL = """
+    g.V().has('{prop}', within({keywords})).as('subj')
+    .repeat(
+       bothE({edge_labels}).as('rel').otherV().as('obj')
+    ).times({max_deep})
+    .path()
+    .by(project('label', 'props')
+       .by(label())
+       .by(valueMap().by(unfold()))
+    )
+    .by(project('label', 'inV', 'outV', 'props')
+       .by(label())
+       .by(inV().values('{prop}'))
+       .by(outV().values('{prop}'))
+       .by(valueMap().by(unfold()))
+    )
+    .limit({max_items})
+    .toList()
+    """
 
     def __init__(
             self,
@@ -93,10 +98,10 @@ class GraphRAGQuery:
                 user = context.get("user") or "admin"
                 pwd = context.get("pwd") or "admin"
                 self._client = PyHugeClient(ip=ip, port=port, graph=graph, 
user=user, pwd=pwd)
-        assert self._client is not None, "No graph for query."
+        assert self._client is not None, "No valid graph to search."
 
         keywords = context.get("keywords")
-        assert keywords is not None, "No keywords for query."
+        assert keywords is not None, "No keywords for graph query."
         entrance_vids = context.get("entrance_vids")
         assert entrance_vids is not None, "No entrance vertices for query."
 
@@ -114,25 +119,29 @@ class GraphRAGQuery:
 
         if not use_id_to_match:
             keywords_str = ",".join("'" + kw + "'" for kw in keywords)
-            rag_gremlin_query_template = self.PROP_RAG_GREMLIN_QUERY_TEMPL
-            rag_gremlin_query = rag_gremlin_query_template.format(
+            rag_gremlin_query = self.PROP_RAG_GREMLIN_QUERY_TEMPL.format(
                 prop=self._prop_to_match,
                 keywords=keywords_str,
                 max_deep=self._max_deep,
                 max_items=self._max_items,
                 edge_labels=edge_labels_str,
             )
+            result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
+            knowledge: Set[str] = 
self._format_knowledge_from_query_result(query_result=result)
         else:
-            rag_gremlin_query_template = self.ID_RAG_GREMLIN_QUERY_TEMPL
-            rag_gremlin_query = rag_gremlin_query_template.format(
+            rag_gremlin_query = self.VERTEX_GREMLIN_QUERY_TEMPL.format(
+                keywords=entrance_vids,
+            )
+            result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
+            knowledge: Set[str] = 
self._format_knowledge_from_vertex(query_result=result)
+            rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format(
                 keywords=entrance_vids,
                 max_deep=self._max_deep,
                 max_items=self._max_items,
                 edge_labels=edge_labels_str,
             )
-
-        result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
-        knowledge: Set[str] = 
self._format_knowledge_from_query_result(query_result=result)
+            result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
+            
knowledge.update(self._format_knowledge_from_query_result(query_result=result))
 
         context["graph_result"] = list(knowledge)
         context["synthesize_context_head"] = (
@@ -142,17 +151,23 @@ class GraphRAGQuery:
             "extracted based on key entities as subject:"
         )
 
+        # TODO: replace print to log
         verbose = context.get("verbose") or False
         if verbose:
-            print("\033[93mKNOWLEDGE FROM GRAPH:")
+            print("\033[93mKnowledge from Graph:")
             print("\n".join(rel for rel in context["graph_result"]) + 
"\033[0m")
 
         return context
 
-    def _format_knowledge_from_query_result(
-            self,
-            query_result: List[Any],
-    ) -> Set[str]:
+    def _format_knowledge_from_vertex(self, query_result: List[Any]) -> 
Set[str]:
+        knowledge = set()
+        for item in query_result:
+            props_str = ", ".join(f"{k}: {v}" for k, v in 
item["properties"].items())
+            node_str = f"{item['id']}{{{props_str}}}"
+            knowledge.add(node_str)
+        return knowledge
+
+    def _format_knowledge_from_query_result(self, query_result: List[Any]) -> 
Set[str]:
         use_id_to_match = self._prop_to_match is None
         knowledge = set()
         for line in query_result:
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
index c70c866..a439d46 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
@@ -32,9 +32,14 @@ class BuildSemanticIndex:
         self.embedding = embedding
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
-        vids = [vertex["id"] for vertex in context["vertices"]]
-        if len(vids) > 0:
-            log.debug("Building vector index for %s vertices...", len(vids))
+        if len(context["vertices"]) > 0:
+            log.debug("Building vector index for %s vertices...", 
len(context["vertices"]))
+            vids = []
+            vids_embedding = []
+            for vertex in context["vertices"]:
+                vertex_text = f"{vertex['label']}\n{vertex['properties']}"
+                
vids_embedding.append(self.embedding.get_text_embedding(vertex_text))
+                vids.append(vertex["id"])
             vids_embedding = [self.embedding.get_text_embedding(vid) for vid 
in vids]
             log.debug("Vector index built for %s vertices.", len(vids))
             if os.path.exists(self.index_file) and 
os.path.exists(self.content_file):
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
index 8b4e9d4..4eaef4e 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
@@ -25,19 +25,20 @@ from hugegraph_llm.models.embeddings.base import 
BaseEmbedding
 
 
 class SemanticIdQuery:
-    def __init__(self, embedding: BaseEmbedding):
+    def __init__(self, embedding: BaseEmbedding, topk_per_keyword: int = 1):
         index_file = str(os.path.join(resource_path, settings.graph_name, 
"vid.faiss"))
         content_file = str(os.path.join(resource_path, settings.graph_name, 
"vid.pkl"))
         self.vector_index = VectorIndex.from_index_file(index_file, 
content_file)
         self.embedding = embedding
+        self._topk_per_keyword = topk_per_keyword
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
         keywords = context["keywords"]
         graph_query_entrance = []
         for keyword in keywords:
             query_vector = self.embedding.get_text_embedding(keyword)
-            results = self.vector_index.search(query_vector, top_k=1)
+            results = self.vector_index.search(query_vector, 
top_k=self._topk_per_keyword)
             if results:
-                graph_query_entrance.append(results[0])
-        context["entrance_vids"] = graph_query_entrance
+                graph_query_entrance.extend(results[:self._topk_per_keyword])
+        context["entrance_vids"] = list(set(graph_query_entrance))
         return context
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index 8bcf05e..f885182 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -16,6 +16,7 @@
 # under the License.
 
 
+import asyncio
 from typing import Any, Dict, Optional
 
 from hugegraph_llm.models.llms.base import BaseLLM
@@ -84,22 +85,33 @@ class AnswerSynthesize:
             return {"answer": response}
 
         vector_result = context.get("vector_result", [])
-        vector_result_context = ("The following are paragraphs related to the 
query:\n"
-                                 + "\n".join([f"{i + 1}. {res}"
-                                              for i, res in 
enumerate(vector_result)]))
+        if len(vector_result) == 0:
+            vector_result_context = "There are no paragraphs related to the 
query."
+        else:
+            vector_result_context = ("The following are paragraphs related to 
the query:\n"
+                                     + "\n".join([f"{i + 1}. {res}"
+                                                  for i, res in 
enumerate(vector_result)]))
         graph_result = context.get("graph_result", [])
-        graph_result_context = ("The following are subgraph related to the 
query:\n"
-                                + "\n".join([f"{i + 1}. {res}"
-                                             for i, res in 
enumerate(graph_result)]))
+        if len(graph_result) == 0:
+            graph_result_context = "There are no knowledge from HugeGraph 
related to the query."
+        else:
+            graph_result_context = (
+                    "The following are knowledge from HugeGraph related to the 
query:\n"
+                    + "\n".join([f"{i + 1}. {res}"
+                                 for i, res in enumerate(graph_result)]))
+        context = asyncio.run(self.async_generate(context, context_head_str, 
context_tail_str,
+                                                  vector_result_context, 
graph_result_context))
 
-        verbose = context.get("verbose") or False
+        return context
 
+    async def async_generate(self, context: Dict[str, Any], context_head_str: 
str,
+                             context_tail_str: str, vector_result_context: str,
+                             graph_result_context: str):
+        verbose = context.get("verbose") or False
+        task_cache = {}
         if self._raw_answer:
             prompt = self._question
-            response = self._llm.generate(prompt=prompt)
-            context["raw_answer"] = response
-            if verbose:
-                print(f"\033[91mANSWER: {response}\033[0m")
+            task_cache["raw_task"] = 
asyncio.create_task(self._llm.agenerate(prompt=prompt))
         if self._vector_only_answer:
             context_str = (f"{context_head_str}\n"
                            f"{vector_result_context}\n"
@@ -109,10 +121,7 @@ class AnswerSynthesize:
                 context_str=context_str,
                 query_str=self._question,
             )
-            response = self._llm.generate(prompt=prompt)
-            context["vector_only_answer"] = response
-            if verbose:
-                print(f"\033[91mANSWER: {response}\033[0m")
+            task_cache["vector_only_task"] = 
asyncio.create_task(self._llm.agenerate(prompt=prompt))
         if self._graph_only_answer:
             context_str = (f"{context_head_str}\n"
                            f"{graph_result_context}\n"
@@ -122,10 +131,7 @@ class AnswerSynthesize:
                 context_str=context_str,
                 query_str=self._question,
             )
-            response = self._llm.generate(prompt=prompt)
-            context["graph_only_answer"] = response
-            if verbose:
-                print(f"\033[91mANSWER: {response}\033[0m")
+            task_cache["graph_only_task"] = 
asyncio.create_task(self._llm.agenerate(prompt=prompt))
         if self._graph_vector_answer:
             context_body_str = 
f"{vector_result_context}\n{graph_result_context}"
             context_str = (f"{context_head_str}\n"
@@ -136,9 +142,27 @@ class AnswerSynthesize:
                 context_str=context_str,
                 query_str=self._question,
             )
-            response = self._llm.generate(prompt=prompt)
+            task_cache["graph_vector_task"] = asyncio.create_task(
+                self._llm.agenerate(prompt=prompt)
+            )
+        if task_cache.get("raw_task"):
+            response = await task_cache["raw_task"]
+            context["raw_answer"] = response
+            if verbose:
+                print(f"\033[91mANSWER: {response}\033[0m")
+        if task_cache.get("vector_only_task"):
+            response = await task_cache["vector_only_task"]
+            context["vector_only_answer"] = response
+            if verbose:
+                print(f"\033[91mANSWER: {response}\033[0m")
+        if task_cache.get("graph_only_task"):
+            response = await task_cache["graph_only_task"]
+            context["graph_only_answer"] = response
+            if verbose:
+                print(f"\033[91mANSWER: {response}\033[0m")
+        if task_cache.get("graph_vector_task"):
+            response = await task_cache["graph_vector_task"]
             context["graph_vector_answer"] = response
             if verbose:
                 print(f"\033[91mANSWER: {response}\033[0m")
-
         return context
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
index e34b637..44bb69d 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
@@ -52,5 +52,7 @@ class DisambiguateData:
             llm_output = self.llm.generate(prompt=prompt)
             data["triples"] = []
             extract_triples_by_regex(llm_output, data)
-            print(f"LLM {self.__class__.__name__} input:{prompt} \n output: 
{llm_output} \n data: {data}")
+            print(
+                f"LLM {self.__class__.__name__} input:{prompt} \n"
+                f" output: {llm_output} \n data: {data}")
         return data
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py
index 6f0e3af..1424143 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py
@@ -152,7 +152,8 @@ class InfoExtract:
 
         for sentence in chunks:
             proceeded_chunk = self.extract_triples_by_llm(schema, sentence)
-            log.debug("[LLM] %s input: %s \n output:%s", 
self.__class__.__name__, sentence, proceeded_chunk)
+            log.debug("[LLM] %s input: %s \n output:%s", 
self.__class__.__name__,
+                      sentence, proceeded_chunk)
             if schema:
                 extract_triples_by_regex_with_schema(schema, proceeded_chunk, 
context)
             else:
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
index 077ee43..171b48a 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
@@ -66,10 +66,13 @@ def generate_extract_property_graph_prompt(text, 
schema=None) -> str:
     return f"""---
 
 Following the full instructions above, try to extract the following text from 
the given schema, output the JSON result:
+# Input
 ## Text:
 {text}
-## Graph schema:
-{schema}"""
+## Graph schema
+{schema}
+
+# Output"""
 
 
 def split_text(text: str) -> List[str]:
@@ -98,7 +101,8 @@ class PropertyGraphExtract:
         items = []
         for chunk in chunks:
             proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk)
-            log.debug("[LLM] %s input: %s \n output:%s", 
self.__class__.__name__, chunk, proceeded_chunk)
+            log.debug("[LLM] %s input: %s \n output:%s", 
self.__class__.__name__, chunk,
+                      proceeded_chunk)
             items.extend(self._extract_and_filter_label(schema, 
proceeded_chunk))
         items = self.filter_item(schema, items)
         for item in items:
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py 
b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py
index e1d51cb..d942bf9 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py
@@ -14,15 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+import json
 
 from pyhugegraph.client import PyHugeClient
 from hugegraph_llm.config import settings
 
 
-def run_gremlin_query(query):
+def run_gremlin_query(query, format=False):
     res = get_hg_client().gremlin().exec(query)
-    return res
+    return json.dumps(res, indent=4, ensure_ascii=False) if format else res
 
 
 def get_hg_client():
diff --git a/style/pylint.conf b/style/pylint.conf
index 2ae3281..80ebc61 100644
--- a/style/pylint.conf
+++ b/style/pylint.conf
@@ -337,7 +337,7 @@ indent-after-paren=4
 indent-string='    '
 
 # Maximum number of characters on a single line.
-max-line-length=100
+max-line-length=120
 
 # Maximum number of lines in a module.
 max-module-lines=1000


Reply via email to