This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new 9018271  fix(llm): support choose template in /rag & /rag/graph api 
(#135)
9018271 is described below

commit 90182713537fd88946b76ec4a05c9ff44f1549d9
Author: HaoJin Yang <[email protected]>
AuthorDate: Thu Dec 19 12:58:32 2024 +0800

    fix(llm): support choose template in /rag & /rag/graph api (#135)
    
    Mark 3 TODOs:
    1. prompt not work in `/rag/graph`
    2. prompt-update not fresh in `/rag`
    3. prompt-update not work in `/rag/graph`
    
    ---------
    
    Co-authored-by: imbajin <[email protected]>
---
 .../src/hugegraph_llm/api/models/rag_requests.py   | 56 +++++++++-----
 hugegraph-llm/src/hugegraph_llm/api/rag_api.py     | 47 +++++++++---
 .../src/hugegraph_llm/demo/rag_demo/app.py         | 87 ++++++++++++---------
 .../src/hugegraph_llm/demo/rag_demo/rag_block.py   | 83 ++++++++++++--------
 .../src/hugegraph_llm/operators/graph_rag_task.py  | 43 +++++++----
 .../operators/gremlin_generate_task.py             |  5 +-
 .../operators/hugegraph_op/graph_rag_query.py      | 89 ++++++++++++----------
 7 files changed, 255 insertions(+), 155 deletions(-)

diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py 
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 9413890..1eaa1e2 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -17,33 +17,47 @@
 
 from typing import Optional, Literal
 
+from fastapi import Query
 from pydantic import BaseModel
 
+from hugegraph_llm.config import prompt
+
 
 class RAGRequest(BaseModel):
-    query: str = ""
-    raw_answer: bool = False
-    vector_only: bool = False
-    graph_only: bool = False
-    graph_vector_answer: bool = False
-    graph_ratio: float = 0.5
-    rerank_method: Literal["bleu", "reranker"] = "bleu"
-    near_neighbor_first: bool = False
-    custom_priority_info: str = ""
-    answer_prompt: Optional[str] = None
+    query: str = Query("", description="Query you want to ask")
+    raw_answer: bool = Query(False, description="Use LLM to generate answer 
directly")
+    vector_only: bool = Query(False, description="Use LLM to generate answer 
with vector")
+    graph_only: bool = Query(True, description="Use LLM to generate answer 
with graph RAG only")
+    graph_vector_answer: bool = Query(False, description="Use LLM to generate 
answer with vector & GraphRAG")
+    graph_ratio: float = Query(0.5, description="The ratio of GraphRAG ans & 
vector ans")
+    rerank_method: Literal["bleu", "reranker"] = Query("bleu", 
description="Method to rerank the results.")
+    near_neighbor_first: bool = Query(False, description="Prioritize near 
neighbors in the search results.")
+    with_gremlin_tmpl: bool = Query(True, description="Use example template in 
text2gremlin")
+    custom_priority_info: str = Query("", description="Custom information to 
prioritize certain results.")
+    answer_prompt: Optional[str] = Query(prompt.answer_prompt, 
description="Prompt to guide the answer generation.")
+    keywords_extract_prompt: Optional[str] = Query(
+        prompt.keywords_extract_prompt,
+        description="Prompt for extracting keywords from query.",
+    )
+    gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates 
to use.")
+    gremlin_prompt: Optional[str] = Query(
+        prompt.gremlin_generate_prompt,
+        description="Prompt for the Text2Gremlin query.",
+    )
 
 
 class GraphRAGRequest(BaseModel):
-    query: str = ""
-    raw_answer: bool = True
-    vector_only: bool = False
-    graph_only: bool = False
-    graph_vector_answer: bool = False
-    graph_ratio: float = 0.5
-    rerank_method: Literal["bleu", "reranker"] = "bleu"
-    near_neighbor_first: bool = False
-    custom_priority_info: str = ""
-    answer_prompt: Optional[str] = None
+    query: str = Query("", description="Query you want to ask")
+    gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates 
to use.")
+    with_gremlin_tmpl: bool = Query(True, description="Use example template in 
text2gremlin")
+    answer_prompt: Optional[str] = Query(prompt.answer_prompt, 
description="Prompt to guide the answer generation.")
+    rerank_method: Literal["bleu", "reranker"] = Query("bleu", 
description="Method to rerank the results.")
+    near_neighbor_first: bool = Query(False, description="Prioritize near 
neighbors in the search results.")
+    custom_priority_info: str = Query("", description="Custom information to 
prioritize certain results.")
+    gremlin_prompt: Optional[str] = Query(
+        prompt.gremlin_generate_prompt,
+        description="Prompt for the Text2Gremlin query.",
+    )
 
 
 class GraphConfigRequest(BaseModel):
@@ -80,4 +94,4 @@ class RerankerConfigRequest(BaseModel):
 
 class LogStreamRequest(BaseModel):
     admin_token: Optional[str] = None
-    log_file: Optional[str] = 'llm-server.log'
+    log_file: Optional[str] = "llm-server.log"
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 5a4ce0f..8a59c8b 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import json
 from typing import Literal
 
@@ -24,7 +25,8 @@ from hugegraph_llm.api.models.rag_requests import (
     RAGRequest,
     GraphConfigRequest,
     LLMConfigRequest,
-    RerankerConfigRequest, GraphRAGRequest,
+    RerankerConfigRequest,
+    GraphRAGRequest,
 )
 from hugegraph_llm.api.models.rag_response import RAGResponse
 from hugegraph_llm.config import llm_settings, huge_settings, prompt
@@ -32,25 +34,34 @@ from hugegraph_llm.utils.log import log
 
 
 def graph_rag_recall(
-        text: str,
-        rerank_method: Literal["bleu", "reranker"],
-        near_neighbor_first: bool,
-        custom_related_information: str
+    query: str,
+    gremlin_tmpl_num: int,
+    with_gremlin_tmpl: bool,
+    answer_prompt: str,  # FIXME: should be used in the query
+    rerank_method: Literal["bleu", "reranker"],
+    near_neighbor_first: bool,
+    custom_related_information: str,
+    gremlin_prompt: str,
 ) -> dict:
     from hugegraph_llm.operators.graph_rag_task import RAGPipeline
+
     rag = RAGPipeline()
 
-    
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb().merge_dedup_rerank(
+    
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
+        with_gremlin_template=with_gremlin_tmpl,
+        num_gremlin_generate_example=gremlin_tmpl_num,
+        gremlin_prompt=gremlin_prompt,
+    ).merge_dedup_rerank(
         rerank_method=rerank_method,
         near_neighbor_first=near_neighbor_first,
         custom_related_information=custom_related_information,
     )
-    context = rag.run(verbose=True, query=text, graph_search=True)
+    context = rag.run(verbose=True, query=query, graph_search=True)
     return context
 
 
 def rag_http_api(
-        router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, 
apply_embedding_conf, apply_reranker_conf
+    router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, 
apply_embedding_conf, apply_reranker_conf
 ):
     @router.post("/rag", status_code=status.HTTP_200_OK)
     def rag_answer_api(req: RAGRequest):
@@ -60,11 +71,15 @@ def rag_http_api(
             req.vector_only,
             req.graph_only,
             req.graph_vector_answer,
+            req.with_gremlin_tmpl,
             req.graph_ratio,
             req.rerank_method,
             req.near_neighbor_first,
             req.custom_priority_info,
-            req.answer_prompt or prompt.answer_prompt
+            req.answer_prompt or prompt.answer_prompt,
+            req.keywords_extract_prompt or prompt.keywords_extract_prompt,
+            req.gremlin_tmpl_num,
+            req.gremlin_prompt or prompt.gremlin_generate_prompt,
         )
         return {
             key: value
@@ -76,14 +91,22 @@ def rag_http_api(
     def graph_rag_recall_api(req: GraphRAGRequest):
         try:
             result = graph_rag_recall(
-                text=req.query,
+                query=req.query,
+                gremlin_tmpl_num=req.gremlin_tmpl_num,
+                with_gremlin_tmpl=req.with_gremlin_tmpl,
+                answer_prompt=req.answer_prompt or prompt.answer_prompt,
                 rerank_method=req.rerank_method,
                 near_neighbor_first=req.near_neighbor_first,
-                custom_related_information=req.custom_priority_info
+                custom_related_information=req.custom_priority_info,
+                gremlin_prompt=req.gremlin_prompt or 
prompt.gremlin_generate_prompt,
             )
 
             if isinstance(result, dict):
-                return {"graph_recall": result}
+                params = ["query", "keywords", "match_vids", 
"graph_result_flag", "gremlin", "graph_result",
+                          "vertex_degree_list"]
+                user_result = {key: result[key] for key in params if key in 
result}
+                return {"graph_recall": user_result}
+            # Note: Maybe only for qianfan/wenxin
             return {"graph_recall": json.dumps(result)}
 
         except TypeError as e:
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
index db6e2f2..3fe6c0f 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
@@ -58,11 +58,11 @@ def authenticate(credentials: HTTPAuthorizationCredentials 
= Depends(sec)):
 
 # pylint: disable=C0301
 def init_rag_ui() -> gr.Interface:
-    with (gr.Blocks(
-            theme="default",
-            title="HugeGraph RAG Platform",
-            css=CSS,
-    ) as hugegraph_llm_ui):
+    with gr.Blocks(
+        theme="default",
+        title="HugeGraph RAG Platform",
+        css=CSS,
+    ) as hugegraph_llm_ui:
         gr.Markdown("# HugeGraph LLM RAG Demo")
 
         """
@@ -93,8 +93,12 @@ def init_rag_ui() -> gr.Interface:
         with gr.Tab(label="1. Build RAG Index 💡"):
             textbox_input_schema, textbox_info_extract_template = 
create_vector_graph_block()
         with gr.Tab(label="2. (Graph)RAG & User Functions 📖"):
-            textbox_inp, textbox_answer_prompt_input, 
textbox_keywords_extract_prompt_input, \
-            textbox_custom_related_information = create_rag_block()
+            (
+                textbox_inp,
+                textbox_answer_prompt_input,
+                textbox_keywords_extract_prompt_input,
+                textbox_custom_related_information,
+            ) = create_rag_block()
         with gr.Tab(label="3. Text2gremlin ⚙️"):
             textbox_gremlin_inp, textbox_gremlin_schema, 
textbox_gremlin_prompt = create_text2gremlin_block()
         with gr.Tab(label="4. Graph Tools 🚧"):
@@ -105,33 +109,46 @@ def init_rag_ui() -> gr.Interface:
         def refresh_ui_config_prompt() -> tuple:
             # we can use its __init__() for in-place reload
             # settings.from_env()
-            huge_settings.__init__() # pylint: disable=C2801
+            huge_settings.__init__()  # pylint: disable=C2801
             prompt.ensure_yaml_file_exists()
             return (
-                huge_settings.graph_ip, huge_settings.graph_port, 
huge_settings.graph_name, huge_settings.graph_user,
-                huge_settings.graph_pwd, huge_settings.graph_space, 
prompt.graph_schema, prompt.extract_graph_prompt,
-                prompt.default_question, prompt.answer_prompt, 
prompt.keywords_extract_prompt,
-                prompt.custom_rerank_info, prompt.default_question, 
huge_settings.graph_name,
-                prompt.gremlin_generate_prompt
+                huge_settings.graph_ip,
+                huge_settings.graph_port,
+                huge_settings.graph_name,
+                huge_settings.graph_user,
+                huge_settings.graph_pwd,
+                huge_settings.graph_space,
+                prompt.graph_schema,
+                prompt.extract_graph_prompt,
+                prompt.default_question,
+                prompt.answer_prompt,
+                prompt.keywords_extract_prompt,
+                prompt.custom_rerank_info,
+                prompt.default_question,
+                huge_settings.graph_name,
+                prompt.gremlin_generate_prompt,
             )
 
-        hugegraph_llm_ui.load(fn=refresh_ui_config_prompt, outputs=[  # 
pylint: disable=E1101
-            textbox_array_graph_config[0],
-            textbox_array_graph_config[1],
-            textbox_array_graph_config[2],
-            textbox_array_graph_config[3],
-            textbox_array_graph_config[4],
-            textbox_array_graph_config[5],
-            textbox_input_schema,
-            textbox_info_extract_template,
-            textbox_inp,
-            textbox_answer_prompt_input,
-            textbox_keywords_extract_prompt_input,
-            textbox_custom_related_information,
-            textbox_gremlin_inp,
-            textbox_gremlin_schema,
-            textbox_gremlin_prompt
-        ])
+        hugegraph_llm_ui.load(  # pylint: disable=E1101
+            fn=refresh_ui_config_prompt,
+            outputs=[
+                textbox_array_graph_config[0],
+                textbox_array_graph_config[1],
+                textbox_array_graph_config[2],
+                textbox_array_graph_config[3],
+                textbox_array_graph_config[4],
+                textbox_array_graph_config[5],
+                textbox_input_schema,
+                textbox_info_extract_template,
+                textbox_inp,
+                textbox_answer_prompt_input,
+                textbox_keywords_extract_prompt_input,
+                textbox_custom_related_information,
+                textbox_gremlin_inp,
+                textbox_gremlin_schema,
+                textbox_gremlin_prompt,
+            ],
+        )
 
     return hugegraph_llm_ui
 
@@ -153,15 +170,17 @@ if __name__ == "__main__":
 
     hugegraph_llm = init_rag_ui()
 
-    rag_http_api(api_auth, rag_answer, apply_graph_config, apply_llm_config, 
apply_embedding_config,
-                 apply_reranker_config)
+    rag_http_api(
+        api_auth, rag_answer, apply_graph_config, apply_llm_config, 
apply_embedding_config, apply_reranker_config
+    )
     admin_http_api(api_auth, log_stream)
 
     app.include_router(api_auth)
 
     # TODO: support multi-user login when need
-    app = gr.mount_gradio_app(app, hugegraph_llm, path="/",
-                              auth=("rag", admin_settings.user_token) if 
auth_enabled else None)
+    app = gr.mount_gradio_app(
+        app, hugegraph_llm, path="/", auth=("rag", admin_settings.user_token) 
if auth_enabled else None
+    )
 
     # TODO: we can't use reload now due to the config 'app' of uvicorn.run
     # ❎:f'{__name__}:app' / rag_web_demo:app / 
hugegraph_llm.demo.rag_web_demo:app
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
index 3edaaf5..070f3b3 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
@@ -30,18 +30,20 @@ from hugegraph_llm.utils.log import log
 
 
 def rag_answer(
-        text: str,
-        raw_answer: bool,
-        vector_only_answer: bool,
-        graph_only_answer: bool,
-        graph_vector_answer: bool,
-        with_gremlin_template: bool,
-        graph_ratio: float,
-        rerank_method: Literal["bleu", "reranker"],
-        near_neighbor_first: bool,
-        custom_related_information: str,
-        answer_prompt: str,
-        keywords_extract_prompt: str,
+    text: str,
+    raw_answer: bool,
+    vector_only_answer: bool,
+    graph_only_answer: bool,
+    graph_vector_answer: bool,
+    with_gremlin_template: bool,
+    graph_ratio: float,
+    rerank_method: Literal["bleu", "reranker"],
+    near_neighbor_first: bool,
+    custom_related_information: str,
+    answer_prompt: str,
+    keywords_extract_prompt: str,
+    gremlin_tmpl_num: Optional[int] = 2,
+    gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt,
 ) -> Tuple:
     """
     Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
@@ -52,15 +54,17 @@ def rag_answer(
     5. Run the pipeline and return the results.
     """
     should_update_prompt = (
-        prompt.default_question != text or
-        prompt.answer_prompt != answer_prompt or
-        prompt.keywords_extract_prompt != keywords_extract_prompt
+        prompt.default_question != text
+        or prompt.answer_prompt != answer_prompt
+        or prompt.keywords_extract_prompt != keywords_extract_prompt
+        or prompt.gremlin_generate_prompt != gremlin_prompt
     )
     if should_update_prompt or prompt.custom_rerank_info != 
custom_related_information:
         prompt.custom_rerank_info = custom_related_information
         prompt.default_question = text
         prompt.answer_prompt = answer_prompt
         prompt.keywords_extract_prompt = keywords_extract_prompt
+        prompt.gremlin_generate_prompt = gremlin_prompt
         prompt.update_yaml_file()
 
     vector_search = vector_only_answer or graph_vector_answer
@@ -74,9 +78,18 @@ def rag_answer(
         rag.query_vector_index()
     if graph_search:
         
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
-            
huge_settings.graph_name).query_graphdb(with_gremlin_template=with_gremlin_template)
+            huge_settings.graph_name
+        ).query_graphdb(
+            with_gremlin_template=with_gremlin_template,
+            num_gremlin_generate_example=gremlin_tmpl_num,
+            gremlin_prompt=gremlin_prompt,
+        )
     # TODO: add more user-defined search strategies
-    rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first, )
+    rag.merge_dedup_rerank(
+        graph_ratio,
+        rerank_method,
+        near_neighbor_first,
+    )
     rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, 
graph_vector_answer, answer_prompt)
 
     try:
@@ -123,6 +136,7 @@ def create_rag_block():
                 graph_vector_radio = gr.Radio(choices=[True, False], 
value=False, label="Graph-Vector Answer")
             with gr.Row():
                 with_gremlin_template_radio = gr.Radio(choices=[True, False], 
value=True, label="With Gremlin Template")
+
             def toggle_slider(enable):
                 return gr.update(interactive=enable)
 
@@ -164,16 +178,18 @@ def create_rag_block():
             near_neighbor_first,
             custom_related_information,
             answer_prompt_input,
-            keywords_extract_prompt_input
+            keywords_extract_prompt_input,
         ],
         outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out],
     )
 
-    gr.Markdown("""## 2. (Batch) Back-testing )
+    gr.Markdown(
+        """## 2. (Batch) Back-testing )
     > 1. Download the template file & fill in the questions you want to test.
     > 2. Upload the file & click the button to generate answers. (Preview 
shows the first 40 lines)
     > 3. The answer options are the same as the above RAG/Q&A frame 
-    """)
+    """
+    )
     tests_df_headers = [
         "Question",
         "Expected Answer",
@@ -214,18 +230,19 @@ def create_rag_block():
         return df
 
     def several_rag_answer(
-            is_raw_answer: bool,
-            is_vector_only_answer: bool,
-            is_graph_only_answer: bool,
-            is_graph_vector_answer: bool,
-            graph_ratio: float,
-            rerank_method: Literal["bleu", "reranker"],
-            near_neighbor_first: bool,
-            custom_related_information: str,
-            answer_prompt: str,
-            keywords_extract_prompt: str,
-            progress=gr.Progress(track_tqdm=True),
-            answer_max_line_count: int = 1,
+        is_raw_answer: bool,
+        is_vector_only_answer: bool,
+        is_graph_only_answer: bool,
+        is_graph_vector_answer: bool,
+        graph_ratio: float,
+        rerank_method: Literal["bleu", "reranker"],
+        near_neighbor_first: bool,
+        with_gremlin_template: bool,
+        custom_related_information: str,
+        answer_prompt: str,
+        keywords_extract_prompt: str,
+        answer_max_line_count: int = 1,
+        progress=gr.Progress(track_tqdm=True),
     ):
         df = pd.read_excel(questions_path, dtype=str)
         total_rows = len(df)
@@ -240,6 +257,7 @@ def create_rag_block():
                 graph_ratio,
                 rerank_method,
                 near_neighbor_first,
+                with_gremlin_template,
                 custom_related_information,
                 answer_prompt,
                 keywords_extract_prompt,
@@ -273,6 +291,7 @@ def create_rag_block():
             graph_ratio,
             rerank_method,
             near_neighbor_first,
+            with_gremlin_template_radio,
             custom_related_information,
             answer_prompt_input,
             keywords_extract_prompt_input,
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 8f5f81d..03ac9ae 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -32,6 +32,7 @@ from hugegraph_llm.operators.index_op.vector_index_query 
import VectorIndexQuery
 from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
 from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract
 from hugegraph_llm.utils.decorators import log_time, log_operator_time, 
record_qps
+from hugegraph_llm.config import prompt
 
 
 class RAGPipeline:
@@ -65,11 +66,11 @@ class RAGPipeline:
         return self
 
     def extract_keywords(
-            self,
-            text: Optional[str] = None,
-            max_keywords: int = 5,
-            language: str = "english",
-            extract_template: Optional[str] = None,
+        self,
+        text: Optional[str] = None,
+        max_keywords: int = 5,
+        language: str = "english",
+        extract_template: Optional[str] = None,
     ):
         """
         Add a keyword extraction operator to the pipeline.
@@ -125,6 +126,8 @@ class RAGPipeline:
         max_e_prop_len: int = 256,
         prop_to_match: Optional[str] = None,
         with_gremlin_template: bool = True,
+        num_gremlin_generate_example: int = 1,
+        gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt,
     ):
         """
         Add a graph RAG query operator to the pipeline.
@@ -137,9 +140,16 @@ class RAGPipeline:
         :return: Self-instance for chaining.
         """
         self._operators.append(
-            GraphRAGQuery(max_deep=max_deep, max_items=max_items, 
max_v_prop_len=max_v_prop_len,
-                          max_e_prop_len=max_e_prop_len, 
prop_to_match=prop_to_match,
-                          with_gremlin_template=with_gremlin_template)
+            GraphRAGQuery(
+                max_deep=max_deep,
+                max_items=max_items,
+                max_v_prop_len=max_v_prop_len,
+                max_e_prop_len=max_e_prop_len,
+                prop_to_match=prop_to_match,
+                with_gremlin_template=with_gremlin_template,
+                num_gremlin_generate_example=num_gremlin_generate_example,
+                gremlin_prompt=gremlin_prompt,
+            )
         )
         return self
 
@@ -151,7 +161,10 @@ class RAGPipeline:
         :return: Self-instance for chaining.
         """
         self._operators.append(
-            VectorIndexQuery(embedding=self._embedding, topk=max_items, )
+            VectorIndexQuery(
+                embedding=self._embedding,
+                topk=max_items,
+            )
         )
         return self
 
@@ -179,12 +192,12 @@ class RAGPipeline:
         return self
 
     def synthesize_answer(
-            self,
-            raw_answer: bool = False,
-            vector_only_answer: bool = True,
-            graph_only_answer: bool = False,
-            graph_vector_answer: bool = False,
-            answer_prompt: Optional[str] = None,
+        self,
+        raw_answer: bool = False,
+        vector_only_answer: bool = True,
+        graph_only_answer: bool = False,
+        graph_vector_answer: bool = False,
+        answer_prompt: Optional[str] = None,
     ):
         """
         Add an answer synthesis operator to the pipeline.
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
index dfbf085..95ce59f 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
@@ -58,8 +58,9 @@ class GremlinGenerator:
         self.operators.append(GremlinExampleIndexQuery(self.embedding, 
num_examples))
         return self
 
-    def gremlin_generate_synthesize(self, schema, gremlin_prompt: 
Optional[str] = None,
-                                    vertices: Optional[List[str]] = None):
+    def gremlin_generate_synthesize(
+        self, schema, gremlin_prompt: Optional[str] = None, vertices: 
Optional[List[str]] = None
+    ):
         self.operators.append(GremlinGenerateSynthesize(self.llm, schema, 
vertices, gremlin_prompt))
         return self
 
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index 33190b7..a4186c8 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -18,7 +18,7 @@
 import json
 from typing import Any, Dict, Optional, List, Set, Tuple
 
-from hugegraph_llm.config import huge_settings
+from hugegraph_llm.config import huge_settings, prompt
 from hugegraph_llm.models.embeddings.base import BaseEmbedding
 from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator
@@ -76,16 +76,17 @@ g.V().has('{prop}', within({keywords}))
 
 class GraphRAGQuery:
     def __init__(
-            self,
-            max_deep: int = 2,
-            max_items: int = int(huge_settings.max_items),
-            prop_to_match: Optional[str] = None,
-            llm: Optional[BaseLLM] = None,
-            embedding: Optional[BaseEmbedding] = None,
-            max_v_prop_len: int = 2048,
-            max_e_prop_len: int = 256,
-            with_gremlin_template: bool = True,
-            num_gremlin_generate_example: int = 1
+        self,
+        max_deep: int = 2,
+        max_items: int = int(huge_settings.max_items),
+        prop_to_match: Optional[str] = None,
+        llm: Optional[BaseLLM] = None,
+        embedding: Optional[BaseEmbedding] = None,
+        max_v_prop_len: int = 2048,
+        max_e_prop_len: int = 256,
+        with_gremlin_template: bool = True,
+        num_gremlin_generate_example: int = 1,
+        gremlin_prompt: Optional[str] = None,
     ):
         self._client = PyHugeClient(
             huge_settings.graph_ip,
@@ -108,6 +109,7 @@ class GraphRAGQuery:
         )
         self._num_gremlin_generate_example = num_gremlin_generate_example
         self._with_gremlin_template = with_gremlin_template
+        self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
         self._init_client(context)
@@ -134,12 +136,8 @@ class GraphRAGQuery:
         self._gremlin_generator.clear()
         
self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example)
         gremlin_response = self._gremlin_generator.gremlin_generate_synthesize(
-            context["simple_schema"],
-            vertices=vertices,
-        ).run(
-            query=query,
-            query_embedding=query_embedding
-        )
+            context["simple_schema"], vertices=vertices, 
gremlin_prompt=self._gremlin_prompt
+        ).run(query=query, query_embedding=query_embedding)
         if self._with_gremlin_template:
             gremlin = gremlin_response["result"]
         else:
@@ -154,10 +152,9 @@ class GraphRAGQuery:
             if context["graph_result"]:
                 context["graph_result_flag"] = 1
                 context["graph_context_head"] = (
-                    f"The following are graph query result "
-                    f"from gremlin query `{gremlin}`.\n"
+                    f"The following are graph query result " f"from gremlin 
query `{gremlin}`.\n"
                 )
-        except Exception as e: # pylint: disable=broad-except
+        except Exception as e:  # pylint: disable=broad-except
             log.error(e)
             context["graph_result"] = ""
         return context
@@ -285,8 +282,9 @@ class GraphRAGQuery:
 
         return subgraph, vertex_degree_list, subgraph_with_degree
 
-    def _process_path(self, path: Any, use_id_to_match: bool, v_cache: 
Set[str],
-                      e_cache: Set[Tuple[str, str, str]]) -> Tuple[str, 
List[str]]:
+    def _process_path(
+        self, path: Any, use_id_to_match: bool, v_cache: Set[str], e_cache: 
Set[Tuple[str, str, str]]
+    ) -> Tuple[str, List[str]]:
         flat_rel = ""
         raw_flat_rel = path["objects"]
         assert len(raw_flat_rel) % 2 == 1, "The length of raw_flat_rel should 
be odd."
@@ -300,8 +298,7 @@ class GraphRAGQuery:
             if i % 2 == 0:
                 # Process each vertex
                 flat_rel, prior_edge_str_len, depth = self._process_vertex(
-                    item, flat_rel, node_cache, prior_edge_str_len, depth, 
nodes_with_degree, use_id_to_match,
-                    v_cache
+                    item, flat_rel, node_cache, prior_edge_str_len, depth, 
nodes_with_degree, use_id_to_match, v_cache
                 )
             else:
                 # Process each edge
@@ -311,17 +308,24 @@ class GraphRAGQuery:
 
         return flat_rel, nodes_with_degree
 
-    def _process_vertex(self, item: Any, flat_rel: str, node_cache: Set[str],
-                        prior_edge_str_len: int, depth: int, 
nodes_with_degree: List[str],
-                        use_id_to_match: bool, v_cache: Set[str]) -> 
Tuple[str, int, int]:
+    def _process_vertex(
+        self,
+        item: Any,
+        flat_rel: str,
+        node_cache: Set[str],
+        prior_edge_str_len: int,
+        depth: int,
+        nodes_with_degree: List[str],
+        use_id_to_match: bool,
+        v_cache: Set[str],
+    ) -> Tuple[str, int, int]:
         matched_str = item["id"] if use_id_to_match else 
item["props"][self._prop_to_match]
         if matched_str in node_cache:
             flat_rel = flat_rel[:-prior_edge_str_len]
             return flat_rel, prior_edge_str_len, depth
 
         node_cache.add(matched_str)
-        props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}"
-                              for k, v in item["props"].items() if v)
+        props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}" for 
k, v in item["props"].items() if v)
 
         # TODO: we may remove label id or replace with label name
         if matched_str in v_cache:
@@ -335,20 +339,27 @@ class GraphRAGQuery:
         depth += 1
         return flat_rel, prior_edge_str_len, depth
 
-    def _process_edge(self, item: Any, path_str: str, raw_flat_rel: List[Any], 
i: int, use_id_to_match: bool,
-                      e_cache: Set[Tuple[str, str, str]]) -> Tuple[str, int]:
-        props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}"
-                              for k, v in item["props"].items() if v)
+    def _process_edge(
+        self,
+        item: Any,
+        path_str: str,
+        raw_flat_rel: List[Any],
+        i: int,
+        use_id_to_match: bool,
+        e_cache: Set[Tuple[str, str, str]],
+    ) -> Tuple[str, int]:
+        props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}" for 
k, v in item["props"].items() if v)
         props_str = f"{{{props_str}}}" if props_str else ""
-        prev_matched_str = raw_flat_rel[i - 1]["id"] if use_id_to_match else (
-            raw_flat_rel)[i - 1]["props"][self._prop_to_match]
+        prev_matched_str = (
+            raw_flat_rel[i - 1]["id"] if use_id_to_match else (raw_flat_rel)[i 
- 1]["props"][self._prop_to_match]
+        )
 
-        edge_key = (item['inV'], item['label'], item['outV'])
+        edge_key = (item["inV"], item["label"], item["outV"])
         if edge_key not in e_cache:
             e_cache.add(edge_key)
             edge_label = f"{item['label']}{props_str}"
         else:
-            edge_label = item['label']
+            edge_label = item["label"]
 
         edge_str = f"--[{edge_label}]-->" if item["outV"] == prev_matched_str 
else f"<--[{edge_label}]--"
         path_str += edge_str
@@ -365,8 +376,8 @@ class GraphRAGQuery:
         schema = self._get_graph_schema()
         vertex_props_str, edge_props_str = schema.split("\n")[:2]
         # TODO: rename to vertex (also need update in the schema)
-        vertex_props_str = vertex_props_str[len("Vertex properties: 
"):].strip("[").strip("]")
-        edge_props_str = edge_props_str[len("Edge properties: 
"):].strip("[").strip("]")
+        vertex_props_str = vertex_props_str[len("Vertex properties: ") 
:].strip("[").strip("]")
+        edge_props_str = edge_props_str[len("Edge properties: ") 
:].strip("[").strip("]")
         vertex_labels = self._extract_label_names(vertex_props_str)
         edge_labels = self._extract_label_names(edge_props_str)
         return vertex_labels, edge_labels

Reply via email to