This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 9018271 fix(llm): support choose template in /rag & /rag/graph api
(#135)
9018271 is described below
commit 90182713537fd88946b76ec4a05c9ff44f1549d9
Author: HaoJin Yang <[email protected]>
AuthorDate: Thu Dec 19 12:58:32 2024 +0800
fix(llm): support choose template in /rag & /rag/graph api (#135)
Mark 3 TODOs:
1. prompt not work in `/rag/graph`
2. prompt-update not fresh in `/rag`
3. prompt-update not work in `/rag/graph`
---------
Co-authored-by: imbajin <[email protected]>
---
.../src/hugegraph_llm/api/models/rag_requests.py | 56 +++++++++-----
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 47 +++++++++---
.../src/hugegraph_llm/demo/rag_demo/app.py | 87 ++++++++++++---------
.../src/hugegraph_llm/demo/rag_demo/rag_block.py | 83 ++++++++++++--------
.../src/hugegraph_llm/operators/graph_rag_task.py | 43 +++++++----
.../operators/gremlin_generate_task.py | 5 +-
.../operators/hugegraph_op/graph_rag_query.py | 89 ++++++++++++----------
7 files changed, 255 insertions(+), 155 deletions(-)
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 9413890..1eaa1e2 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -17,33 +17,47 @@
from typing import Optional, Literal
+from fastapi import Query
from pydantic import BaseModel
+from hugegraph_llm.config import prompt
+
class RAGRequest(BaseModel):
- query: str = ""
- raw_answer: bool = False
- vector_only: bool = False
- graph_only: bool = False
- graph_vector_answer: bool = False
- graph_ratio: float = 0.5
- rerank_method: Literal["bleu", "reranker"] = "bleu"
- near_neighbor_first: bool = False
- custom_priority_info: str = ""
- answer_prompt: Optional[str] = None
+ query: str = Query("", description="Query you want to ask")
+ raw_answer: bool = Query(False, description="Use LLM to generate answer
directly")
+ vector_only: bool = Query(False, description="Use LLM to generate answer
with vector")
+ graph_only: bool = Query(True, description="Use LLM to generate answer
with graph RAG only")
+ graph_vector_answer: bool = Query(False, description="Use LLM to generate
answer with vector & GraphRAG")
+ graph_ratio: float = Query(0.5, description="The ratio of GraphRAG ans &
vector ans")
+ rerank_method: Literal["bleu", "reranker"] = Query("bleu",
description="Method to rerank the results.")
+ near_neighbor_first: bool = Query(False, description="Prioritize near
neighbors in the search results.")
+ with_gremlin_tmpl: bool = Query(True, description="Use example template in
text2gremlin")
+ custom_priority_info: str = Query("", description="Custom information to
prioritize certain results.")
+ answer_prompt: Optional[str] = Query(prompt.answer_prompt,
description="Prompt to guide the answer generation.")
+ keywords_extract_prompt: Optional[str] = Query(
+ prompt.keywords_extract_prompt,
+ description="Prompt for extracting keywords from query.",
+ )
+ gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates
to use.")
+ gremlin_prompt: Optional[str] = Query(
+ prompt.gremlin_generate_prompt,
+ description="Prompt for the Text2Gremlin query.",
+ )
class GraphRAGRequest(BaseModel):
- query: str = ""
- raw_answer: bool = True
- vector_only: bool = False
- graph_only: bool = False
- graph_vector_answer: bool = False
- graph_ratio: float = 0.5
- rerank_method: Literal["bleu", "reranker"] = "bleu"
- near_neighbor_first: bool = False
- custom_priority_info: str = ""
- answer_prompt: Optional[str] = None
+ query: str = Query("", description="Query you want to ask")
+ gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates
to use.")
+ with_gremlin_tmpl: bool = Query(True, description="Use example template in
text2gremlin")
+ answer_prompt: Optional[str] = Query(prompt.answer_prompt,
description="Prompt to guide the answer generation.")
+ rerank_method: Literal["bleu", "reranker"] = Query("bleu",
description="Method to rerank the results.")
+ near_neighbor_first: bool = Query(False, description="Prioritize near
neighbors in the search results.")
+ custom_priority_info: str = Query("", description="Custom information to
prioritize certain results.")
+ gremlin_prompt: Optional[str] = Query(
+ prompt.gremlin_generate_prompt,
+ description="Prompt for the Text2Gremlin query.",
+ )
class GraphConfigRequest(BaseModel):
@@ -80,4 +94,4 @@ class RerankerConfigRequest(BaseModel):
class LogStreamRequest(BaseModel):
admin_token: Optional[str] = None
- log_file: Optional[str] = 'llm-server.log'
+ log_file: Optional[str] = "llm-server.log"
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 5a4ce0f..8a59c8b 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import json
from typing import Literal
@@ -24,7 +25,8 @@ from hugegraph_llm.api.models.rag_requests import (
RAGRequest,
GraphConfigRequest,
LLMConfigRequest,
- RerankerConfigRequest, GraphRAGRequest,
+ RerankerConfigRequest,
+ GraphRAGRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import llm_settings, huge_settings, prompt
@@ -32,25 +34,34 @@ from hugegraph_llm.utils.log import log
def graph_rag_recall(
- text: str,
- rerank_method: Literal["bleu", "reranker"],
- near_neighbor_first: bool,
- custom_related_information: str
+ query: str,
+ gremlin_tmpl_num: int,
+ with_gremlin_tmpl: bool,
+ answer_prompt: str, # FIXME: should be used in the query
+ rerank_method: Literal["bleu", "reranker"],
+ near_neighbor_first: bool,
+ custom_related_information: str,
+ gremlin_prompt: str,
) -> dict:
from hugegraph_llm.operators.graph_rag_task import RAGPipeline
+
rag = RAGPipeline()
-
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb().merge_dedup_rerank(
+
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
+ with_gremlin_template=with_gremlin_tmpl,
+ num_gremlin_generate_example=gremlin_tmpl_num,
+ gremlin_prompt=gremlin_prompt,
+ ).merge_dedup_rerank(
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
)
- context = rag.run(verbose=True, query=text, graph_search=True)
+ context = rag.run(verbose=True, query=query, graph_search=True)
return context
def rag_http_api(
- router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf,
apply_embedding_conf, apply_reranker_conf
+ router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf,
apply_embedding_conf, apply_reranker_conf
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
@@ -60,11 +71,15 @@ def rag_http_api(
req.vector_only,
req.graph_only,
req.graph_vector_answer,
+ req.with_gremlin_tmpl,
req.graph_ratio,
req.rerank_method,
req.near_neighbor_first,
req.custom_priority_info,
- req.answer_prompt or prompt.answer_prompt
+ req.answer_prompt or prompt.answer_prompt,
+ req.keywords_extract_prompt or prompt.keywords_extract_prompt,
+ req.gremlin_tmpl_num,
+ req.gremlin_prompt or prompt.gremlin_generate_prompt,
)
return {
key: value
@@ -76,14 +91,22 @@ def rag_http_api(
def graph_rag_recall_api(req: GraphRAGRequest):
try:
result = graph_rag_recall(
- text=req.query,
+ query=req.query,
+ gremlin_tmpl_num=req.gremlin_tmpl_num,
+ with_gremlin_tmpl=req.with_gremlin_tmpl,
+ answer_prompt=req.answer_prompt or prompt.answer_prompt,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
- custom_related_information=req.custom_priority_info
+ custom_related_information=req.custom_priority_info,
+ gremlin_prompt=req.gremlin_prompt or
prompt.gremlin_generate_prompt,
)
if isinstance(result, dict):
- return {"graph_recall": result}
+ params = ["query", "keywords", "match_vids",
"graph_result_flag", "gremlin", "graph_result",
+ "vertex_degree_list"]
+ user_result = {key: result[key] for key in params if key in
result}
+ return {"graph_recall": user_result}
+ # Note: Maybe only for qianfan/wenxin
return {"graph_recall": json.dumps(result)}
except TypeError as e:
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
index db6e2f2..3fe6c0f 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
@@ -58,11 +58,11 @@ def authenticate(credentials: HTTPAuthorizationCredentials
= Depends(sec)):
# pylint: disable=C0301
def init_rag_ui() -> gr.Interface:
- with (gr.Blocks(
- theme="default",
- title="HugeGraph RAG Platform",
- css=CSS,
- ) as hugegraph_llm_ui):
+ with gr.Blocks(
+ theme="default",
+ title="HugeGraph RAG Platform",
+ css=CSS,
+ ) as hugegraph_llm_ui:
gr.Markdown("# HugeGraph LLM RAG Demo")
"""
@@ -93,8 +93,12 @@ def init_rag_ui() -> gr.Interface:
with gr.Tab(label="1. Build RAG Index 💡"):
textbox_input_schema, textbox_info_extract_template =
create_vector_graph_block()
with gr.Tab(label="2. (Graph)RAG & User Functions 📖"):
- textbox_inp, textbox_answer_prompt_input,
textbox_keywords_extract_prompt_input, \
- textbox_custom_related_information = create_rag_block()
+ (
+ textbox_inp,
+ textbox_answer_prompt_input,
+ textbox_keywords_extract_prompt_input,
+ textbox_custom_related_information,
+ ) = create_rag_block()
with gr.Tab(label="3. Text2gremlin ⚙️"):
textbox_gremlin_inp, textbox_gremlin_schema,
textbox_gremlin_prompt = create_text2gremlin_block()
with gr.Tab(label="4. Graph Tools 🚧"):
@@ -105,33 +109,46 @@ def init_rag_ui() -> gr.Interface:
def refresh_ui_config_prompt() -> tuple:
# we can use its __init__() for in-place reload
# settings.from_env()
- huge_settings.__init__() # pylint: disable=C2801
+ huge_settings.__init__() # pylint: disable=C2801
prompt.ensure_yaml_file_exists()
return (
- huge_settings.graph_ip, huge_settings.graph_port,
huge_settings.graph_name, huge_settings.graph_user,
- huge_settings.graph_pwd, huge_settings.graph_space,
prompt.graph_schema, prompt.extract_graph_prompt,
- prompt.default_question, prompt.answer_prompt,
prompt.keywords_extract_prompt,
- prompt.custom_rerank_info, prompt.default_question,
huge_settings.graph_name,
- prompt.gremlin_generate_prompt
+ huge_settings.graph_ip,
+ huge_settings.graph_port,
+ huge_settings.graph_name,
+ huge_settings.graph_user,
+ huge_settings.graph_pwd,
+ huge_settings.graph_space,
+ prompt.graph_schema,
+ prompt.extract_graph_prompt,
+ prompt.default_question,
+ prompt.answer_prompt,
+ prompt.keywords_extract_prompt,
+ prompt.custom_rerank_info,
+ prompt.default_question,
+ huge_settings.graph_name,
+ prompt.gremlin_generate_prompt,
)
- hugegraph_llm_ui.load(fn=refresh_ui_config_prompt, outputs=[ #
pylint: disable=E1101
- textbox_array_graph_config[0],
- textbox_array_graph_config[1],
- textbox_array_graph_config[2],
- textbox_array_graph_config[3],
- textbox_array_graph_config[4],
- textbox_array_graph_config[5],
- textbox_input_schema,
- textbox_info_extract_template,
- textbox_inp,
- textbox_answer_prompt_input,
- textbox_keywords_extract_prompt_input,
- textbox_custom_related_information,
- textbox_gremlin_inp,
- textbox_gremlin_schema,
- textbox_gremlin_prompt
- ])
+ hugegraph_llm_ui.load( # pylint: disable=E1101
+ fn=refresh_ui_config_prompt,
+ outputs=[
+ textbox_array_graph_config[0],
+ textbox_array_graph_config[1],
+ textbox_array_graph_config[2],
+ textbox_array_graph_config[3],
+ textbox_array_graph_config[4],
+ textbox_array_graph_config[5],
+ textbox_input_schema,
+ textbox_info_extract_template,
+ textbox_inp,
+ textbox_answer_prompt_input,
+ textbox_keywords_extract_prompt_input,
+ textbox_custom_related_information,
+ textbox_gremlin_inp,
+ textbox_gremlin_schema,
+ textbox_gremlin_prompt,
+ ],
+ )
return hugegraph_llm_ui
@@ -153,15 +170,17 @@ if __name__ == "__main__":
hugegraph_llm = init_rag_ui()
- rag_http_api(api_auth, rag_answer, apply_graph_config, apply_llm_config,
apply_embedding_config,
- apply_reranker_config)
+ rag_http_api(
+ api_auth, rag_answer, apply_graph_config, apply_llm_config,
apply_embedding_config, apply_reranker_config
+ )
admin_http_api(api_auth, log_stream)
app.include_router(api_auth)
# TODO: support multi-user login when need
- app = gr.mount_gradio_app(app, hugegraph_llm, path="/",
- auth=("rag", admin_settings.user_token) if
auth_enabled else None)
+ app = gr.mount_gradio_app(
+ app, hugegraph_llm, path="/", auth=("rag", admin_settings.user_token)
if auth_enabled else None
+ )
# TODO: we can't use reload now due to the config 'app' of uvicorn.run
# ❎:f'{__name__}:app' / rag_web_demo:app /
hugegraph_llm.demo.rag_web_demo:app
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
index 3edaaf5..070f3b3 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
@@ -30,18 +30,20 @@ from hugegraph_llm.utils.log import log
def rag_answer(
- text: str,
- raw_answer: bool,
- vector_only_answer: bool,
- graph_only_answer: bool,
- graph_vector_answer: bool,
- with_gremlin_template: bool,
- graph_ratio: float,
- rerank_method: Literal["bleu", "reranker"],
- near_neighbor_first: bool,
- custom_related_information: str,
- answer_prompt: str,
- keywords_extract_prompt: str,
+ text: str,
+ raw_answer: bool,
+ vector_only_answer: bool,
+ graph_only_answer: bool,
+ graph_vector_answer: bool,
+ with_gremlin_template: bool,
+ graph_ratio: float,
+ rerank_method: Literal["bleu", "reranker"],
+ near_neighbor_first: bool,
+ custom_related_information: str,
+ answer_prompt: str,
+ keywords_extract_prompt: str,
+ gremlin_tmpl_num: Optional[int] = 2,
+ gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt,
) -> Tuple:
"""
Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
@@ -52,15 +54,17 @@ def rag_answer(
5. Run the pipeline and return the results.
"""
should_update_prompt = (
- prompt.default_question != text or
- prompt.answer_prompt != answer_prompt or
- prompt.keywords_extract_prompt != keywords_extract_prompt
+ prompt.default_question != text
+ or prompt.answer_prompt != answer_prompt
+ or prompt.keywords_extract_prompt != keywords_extract_prompt
+ or prompt.gremlin_generate_prompt != gremlin_prompt
)
if should_update_prompt or prompt.custom_rerank_info !=
custom_related_information:
prompt.custom_rerank_info = custom_related_information
prompt.default_question = text
prompt.answer_prompt = answer_prompt
prompt.keywords_extract_prompt = keywords_extract_prompt
+ prompt.gremlin_generate_prompt = gremlin_prompt
prompt.update_yaml_file()
vector_search = vector_only_answer or graph_vector_answer
@@ -74,9 +78,18 @@ def rag_answer(
rag.query_vector_index()
if graph_search:
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
-
huge_settings.graph_name).query_graphdb(with_gremlin_template=with_gremlin_template)
+ huge_settings.graph_name
+ ).query_graphdb(
+ with_gremlin_template=with_gremlin_template,
+ num_gremlin_generate_example=gremlin_tmpl_num,
+ gremlin_prompt=gremlin_prompt,
+ )
# TODO: add more user-defined search strategies
- rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first, )
+ rag.merge_dedup_rerank(
+ graph_ratio,
+ rerank_method,
+ near_neighbor_first,
+ )
rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer,
graph_vector_answer, answer_prompt)
try:
@@ -123,6 +136,7 @@ def create_rag_block():
graph_vector_radio = gr.Radio(choices=[True, False],
value=False, label="Graph-Vector Answer")
with gr.Row():
with_gremlin_template_radio = gr.Radio(choices=[True, False],
value=True, label="With Gremlin Template")
+
def toggle_slider(enable):
return gr.update(interactive=enable)
@@ -164,16 +178,18 @@ def create_rag_block():
near_neighbor_first,
custom_related_information,
answer_prompt_input,
- keywords_extract_prompt_input
+ keywords_extract_prompt_input,
],
outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out],
)
- gr.Markdown("""## 2. (Batch) Back-testing )
+ gr.Markdown(
+ """## 2. (Batch) Back-testing )
> 1. Download the template file & fill in the questions you want to test.
> 2. Upload the file & click the button to generate answers. (Preview
shows the first 40 lines)
> 3. The answer options are the same as the above RAG/Q&A frame
- """)
+ """
+ )
tests_df_headers = [
"Question",
"Expected Answer",
@@ -214,18 +230,19 @@ def create_rag_block():
return df
def several_rag_answer(
- is_raw_answer: bool,
- is_vector_only_answer: bool,
- is_graph_only_answer: bool,
- is_graph_vector_answer: bool,
- graph_ratio: float,
- rerank_method: Literal["bleu", "reranker"],
- near_neighbor_first: bool,
- custom_related_information: str,
- answer_prompt: str,
- keywords_extract_prompt: str,
- progress=gr.Progress(track_tqdm=True),
- answer_max_line_count: int = 1,
+ is_raw_answer: bool,
+ is_vector_only_answer: bool,
+ is_graph_only_answer: bool,
+ is_graph_vector_answer: bool,
+ graph_ratio: float,
+ rerank_method: Literal["bleu", "reranker"],
+ near_neighbor_first: bool,
+ with_gremlin_template: bool,
+ custom_related_information: str,
+ answer_prompt: str,
+ keywords_extract_prompt: str,
+ answer_max_line_count: int = 1,
+ progress=gr.Progress(track_tqdm=True),
):
df = pd.read_excel(questions_path, dtype=str)
total_rows = len(df)
@@ -240,6 +257,7 @@ def create_rag_block():
graph_ratio,
rerank_method,
near_neighbor_first,
+ with_gremlin_template,
custom_related_information,
answer_prompt,
keywords_extract_prompt,
@@ -273,6 +291,7 @@ def create_rag_block():
graph_ratio,
rerank_method,
near_neighbor_first,
+ with_gremlin_template_radio,
custom_related_information,
answer_prompt_input,
keywords_extract_prompt_input,
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 8f5f81d..03ac9ae 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -32,6 +32,7 @@ from hugegraph_llm.operators.index_op.vector_index_query
import VectorIndexQuery
from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract
from hugegraph_llm.utils.decorators import log_time, log_operator_time,
record_qps
+from hugegraph_llm.config import prompt
class RAGPipeline:
@@ -65,11 +66,11 @@ class RAGPipeline:
return self
def extract_keywords(
- self,
- text: Optional[str] = None,
- max_keywords: int = 5,
- language: str = "english",
- extract_template: Optional[str] = None,
+ self,
+ text: Optional[str] = None,
+ max_keywords: int = 5,
+ language: str = "english",
+ extract_template: Optional[str] = None,
):
"""
Add a keyword extraction operator to the pipeline.
@@ -125,6 +126,8 @@ class RAGPipeline:
max_e_prop_len: int = 256,
prop_to_match: Optional[str] = None,
with_gremlin_template: bool = True,
+ num_gremlin_generate_example: int = 1,
+ gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt,
):
"""
Add a graph RAG query operator to the pipeline.
@@ -137,9 +140,16 @@ class RAGPipeline:
:return: Self-instance for chaining.
"""
self._operators.append(
- GraphRAGQuery(max_deep=max_deep, max_items=max_items,
max_v_prop_len=max_v_prop_len,
- max_e_prop_len=max_e_prop_len,
prop_to_match=prop_to_match,
- with_gremlin_template=with_gremlin_template)
+ GraphRAGQuery(
+ max_deep=max_deep,
+ max_items=max_items,
+ max_v_prop_len=max_v_prop_len,
+ max_e_prop_len=max_e_prop_len,
+ prop_to_match=prop_to_match,
+ with_gremlin_template=with_gremlin_template,
+ num_gremlin_generate_example=num_gremlin_generate_example,
+ gremlin_prompt=gremlin_prompt,
+ )
)
return self
@@ -151,7 +161,10 @@ class RAGPipeline:
:return: Self-instance for chaining.
"""
self._operators.append(
- VectorIndexQuery(embedding=self._embedding, topk=max_items, )
+ VectorIndexQuery(
+ embedding=self._embedding,
+ topk=max_items,
+ )
)
return self
@@ -179,12 +192,12 @@ class RAGPipeline:
return self
def synthesize_answer(
- self,
- raw_answer: bool = False,
- vector_only_answer: bool = True,
- graph_only_answer: bool = False,
- graph_vector_answer: bool = False,
- answer_prompt: Optional[str] = None,
+ self,
+ raw_answer: bool = False,
+ vector_only_answer: bool = True,
+ graph_only_answer: bool = False,
+ graph_vector_answer: bool = False,
+ answer_prompt: Optional[str] = None,
):
"""
Add an answer synthesis operator to the pipeline.
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
index dfbf085..95ce59f 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
@@ -58,8 +58,9 @@ class GremlinGenerator:
self.operators.append(GremlinExampleIndexQuery(self.embedding,
num_examples))
return self
- def gremlin_generate_synthesize(self, schema, gremlin_prompt:
Optional[str] = None,
- vertices: Optional[List[str]] = None):
+ def gremlin_generate_synthesize(
+ self, schema, gremlin_prompt: Optional[str] = None, vertices:
Optional[List[str]] = None
+ ):
self.operators.append(GremlinGenerateSynthesize(self.llm, schema,
vertices, gremlin_prompt))
return self
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index 33190b7..a4186c8 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -18,7 +18,7 @@
import json
from typing import Any, Dict, Optional, List, Set, Tuple
-from hugegraph_llm.config import huge_settings
+from hugegraph_llm.config import huge_settings, prompt
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator
@@ -76,16 +76,17 @@ g.V().has('{prop}', within({keywords}))
class GraphRAGQuery:
def __init__(
- self,
- max_deep: int = 2,
- max_items: int = int(huge_settings.max_items),
- prop_to_match: Optional[str] = None,
- llm: Optional[BaseLLM] = None,
- embedding: Optional[BaseEmbedding] = None,
- max_v_prop_len: int = 2048,
- max_e_prop_len: int = 256,
- with_gremlin_template: bool = True,
- num_gremlin_generate_example: int = 1
+ self,
+ max_deep: int = 2,
+ max_items: int = int(huge_settings.max_items),
+ prop_to_match: Optional[str] = None,
+ llm: Optional[BaseLLM] = None,
+ embedding: Optional[BaseEmbedding] = None,
+ max_v_prop_len: int = 2048,
+ max_e_prop_len: int = 256,
+ with_gremlin_template: bool = True,
+ num_gremlin_generate_example: int = 1,
+ gremlin_prompt: Optional[str] = None,
):
self._client = PyHugeClient(
huge_settings.graph_ip,
@@ -108,6 +109,7 @@ class GraphRAGQuery:
)
self._num_gremlin_generate_example = num_gremlin_generate_example
self._with_gremlin_template = with_gremlin_template
+ self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
self._init_client(context)
@@ -134,12 +136,8 @@ class GraphRAGQuery:
self._gremlin_generator.clear()
self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example)
gremlin_response = self._gremlin_generator.gremlin_generate_synthesize(
- context["simple_schema"],
- vertices=vertices,
- ).run(
- query=query,
- query_embedding=query_embedding
- )
+ context["simple_schema"], vertices=vertices,
gremlin_prompt=self._gremlin_prompt
+ ).run(query=query, query_embedding=query_embedding)
if self._with_gremlin_template:
gremlin = gremlin_response["result"]
else:
@@ -154,10 +152,9 @@ class GraphRAGQuery:
if context["graph_result"]:
context["graph_result_flag"] = 1
context["graph_context_head"] = (
- f"The following are graph query result "
- f"from gremlin query `{gremlin}`.\n"
+ f"The following are graph query result " f"from gremlin
query `{gremlin}`.\n"
)
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
log.error(e)
context["graph_result"] = ""
return context
@@ -285,8 +282,9 @@ class GraphRAGQuery:
return subgraph, vertex_degree_list, subgraph_with_degree
- def _process_path(self, path: Any, use_id_to_match: bool, v_cache:
Set[str],
- e_cache: Set[Tuple[str, str, str]]) -> Tuple[str,
List[str]]:
+ def _process_path(
+ self, path: Any, use_id_to_match: bool, v_cache: Set[str], e_cache:
Set[Tuple[str, str, str]]
+ ) -> Tuple[str, List[str]]:
flat_rel = ""
raw_flat_rel = path["objects"]
assert len(raw_flat_rel) % 2 == 1, "The length of raw_flat_rel should
be odd."
@@ -300,8 +298,7 @@ class GraphRAGQuery:
if i % 2 == 0:
# Process each vertex
flat_rel, prior_edge_str_len, depth = self._process_vertex(
- item, flat_rel, node_cache, prior_edge_str_len, depth,
nodes_with_degree, use_id_to_match,
- v_cache
+ item, flat_rel, node_cache, prior_edge_str_len, depth,
nodes_with_degree, use_id_to_match, v_cache
)
else:
# Process each edge
@@ -311,17 +308,24 @@ class GraphRAGQuery:
return flat_rel, nodes_with_degree
- def _process_vertex(self, item: Any, flat_rel: str, node_cache: Set[str],
- prior_edge_str_len: int, depth: int,
nodes_with_degree: List[str],
- use_id_to_match: bool, v_cache: Set[str]) ->
Tuple[str, int, int]:
+ def _process_vertex(
+ self,
+ item: Any,
+ flat_rel: str,
+ node_cache: Set[str],
+ prior_edge_str_len: int,
+ depth: int,
+ nodes_with_degree: List[str],
+ use_id_to_match: bool,
+ v_cache: Set[str],
+ ) -> Tuple[str, int, int]:
matched_str = item["id"] if use_id_to_match else
item["props"][self._prop_to_match]
if matched_str in node_cache:
flat_rel = flat_rel[:-prior_edge_str_len]
return flat_rel, prior_edge_str_len, depth
node_cache.add(matched_str)
- props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}"
- for k, v in item["props"].items() if v)
+ props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}" for
k, v in item["props"].items() if v)
# TODO: we may remove label id or replace with label name
if matched_str in v_cache:
@@ -335,20 +339,27 @@ class GraphRAGQuery:
depth += 1
return flat_rel, prior_edge_str_len, depth
- def _process_edge(self, item: Any, path_str: str, raw_flat_rel: List[Any],
i: int, use_id_to_match: bool,
- e_cache: Set[Tuple[str, str, str]]) -> Tuple[str, int]:
- props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}"
- for k, v in item["props"].items() if v)
+ def _process_edge(
+ self,
+ item: Any,
+ path_str: str,
+ raw_flat_rel: List[Any],
+ i: int,
+ use_id_to_match: bool,
+ e_cache: Set[Tuple[str, str, str]],
+ ) -> Tuple[str, int]:
+ props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}" for
k, v in item["props"].items() if v)
props_str = f"{{{props_str}}}" if props_str else ""
- prev_matched_str = raw_flat_rel[i - 1]["id"] if use_id_to_match else (
- raw_flat_rel)[i - 1]["props"][self._prop_to_match]
+ prev_matched_str = (
+ raw_flat_rel[i - 1]["id"] if use_id_to_match else (raw_flat_rel)[i
- 1]["props"][self._prop_to_match]
+ )
- edge_key = (item['inV'], item['label'], item['outV'])
+ edge_key = (item["inV"], item["label"], item["outV"])
if edge_key not in e_cache:
e_cache.add(edge_key)
edge_label = f"{item['label']}{props_str}"
else:
- edge_label = item['label']
+ edge_label = item["label"]
edge_str = f"--[{edge_label}]-->" if item["outV"] == prev_matched_str
else f"<--[{edge_label}]--"
path_str += edge_str
@@ -365,8 +376,8 @@ class GraphRAGQuery:
schema = self._get_graph_schema()
vertex_props_str, edge_props_str = schema.split("\n")[:2]
# TODO: rename to vertex (also need update in the schema)
- vertex_props_str = vertex_props_str[len("Vertex properties:
"):].strip("[").strip("]")
- edge_props_str = edge_props_str[len("Edge properties:
"):].strip("[").strip("]")
+ vertex_props_str = vertex_props_str[len("Vertex properties: ")
:].strip("[").strip("]")
+ edge_props_str = edge_props_str[len("Edge properties: ")
:].strip("[").strip("]")
vertex_labels = self._extract_label_names(vertex_props_str)
edge_labels = self._extract_label_names(edge_props_str)
return vertex_labels, edge_labels