This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new d2fcfdb fix(llm): update prompt to fit prefix cache (#137)
d2fcfdb is described below
commit d2fcfdb2348baea1d5ebf526d6e687271d5cda03
Author: HaoJin Yang <[email protected]>
AuthorDate: Mon Dec 23 00:32:45 2024 +0800
fix(llm): update prompt to fit prefix cache (#137)
* fix vid not readable for LLM in gremlin prompt
---------
Co-authored-by: imbajin <[email protected]>
---
.../src/hugegraph_llm/api/models/rag_requests.py | 2 +-
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 88 ++++++++----------
.../src/hugegraph_llm/config/prompt_config.py | 79 +++++++++-------
.../src/hugegraph_llm/demo/rag_demo/app.py | 10 +-
.../demo/rag_demo/text2gremlin_block.py | 102 ++++++++++++++-------
.../operators/llm_op/gremlin_generate.py | 14 +--
6 files changed, 166 insertions(+), 129 deletions(-)
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 1eaa1e2..489ce0a 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -46,11 +46,11 @@ class RAGRequest(BaseModel):
)
+# TODO: import the default value of prompt.* dynamically
class GraphRAGRequest(BaseModel):
query: str = Query("", description="Query you want to ask")
gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates
to use.")
with_gremlin_tmpl: bool = Query(True, description="Use example template in
text2gremlin")
- answer_prompt: Optional[str] = Query(prompt.answer_prompt,
description="Prompt to guide the answer generation.")
rerank_method: Literal["bleu", "reranker"] = Query("bleu",
description="Method to rerank the results.")
near_neighbor_first: bool = Query(False, description="Prioritize near
neighbors in the search results.")
custom_priority_info: str = Query("", description="Custom information to
prioritize certain results.")
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 8a59c8b..4036496 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -16,7 +16,6 @@
# under the License.
import json
-from typing import Literal
from fastapi import status, APIRouter, HTTPException
@@ -29,72 +28,52 @@ from hugegraph_llm.api.models.rag_requests import (
GraphRAGRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
-from hugegraph_llm.config import llm_settings, huge_settings, prompt
+from hugegraph_llm.config import llm_settings, prompt
from hugegraph_llm.utils.log import log
-def graph_rag_recall(
- query: str,
- gremlin_tmpl_num: int,
- with_gremlin_tmpl: bool,
- answer_prompt: str, # FIXME: should be used in the query
- rerank_method: Literal["bleu", "reranker"],
- near_neighbor_first: bool,
- custom_related_information: str,
- gremlin_prompt: str,
-) -> dict:
- from hugegraph_llm.operators.graph_rag_task import RAGPipeline
-
- rag = RAGPipeline()
-
-
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
- with_gremlin_template=with_gremlin_tmpl,
- num_gremlin_generate_example=gremlin_tmpl_num,
- gremlin_prompt=gremlin_prompt,
- ).merge_dedup_rerank(
- rerank_method=rerank_method,
- near_neighbor_first=near_neighbor_first,
- custom_related_information=custom_related_information,
- )
- context = rag.run(verbose=True, query=query, graph_search=True)
- return context
-
-
def rag_http_api(
- router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf,
apply_embedding_conf, apply_reranker_conf
+ router: APIRouter,
+ rag_answer_func,
+ graph_rag_recall_func,
+ apply_graph_conf,
+ apply_llm_conf,
+ apply_embedding_conf,
+ apply_reranker_conf,
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
result = rag_answer_func(
- req.query,
- req.raw_answer,
- req.vector_only,
- req.graph_only,
- req.graph_vector_answer,
- req.with_gremlin_tmpl,
- req.graph_ratio,
- req.rerank_method,
- req.near_neighbor_first,
- req.custom_priority_info,
- req.answer_prompt or prompt.answer_prompt,
- req.keywords_extract_prompt or prompt.keywords_extract_prompt,
- req.gremlin_tmpl_num,
- req.gremlin_prompt or prompt.gremlin_generate_prompt,
+ text=req.query,
+ raw_answer=req.raw_answer,
+ vector_only_answer=req.vector_only,
+ graph_only_answer=req.graph_only,
+ graph_vector_answer=req.graph_vector_answer,
+ with_gremlin_template=req.with_gremlin_tmpl,
+ graph_ratio=req.graph_ratio,
+ rerank_method=req.rerank_method,
+ near_neighbor_first=req.near_neighbor_first,
+ custom_related_information=req.custom_priority_info,
+ answer_prompt=req.answer_prompt or prompt.answer_prompt,
+ keywords_extract_prompt=req.keywords_extract_prompt or
prompt.keywords_extract_prompt,
+ gremlin_tmpl_num=req.gremlin_tmpl_num,
+ gremlin_prompt=req.gremlin_prompt or
prompt.gremlin_generate_prompt,
)
+ # TODO: we need more info in the response for users to understand the
query logic
return {
- key: value
- for key, value in zip(["raw_answer", "vector_only", "graph_only",
"graph_vector_answer"], result)
- if getattr(req, key)
+ "query": req.query,
+ **{key: value
+ for key, value in zip(["raw_answer", "vector_only",
"graph_only", "graph_vector_answer"], result)
+ if getattr(req, key)}
}
@router.post("/rag/graph", status_code=status.HTTP_200_OK)
def graph_rag_recall_api(req: GraphRAGRequest):
try:
- result = graph_rag_recall(
+ result = graph_rag_recall_func(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
with_gremlin_tmpl=req.with_gremlin_tmpl,
- answer_prompt=req.answer_prompt or prompt.answer_prompt,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
@@ -102,8 +81,15 @@ def rag_http_api(
)
if isinstance(result, dict):
- params = ["query", "keywords", "match_vids",
"graph_result_flag", "gremlin", "graph_result",
- "vertex_degree_list"]
+ params = [
+ "query",
+ "keywords",
+ "match_vids",
+ "graph_result_flag",
+ "gremlin",
+ "graph_result",
+ "vertex_degree_list",
+ ]
user_result = {key: result[key] for key in params if key in
result}
return {"graph_recall": user_result}
# Note: Maybe only for qianfan/wenxin
diff --git a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
index 4d37ced..f6aeef5 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
@@ -18,18 +18,19 @@
from hugegraph_llm.config.models.base_prompt_config import BasePromptConfig
+
class PromptConfig(BasePromptConfig):
# Data is detached from llm_op/answer_synthesize.py
answer_prompt: str = """You are an expert in knowledge graphs and natural
language processing.
Your task is to provide a precise and accurate answer based on the given
context.
+Given the context information and without using fictive knowledge,
+answer the following query in a concise and professional manner.
+
Context information is below.
---------------------
{context_str}
---------------------
-
-Given the context information and without using fictive knowledge,
-answer the following query in a concise and professional manner.
Query: {query_str}
Answer:
"""
@@ -131,7 +132,7 @@ Meet Sarah, a 30-year-old attorney, and her roommate,
James, whom she's shared a
keywords_extract_prompt: str = """指令:
请对以下文本执行以下任务:
1. 从文本中提取关键词:
- - 最少 0 个,最多 {max_keywords} 个。
+ - 最少 0 个,最多 MAX_KEYWORDS 个。
- 关键词应为具有完整语义的词语或短语,确保信息完整。
2. 识别需改写的关键词:
- 从提取的关键词中,识别那些在原语境中具有歧义或存在信息缺失的关键词。
@@ -151,47 +152,55 @@ Meet Sarah, a 30-year-old attorney, and her roommate,
James, whom she's shared a
- 仅输出一行内容, 以 KEYWORDS: 为前缀,后跟所有关键词或对应的同义词,之间用逗号分隔。抽取的关键词中不允许出现空格或空字符
- 格式示例:
KEYWORDS:关键词1,关键词2,...,关键词n
+
+MAX_KEYWORDS: {max_keywords}
文本:
{question}
"""
-#pylint: disable=C0301
+ # pylint: disable=C0301
# keywords_extract_prompt_EN = """
-# Instruction:
-# Please perform the following tasks on the text below:
-# 1. Extract Keywords and Generate Synonyms from text:
-# - At least 0, at most {max_keywords} keywords.
-# - For each keyword, generate its synonyms or possible variant forms.
-# Requirements:
-# - Keywords should be meaningful and specific entities; avoid using
meaningless or overly broad terms (e.g., “object,” “the,” “he”).
-# - Prioritize extracting subjects, verbs, and objects; avoid extracting
function words or auxiliary words.
-# - Do not expand into unrelated generalized categories.
-# Note:
-# - Only consider semantic synonyms and other words with similar meanings in
the given context.
-# Output Format:
-# - Output only one line, prefixed with KEYWORDS:, followed by all keywords
and synonyms, separated by commas.No spaces or empty characters are allowed in
the extracted keywords.
-# - Format example:
-# KEYWORDS: keyword1, keyword2, ..., keywordn, synonym1, synonym2, ...,
synonymn
-# Text:
-# {question}
-# """
-
- gremlin_generate_prompt = """\
-Given the example query-gremlin pairs:
-{example}
-
-Given the graph schema:
+ # Instruction:
+ # Please perform the following tasks on the text below:
+ # 1. Extract Keywords and Generate Synonyms from the text:
+ # - At least 0, at most {max_keywords} keywords.
+ # - For each keyword, generate its synonyms or possible variant forms.
+ # Requirements:
+ # - Keywords should be meaningful and specific entities; avoid using
meaningless or overly broad terms (e.g., “object,” “the,” “he”).
+ # - Prioritize extracting subjects, verbs, and objects; avoid extracting
function words or auxiliary words.
+ # - Do not expand into unrelated generalized categories.
+ # Note:
+ # - Only consider semantic synonyms and other words with similar meanings
in the given context.
+ # Output Format:
+ # - Output only one line, prefixed with KEYWORDS:, followed by all
keywords and synonyms, separated by commas.No spaces or empty characters are
allowed in the extracted keywords.
+ # - Format example:
+ # KEYWORDS: keyword1, keyword2, ..., keywordN, synonym1, synonym2, ...,
synonymN
+ # Text:
+ # {question}
+ # """
+
+ gremlin_generate_prompt = """
+You are an expert in graph query language(Gremlin), your role is to understand
the schema of the graph and generate
+accurate Gremlin code based on the given instructions.
+
+# Graph Schema:
```json
{schema}
```
+# Rule:
+1. Could use the vertex ID directly if it's given in the context.
+2. The output format must be like:
+```gremlin
+g.V().limit(10)
+```
-Given the extracted vertex vid:
+# Extracted vertex vid:
{vertices}
-Generate gremlin from the following user query.
+# Given the example query-gremlin pairs:
+{example}
+
+# Generate gremlin from the following user query.
{query}
-The output format must be like:
-```gremlin
-g.V().limit(10)
-```
+
The generated gremlin is:
"""
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
index 3fe6c0f..700e60b 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
@@ -35,7 +35,7 @@ from hugegraph_llm.demo.rag_demo.configs_block import (
apply_graph_config,
)
from hugegraph_llm.demo.rag_demo.other_block import create_other_block
-from hugegraph_llm.demo.rag_demo.text2gremlin_block import
create_text2gremlin_block
+from hugegraph_llm.demo.rag_demo.text2gremlin_block import
create_text2gremlin_block, graph_rag_recall
from hugegraph_llm.demo.rag_demo.rag_block import create_rag_block, rag_answer
from hugegraph_llm.demo.rag_demo.vector_graph_block import
create_vector_graph_block
from hugegraph_llm.resources.demo.css import CSS
@@ -171,7 +171,13 @@ if __name__ == "__main__":
hugegraph_llm = init_rag_ui()
rag_http_api(
- api_auth, rag_answer, apply_graph_config, apply_llm_config,
apply_embedding_config, apply_reranker_config
+ api_auth,
+ rag_answer,
+ graph_rag_recall,
+ apply_graph_config,
+ apply_llm_config,
+ apply_embedding_config,
+ apply_reranker_config,
)
admin_http_api(api_auth, log_stream)
diff --git
a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
index eaa37bc..c47ce7c 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
@@ -17,14 +17,15 @@
import json
import os
-from typing import Any, Tuple, Dict, Union
+from typing import Any, Tuple, Dict, Union, Literal
import gradio as gr
import pandas as pd
-from hugegraph_llm.config import prompt, resource_path
+from hugegraph_llm.config import prompt, resource_path, huge_settings
from hugegraph_llm.models.embeddings.init_embedding import Embeddings
from hugegraph_llm.models.llms.init_llm import LLMs
+from hugegraph_llm.operators.graph_rag_task import RAGPipeline
from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator
from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
from hugegraph_llm.utils.hugegraph_utils import run_gremlin_query
@@ -32,8 +33,11 @@ from hugegraph_llm.utils.log import log
def store_schema(schema, question, gremlin_prompt):
- if (prompt.text2gql_graph_schema != schema or prompt.default_question !=
question or
- prompt.gremlin_generate_prompt != gremlin_prompt):
+ if (
+ prompt.text2gql_graph_schema != schema or
+ prompt.default_question != question or
+ prompt.gremlin_generate_prompt != gremlin_prompt
+ ):
prompt.text2gql_graph_schema = schema
prompt.default_question = question
prompt.gremlin_generate_prompt = gremlin_prompt
@@ -49,7 +53,7 @@ def build_example_vector_index(temp_file) -> dict:
with open(full_path, "r", encoding="utf-8") as f:
examples = json.load(f)
elif full_path.endswith(".csv"):
- examples = pd.read_csv(full_path).to_dict('records')
+ examples = pd.read_csv(full_path).to_dict("records")
else:
log.critical("Unsupported file format. Please input a JSON or CSV
file.")
return {"error": "Unsupported file format. Please input a JSON or CSV
file."}
@@ -60,8 +64,9 @@ def build_example_vector_index(temp_file) -> dict:
return builder.example_index_build(examples).run()
-def gremlin_generate(inp, example_num, schema, gremlin_prompt) -> Union[
- tuple[str, str], tuple[str, Any, Any, Any, Any]]:
+def gremlin_generate(
+ inp, example_num, schema, gremlin_prompt
+) -> Union[tuple[str, str], tuple[str, Any, Any, Any, Any]]:
generator = GremlinGenerator(llm=LLMs().get_text2gql_llm(),
embedding=Embeddings().get_embedding())
sm = SchemaManager(graph_name=schema)
short_schema = False
@@ -83,19 +88,28 @@ def gremlin_generate(inp, example_num, schema,
gremlin_prompt) -> Union[
return "Invalid JSON schema, please check the format
carefully.", ""
# FIXME: schema is not used in gremlin_generate() step, no context for it
(enhance the logic here)
updated_schema = sm.simple_schema(schema) if short_schema else schema
- context =
generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema,
-
gremlin_prompt).run(query=inp)
+ store_schema(str(updated_schema), inp, gremlin_prompt)
+ context = (
+
generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema,
gremlin_prompt)
+ .run(query=inp)
+ )
try:
context["template_exec_res"] =
run_gremlin_query(query=context["result"])
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
context["template_exec_res"] = f"{e}"
try:
context["raw_exec_res"] =
run_gremlin_query(query=context["raw_result"])
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
context["raw_exec_res"] = f"{e}"
match_result = json.dumps(context.get("match_result", "No Results"),
ensure_ascii=False, indent=2)
- return match_result, context["result"], context["raw_result"],
context["template_exec_res"], context["raw_exec_res"]
+ return (
+ match_result,
+ context["result"],
+ context["raw_result"],
+ context["template_exec_res"],
+ context["raw_exec_res"],
+ )
def simple_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
@@ -112,24 +126,24 @@ def simple_schema(schema: Dict[str, Any]) -> Dict[str,
Any]:
if "edgelabels" in schema:
mini_schema["edgelabels"] = []
for edge in schema["edgelabels"]:
- new_edge = {key: edge[key] for key in
- ["name", "source_label", "target_label", "properties"]
if key in edge}
+ new_edge = {key: edge[key] for key in ["name", "source_label",
"target_label", "properties"] if key in edge}
mini_schema["edgelabels"].append(new_edge)
return mini_schema
def create_text2gremlin_block() -> Tuple:
- gr.Markdown("""## Build Vector Template Index (Optional)
+ gr.Markdown(
+ """## Build Vector Template Index (Optional)
> Uploaded CSV file should be in `query,gremlin` format below:
> e.g. `who is peter?`,`g.V().has('name', 'peter')`
> JSON file should be in format below:
> e.g. `[{"query":"who is peter", "gremlin":"g.V().has('name', 'peter')"}]`
- """)
+ """
+ )
with gr.Row():
file = gr.File(
- value=os.path.join(resource_path, "demo", "text2gremlin.csv"),
- label="Upload Text-Gremlin Pairs File"
+ value=os.path.join(resource_path, "demo", "text2gremlin.csv"),
label="Upload Text-Gremlin Pairs File"
)
out = gr.Textbox(label="Result Message")
with gr.Row():
@@ -143,27 +157,49 @@ def create_text2gremlin_block() -> Tuple:
match = gr.Code(label="Similar Template (TopN)",
language="javascript", elem_classes="code-container-show")
initialized_out = gr.Textbox(label="Gremlin With Template",
show_copy_button=True)
raw_out = gr.Textbox(label="Gremlin Without Template",
show_copy_button=True)
- tmpl_exec_out = gr.Code(label="Query With Template Output",
language="json",
- elem_classes="code-container-show")
- raw_exec_out = gr.Code(label="Query Without Template Output",
language="json",
- elem_classes="code-container-show")
+ tmpl_exec_out = gr.Code(
+ label="Query With Template Output", language="json",
elem_classes="code-container-show"
+ )
+ raw_exec_out = gr.Code(
+ label="Query Without Template Output", language="json",
elem_classes="code-container-show"
+ )
with gr.Column(scale=1):
- example_num_slider = gr.Slider(
- minimum=0,
- maximum=10,
- step=1,
- value=2,
- label="Number of refer examples"
- )
+ example_num_slider = gr.Slider(minimum=0, maximum=10, step=1,
value=2, label="Number of refer examples")
schema_box = gr.Textbox(value=prompt.text2gql_graph_schema,
label="Schema", lines=2, show_copy_button=True)
- prompt_box = gr.Textbox(value=prompt.gremlin_generate_prompt,
label="Prompt", lines=2,
- show_copy_button=True)
+ prompt_box = gr.Textbox(
+ value=prompt.gremlin_generate_prompt, label="Prompt",
lines=20, show_copy_button=True
+ )
btn = gr.Button("Text2Gremlin", variant="primary")
btn.click( # pylint: disable=no-member
fn=gremlin_generate,
inputs=[input_box, example_num_slider, schema_box, prompt_box],
- outputs=[match, initialized_out, raw_out, tmpl_exec_out, raw_exec_out]
- ).then(store_schema, inputs=[schema_box, input_box, prompt_box], )
+ outputs=[match, initialized_out, raw_out, tmpl_exec_out, raw_exec_out],
+ )
return input_box, schema_box, prompt_box
+
+
+def graph_rag_recall(
+ query: str,
+ gremlin_tmpl_num: int,
+ with_gremlin_tmpl: bool,
+ rerank_method: Literal["bleu", "reranker"],
+ near_neighbor_first: bool,
+ custom_related_information: str,
+ gremlin_prompt: str,
+) -> dict:
+ store_schema(prompt.text2gql_graph_schema, query, gremlin_prompt)
+ rag = RAGPipeline()
+
+
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
+ with_gremlin_template=with_gremlin_tmpl,
+ num_gremlin_generate_example=gremlin_tmpl_num,
+ gremlin_prompt=gremlin_prompt,
+ ).merge_dedup_rerank(
+ rerank_method=rerank_method,
+ near_neighbor_first=near_neighbor_first,
+ custom_related_information=custom_related_information,
+ )
+ context = rag.run(verbose=True, query=query, graph_search=True)
+ return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
index 955fc9e..9694647 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
@@ -20,19 +20,19 @@ import json
import re
from typing import Optional, List, Dict, Any, Union
+from hugegraph_llm.config import prompt
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.utils.log import log
-from hugegraph_llm.config import prompt
class GremlinGenerateSynthesize:
def __init__(
- self,
- llm: BaseLLM = None,
- schema: Optional[Union[dict, str]] = None,
- vertices: Optional[List[str]] = None,
- gremlin_prompt: Optional[str] = None
+ self,
+ llm: BaseLLM = None,
+ schema: Optional[Union[dict, str]] = None,
+ vertices: Optional[List[str]] = None,
+ gremlin_prompt: Optional[str] = None
) -> None:
self.llm = llm or LLMs().get_text2gql_llm()
if isinstance(schema, dict):
@@ -59,7 +59,7 @@ class GremlinGenerateSynthesize:
def _format_vertices(self, vertices: Optional[List[str]]) -> Optional[str]:
if not vertices:
return None
- return "\n".join([f"- {vid}" for vid in vertices])
+ return "\n".join([f"- '{vid}'" for vid in vertices])
async def async_generate(self, context: Dict[str, Any]):
async_tasks = {}