This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 39e8d48 refactor(llm): remove enable_gql logic in api & rag block
(#148)
39e8d48 is described below
commit 39e8d486b31f6a2dd99f8a78dbd6402e91e7ce18
Author: HaoJin Yang <[email protected]>
AuthorDate: Fri Dec 27 16:53:37 2024 +0800
refactor(llm): remove enable_gql logic in api & rag block (#148)
* feat(llm): support choose num of examples in rag
---------
Co-authored-by: imbajin <[email protected]>
---
.../src/hugegraph_llm/api/models/rag_requests.py | 6 +++---
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 10 +++++-----
.../src/hugegraph_llm/demo/rag_demo/rag_block.py | 15 ++++++---------
.../src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py | 11 +++++------
.../src/hugegraph_llm/operators/graph_rag_task.py | 4 +---
.../operators/hugegraph_op/graph_rag_query.py | 10 ++++------
.../operators/index_op/gremlin_example_index_query.py | 11 ++++++-----
7 files changed, 30 insertions(+), 37 deletions(-)
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 489ce0a..de47aa0 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -32,7 +32,6 @@ class RAGRequest(BaseModel):
graph_ratio: float = Query(0.5, description="The ratio of GraphRAG ans &
vector ans")
rerank_method: Literal["bleu", "reranker"] = Query("bleu",
description="Method to rerank the results.")
near_neighbor_first: bool = Query(False, description="Prioritize near
neighbors in the search results.")
- with_gremlin_tmpl: bool = Query(True, description="Use example template in
text2gremlin")
custom_priority_info: str = Query("", description="Custom information to
prioritize certain results.")
answer_prompt: Optional[str] = Query(prompt.answer_prompt,
description="Prompt to guide the answer generation.")
keywords_extract_prompt: Optional[str] = Query(
@@ -49,8 +48,9 @@ class RAGRequest(BaseModel):
# TODO: import the default value of prompt.* dynamically
class GraphRAGRequest(BaseModel):
query: str = Query("", description="Query you want to ask")
- gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates
to use.")
- with_gremlin_tmpl: bool = Query(True, description="Use example template in
text2gremlin")
+ gremlin_tmpl_num: int = Query(
+ 1, description="Number of Gremlin templates to use. If num <=0 means
template is not provided"
+ )
rerank_method: Literal["bleu", "reranker"] = Query("bleu",
description="Method to rerank the results.")
near_neighbor_first: bool = Query(False, description="Prioritize near
neighbors in the search results.")
custom_priority_info: str = Query("", description="Custom information to
prioritize certain results.")
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 4036496..d851fd1 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -49,7 +49,6 @@ def rag_http_api(
vector_only_answer=req.vector_only,
graph_only_answer=req.graph_only,
graph_vector_answer=req.graph_vector_answer,
- with_gremlin_template=req.with_gremlin_tmpl,
graph_ratio=req.graph_ratio,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
@@ -62,9 +61,11 @@ def rag_http_api(
# TODO: we need more info in the response for users to understand the
query logic
return {
"query": req.query,
- **{key: value
- for key, value in zip(["raw_answer", "vector_only",
"graph_only", "graph_vector_answer"], result)
- if getattr(req, key)}
+ **{
+ key: value
+ for key, value in zip(["raw_answer", "vector_only",
"graph_only", "graph_vector_answer"], result)
+ if getattr(req, key)
+ },
}
@router.post("/rag/graph", status_code=status.HTTP_200_OK)
@@ -73,7 +74,6 @@ def rag_http_api(
result = graph_rag_recall_func(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
- with_gremlin_tmpl=req.with_gremlin_tmpl,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
index 070f3b3..c10f84b 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
@@ -35,7 +35,6 @@ def rag_answer(
vector_only_answer: bool,
graph_only_answer: bool,
graph_vector_answer: bool,
- with_gremlin_template: bool,
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
@@ -80,7 +79,6 @@ def rag_answer(
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
huge_settings.graph_name
).query_graphdb(
- with_gremlin_template=with_gremlin_template,
num_gremlin_generate_example=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
)
@@ -125,7 +123,10 @@ def create_rag_block():
value=prompt.answer_prompt, label="Query Prompt",
show_copy_button=True, lines=7
)
keywords_extract_prompt_input = gr.Textbox(
- value=prompt.keywords_extract_prompt, label="Keywords
Extraction Prompt", show_copy_button=True, lines=7
+ value=prompt.keywords_extract_prompt,
+ label="Keywords Extraction Prompt",
+ show_copy_button=True,
+ lines=7,
)
with gr.Column(scale=1):
with gr.Row():
@@ -134,8 +135,6 @@ def create_rag_block():
with gr.Row():
graph_only_radio = gr.Radio(choices=[True, False], value=True,
label="Graph-only Answer")
graph_vector_radio = gr.Radio(choices=[True, False],
value=False, label="Graph-Vector Answer")
- with gr.Row():
- with_gremlin_template_radio = gr.Radio(choices=[True, False],
value=True, label="With Gremlin Template")
def toggle_slider(enable):
return gr.update(interactive=enable)
@@ -148,6 +147,7 @@ def create_rag_block():
value="reranker" if online_rerank else "bleu",
label="Rerank method",
)
+ example_num = gr.Number(value=2, label="Template Num (0 to
disable it) ", precision=0)
graph_ratio = gr.Slider(0, 1, 0.6, label="Graph Ratio",
step=0.1, interactive=False)
graph_vector_radio.change(
@@ -172,13 +172,13 @@ def create_rag_block():
vector_only_radio,
graph_only_radio,
graph_vector_radio,
- with_gremlin_template_radio,
graph_ratio,
rerank_method,
near_neighbor_first,
custom_related_information,
answer_prompt_input,
keywords_extract_prompt_input,
+ example_num,
],
outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out],
)
@@ -237,7 +237,6 @@ def create_rag_block():
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
- with_gremlin_template: bool,
custom_related_information: str,
answer_prompt: str,
keywords_extract_prompt: str,
@@ -257,7 +256,6 @@ def create_rag_block():
graph_ratio,
rerank_method,
near_neighbor_first,
- with_gremlin_template,
custom_related_information,
answer_prompt,
keywords_extract_prompt,
@@ -291,7 +289,6 @@ def create_rag_block():
graph_ratio,
rerank_method,
near_neighbor_first,
- with_gremlin_template_radio,
custom_related_information,
answer_prompt_input,
keywords_extract_prompt_input,
diff --git
a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
index c47ce7c..46e2e9e 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
@@ -34,9 +34,9 @@ from hugegraph_llm.utils.log import log
def store_schema(schema, question, gremlin_prompt):
if (
- prompt.text2gql_graph_schema != schema or
- prompt.default_question != question or
- prompt.gremlin_generate_prompt != gremlin_prompt
+ prompt.text2gql_graph_schema != schema
+ or prompt.default_question != question
+ or prompt.gremlin_generate_prompt != gremlin_prompt
):
prompt.text2gql_graph_schema = schema
prompt.default_question = question
@@ -90,7 +90,8 @@ def gremlin_generate(
updated_schema = sm.simple_schema(schema) if short_schema else schema
store_schema(str(updated_schema), inp, gremlin_prompt)
context = (
-
generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema,
gremlin_prompt)
+ generator.example_index_query(example_num)
+ .gremlin_generate_synthesize(updated_schema, gremlin_prompt)
.run(query=inp)
)
try:
@@ -183,7 +184,6 @@ def create_text2gremlin_block() -> Tuple:
def graph_rag_recall(
query: str,
gremlin_tmpl_num: int,
- with_gremlin_tmpl: bool,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str,
@@ -193,7 +193,6 @@ def graph_rag_recall(
rag = RAGPipeline()
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
- with_gremlin_template=with_gremlin_tmpl,
num_gremlin_generate_example=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
).merge_dedup_rerank(
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 03ac9ae..399864a 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -125,8 +125,7 @@ class RAGPipeline:
max_v_prop_len: int = 2048,
max_e_prop_len: int = 256,
prop_to_match: Optional[str] = None,
- with_gremlin_template: bool = True,
- num_gremlin_generate_example: int = 1,
+ num_gremlin_generate_example: Optional[int] = 1,
gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt,
):
"""
@@ -146,7 +145,6 @@ class RAGPipeline:
max_v_prop_len=max_v_prop_len,
max_e_prop_len=max_e_prop_len,
prop_to_match=prop_to_match,
- with_gremlin_template=with_gremlin_template,
num_gremlin_generate_example=num_gremlin_generate_example,
gremlin_prompt=gremlin_prompt,
)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index a4186c8..e213c37 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -82,10 +82,9 @@ class GraphRAGQuery:
prop_to_match: Optional[str] = None,
llm: Optional[BaseLLM] = None,
embedding: Optional[BaseEmbedding] = None,
- max_v_prop_len: int = 2048,
- max_e_prop_len: int = 256,
- with_gremlin_template: bool = True,
- num_gremlin_generate_example: int = 1,
+ max_v_prop_len: Optional[int] = 2048,
+ max_e_prop_len: Optional[int] = 256,
+ num_gremlin_generate_example: Optional[int] = 1,
gremlin_prompt: Optional[str] = None,
):
self._client = PyHugeClient(
@@ -108,7 +107,6 @@ class GraphRAGQuery:
embedding=embedding,
)
self._num_gremlin_generate_example = num_gremlin_generate_example
- self._with_gremlin_template = with_gremlin_template
self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
@@ -138,7 +136,7 @@ class GraphRAGQuery:
gremlin_response = self._gremlin_generator.gremlin_generate_synthesize(
context["simple_schema"], vertices=vertices,
gremlin_prompt=self._gremlin_prompt
).run(query=query, query_embedding=query_embedding)
- if self._with_gremlin_template:
+ if self._num_gremlin_generate_example > 0:
gremlin = gremlin_response["result"]
else:
gremlin = gremlin_response["raw_result"]
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
index a95e7da..4029995 100644
---
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
+++
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
@@ -38,13 +38,15 @@ class GremlinExampleIndexQuery:
self.vector_index = VectorIndex.from_index_file(self.index_dir)
def _ensure_index_exists(self):
- if not (os.path.exists(os.path.join(self.index_dir, "index.faiss"))
- and os.path.exists(os.path.join(self.index_dir,
"properties.pkl"))):
+ if not (
+ os.path.exists(os.path.join(self.index_dir, "index.faiss"))
+ and os.path.exists(os.path.join(self.index_dir, "properties.pkl"))
+ ):
log.warning("No gremlin example index found, will generate one.")
self._build_default_example_index()
def _get_match_result(self, context: Dict[str, Any], query: str) ->
List[Dict[str, Any]]:
- if self.num_examples == 0:
+ if self.num_examples <= 0:
return []
query_embedding = context.get("query_embedding")
@@ -53,8 +55,7 @@ class GremlinExampleIndexQuery:
return self.vector_index.search(query_embedding, self.num_examples,
dis_threshold=1.8)
def _build_default_example_index(self):
- properties = pd.read_csv(os.path.join(resource_path, "demo",
-
"text2gremlin.csv")).to_dict(orient="records")
+ properties = pd.read_csv(os.path.join(resource_path, "demo",
"text2gremlin.csv")).to_dict(orient="records")
embeddings = [self.embedding.get_text_embedding(row["query"]) for row
in tqdm(properties)]
vector_index = VectorIndex(len(embeddings[0]))
vector_index.add(embeddings, properties)