This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch search-template in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit 919377587a90728af4879fcffb25a5910d668077 Author: imbajin <[email protected]> AuthorDate: Wed Aug 21 17:45:57 2024 +0800 feat(llm): support user-defined search template --- .../src/hugegraph_llm/demo/rag_web_demo.py | 19 +-- .../operators/common_op/merge_dedup_rerank.py | 8 +- .../src/hugegraph_llm/operators/graph_rag_task.py | 136 +++++++++++++++------ .../operators/llm_op/answer_synthesize.py | 33 ++--- 4 files changed, 137 insertions(+), 59 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 3151e23..b532ac1 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -35,7 +35,7 @@ from hugegraph_llm.config import settings, resource_path from hugegraph_llm.enums.build_mode import BuildMode from hugegraph_llm.models.embeddings.init_embedding import Embeddings from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.graph_rag_task import GraphRAG +from hugegraph_llm.operators.graph_rag_task import RAGPipeline from hugegraph_llm.operators.kg_construction_task import KgBuilder from hugegraph_llm.operators.llm_op.property_graph_extract import SCHEMA_EXAMPLE_PROMPT from hugegraph_llm.utils.hugegraph_utils import get_hg_client @@ -58,24 +58,26 @@ def authenticate(credentials: HTTPAuthorizationCredentials = Depends(sec)): def rag_answer( - text: str, raw_answer: bool, vector_only_answer: bool, graph_only_answer: bool, graph_vector_answer: bool -) -> tuple: + text: str, raw_answer: bool, vector_only_answer: bool, graph_only_answer: bool, + graph_vector_answer: bool, answer_prompt: str) -> tuple: vector_search = vector_only_answer or graph_vector_answer graph_search = graph_only_answer or graph_vector_answer if raw_answer is False and not vector_search and not graph_search: gr.Warning("Please select at least one generate mode.") return "", "", "", "" - searcher = GraphRAG() + searcher = RAGPipeline() if vector_search: searcher.query_vector_index_for_rag() if graph_search: searcher.extract_word().match_keyword_to_id().query_graph_for_rag() + # TODO: add more user-defined search strategies searcher.merge_dedup_rerank().synthesize_answer( raw_answer=raw_answer, vector_only_answer=vector_only_answer, graph_only_answer=graph_only_answer, graph_vector_answer=graph_vector_answer, + answer_prompt=answer_prompt ) try: @@ -449,6 +451,9 @@ def init_rag_ui() -> gr.Interface: vector_only_radio = gr.Radio(choices=[True, False], value=False, label="Vector-only Answer") graph_only_radio = gr.Radio(choices=[True, False], value=False, label="Graph-only Answer") graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer") + from hugegraph_llm.operators.llm_op.answer_synthesize import DEFAULT_ANSWER_TEMPLATE + answer_prompt_input = gr.Textbox(value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt", + show_copy_button=True) btn = gr.Button("Answer Question") btn.click( # pylint: disable=no-member fn=rag_answer, @@ -458,6 +463,7 @@ def init_rag_ui() -> gr.Interface: vector_only_radio, graph_only_radio, graph_vector_radio, + answer_prompt_input, ], outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out], ) @@ -496,7 +502,6 @@ if __name__ == "__main__": # TODO: support multi-user login when need app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag", os.getenv("TOKEN")) if auth_enabled else None) - # Note: set reload to False in production environment - uvicorn.run(app, host=args.host, port=args.port) # TODO: we can't use reload now due to the config 'app' of uvicorn.run - # uvicorn.run("rag_web_demo:app", host="0.0.0.0", port=8001, reload=True) + # ❎:f'{__name__}:app' / rag_web_demo:app / hugegraph_llm.demo.rag_web_demo:app + uvicorn.run(app, host=args.host, port=args.port, reload=False) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index e012479..2187096 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -34,16 +34,16 @@ class MergeDedupRerank: self, embedding: BaseEmbedding, topk: int = 10, - policy: Literal["bleu", "priority"] = "bleu" + strategy: Literal["bleu", "priority"] = "bleu" ): self.embedding = embedding self.topk = topk - if policy == "bleu": + if strategy == "bleu": self.rerank_func = self._bleu_rerank - elif policy == "priority": + elif strategy == "priority": self.rerank_func = self._priority_rerank else: - raise ValueError(f"Unimplemented policy {policy}.") + raise ValueError(f"Unimplemented rank strategy {strategy}.") def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query") diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index f60f091..de75352 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -19,12 +19,12 @@ import time from typing import Dict, Any, Optional, List -from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.models.llms.init_llm import LLMs from hugegraph_llm.models.embeddings.init_embedding import Embeddings -from hugegraph_llm.operators.common_op.print_result import PrintResult +from hugegraph_llm.models.llms.base import BaseLLM +from hugegraph_llm.models.llms.init_llm import LLMs from hugegraph_llm.operators.common_op.merge_dedup_rerank import MergeDedupRerank +from hugegraph_llm.operators.common_op.print_result import PrintResult from hugegraph_llm.operators.document_op.word_extract import WordExtract from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery @@ -34,33 +34,56 @@ from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract from hugegraph_llm.utils.log import log -class GraphRAG: +class RAGPipeline: + """ + RAGPipeline is a (core)class that encapsulates a series of operations for extracting information from text, + querying graph databases and vector indices, merging and re-ranking results, and generating answers. + """ + def __init__(self, llm: Optional[BaseLLM] = None, embedding: Optional[BaseEmbedding] = None): + """ + Initialize the RAGPipeline with optional LLM and embedding models. + + :param llm: Optional LLM model to use. + :param embedding: Optional embedding model to use. + """ self._llm = llm or LLMs().get_llm() self._embedding = embedding or Embeddings().get_embedding() self._operators: List[Any] = [] def extract_word( - self, - text: Optional[str] = None, - language: str = "english", + self, + text: Optional[str] = None, + language: str = "english", ): - self._operators.append( - WordExtract( - text=text, - language=language, - ) - ) + """ + Add a word extraction operator to the pipeline. + + :param text: Text to extract words from. + :param language: Language of the text. + :return: Self-instance for chaining. + """ + self._operators.append(WordExtract(text=text, language=language)) return self def extract_keyword( - self, - text: Optional[str] = None, - max_keywords: int = 5, - language: str = "english", - extract_template: Optional[str] = None, - expand_template: Optional[str] = None, + self, + text: Optional[str] = None, + max_keywords: int = 5, + language: str = "english", + extract_template: Optional[str] = None, + expand_template: Optional[str] = None, ): + """ + Add a keyword extraction operator to the pipeline. + + :param text: Text to extract keywords from. + :param max_keywords: Maximum number of keywords to extract. + :param language: Language of the text. + :param extract_template: Template for keyword extraction. + :param expand_template: Template for keyword expansion. + :return: Self-instance for chaining. + """ self._operators.append( KeywordExtract( text=text, @@ -73,6 +96,12 @@ class GraphRAG: return self def match_keyword_to_id(self, topk_per_keyword: int = 1): + """ + Add a semantic ID query operator to the pipeline. + + :param topk_per_keyword: Top K results per keyword. + :return: Self-instance for chaining. + """ self._operators.append( SemanticIdQuery( embedding=self._embedding, @@ -87,6 +116,14 @@ class GraphRAG: max_items: int = 30, prop_to_match: Optional[str] = None, ): + """ + Add a graph RAG query operator to the pipeline. + + :param max_deep: Maximum depth for the graph query. + :param max_items: Maximum number of items to retrieve. + :param prop_to_match: Property to match in the graph. + :return: Self-instance for chaining. + """ self._operators.append( GraphRAGQuery( max_deep=max_deep, @@ -100,6 +137,12 @@ class GraphRAG: self, max_items: int = 3 ): + """ + Add a vector index query operator to the pipeline. + + :param max_items: Maximum number of items to retrieve. + :return: Self-instance for chaining. + """ self._operators.append( VectorIndexQuery( embedding=self._embedding, @@ -109,37 +152,62 @@ class GraphRAG: return self def merge_dedup_rerank(self): - self._operators.append( - MergeDedupRerank( - embedding=self._embedding, - ) + """ + Add a merge, deduplication, and rerank operator to the pipeline. + + :return: Self-instance for chaining. + """ + self._operators.append(MergeDedupRerank( + embedding=self._embedding, + ) ) return self def synthesize_answer( - self, - raw_answer: bool = False, - vector_only_answer: bool = True, - graph_only_answer: bool = False, - graph_vector_answer: bool = False, - prompt_template: Optional[str] = None, + self, + raw_answer: bool = False, + vector_only_answer: bool = True, + graph_only_answer: bool = False, + graph_vector_answer: bool = False, + answer_prompt: Optional[str] = None, ): + """ + Add an answer synthesis operator to the pipeline. + + :param raw_answer: Whether to return raw answers. + :param vector_only_answer: Whether to return vector-only answers. + :param graph_only_answer: Whether to return graph-only answers. + :param graph_vector_answer: Whether to return graph-vector combined answers. + :param answer_prompt: Template for the answer synthesis prompt. + :return: Self-instance for chaining. + """ self._operators.append( AnswerSynthesize( - raw_answer = raw_answer, - vector_only_answer = vector_only_answer, - graph_only_answer = graph_only_answer, - graph_vector_answer = graph_vector_answer, - prompt_template=prompt_template, + raw_answer=raw_answer, + vector_only_answer=vector_only_answer, + graph_only_answer=graph_only_answer, + graph_vector_answer=graph_vector_answer, + prompt_template=answer_prompt, ) ) return self def print_result(self): + """ + Add a print result operator to the pipeline. + + :return: Self-instance for chaining. + """ self._operators.append(PrintResult()) return self def run(self, **kwargs) -> Dict[str, Any]: + """ + Execute all operators in the pipeline in sequence. + + :param kwargs: Additional context to pass to operators. + :return: Final context after all operators have been executed. + """ if len(self._operators) == 0: self.extract_keyword().query_graph_for_rag().synthesize_answer() diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index f3803c7..6e050d5 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -23,19 +23,24 @@ from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.models.llms.init_llm import LLMs # TODO: we need enhance the template to answer the question -DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL = ( - "Context information is below.\n" - "---------------------\n" - "{context_str}\n" - "---------------------\n" - "You need to refer to the context based on the following priority:\n" - "1. Graph recall > vector recall\n" - "2. Exact recall > Fuzzy recall\n" - "3. Independent vertex > 1-depth neighbor> 2-depth neighbors\n" - "Given the context information and not prior knowledge, answer the query.\n" - "Query: {query_str}\n" - "Answer: " -) +DEFAULT_ANSWER_TEMPLATE = f""" +You are an expert in knowledge graphs and natural language processing. +Your task is to provide a precise and accurate answer based on the given context. + +Context information is below. +--------------------- +{{context_str}} +--------------------- +Please refer to the context based on the following priority: +1. Graph recall > Vector recall +2. Exact recall > Fuzzy recall +3. Independent vertex > 1-depth neighbor > 2-depth neighbors + +Given the context information and without using prior knowledge, +answer the following query in a concise and professional manner. +Query: {{query_str}} +Answer: +""" class AnswerSynthesize: @@ -53,7 +58,7 @@ class AnswerSynthesize: graph_vector_answer: bool = False, ): self._llm = llm - self._prompt_template = prompt_template or DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL + self._prompt_template = prompt_template or DEFAULT_ANSWER_TEMPLATE self._question = question self._context_body = context_body self._context_head = context_head
