This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new aeb25bd fix(rag): rag api critical bug & rename some params (#114)
aeb25bd is described below
commit aeb25bde48cb240eae2208120cfa7a5e38bb5dd2
Author: Hongjun Li <[email protected]>
AuthorDate: Mon Nov 18 18:09:27 2024 +0800
fix(rag): rag api critical bug & rename some params (#114)
1. Added "Enable api authentication based on ENABLE_LOGIN in `.env` file".
(in file `hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py`)
2. Synchronize according to the arguments of the latest rag_answer
function. (in other changes files)
3. remove `Optional`, My reasons for removing optional:
The purpose of using optional is to make the value support None,
but since I have assigned it by default, None will never happen, so the
optional is effectively invalid, so it is removed
remove `Optional` References:
1. Fastapi official documentation
(https://fastapi.tiangolo.com/python-types/#using-union-or-optional)
---------
Co-authored-by: imbajin <[email protected]>
---
hugegraph-llm/README.md | 3 +-
.../src/hugegraph_llm/api/models/rag_requests.py | 25 +++++++++-------
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 22 +++++++++-----
.../src/hugegraph_llm/demo/rag_demo/app.py | 13 +++++---
.../operators/document_op/word_extract.py | 8 ++---
.../operators/hugegraph_op/graph_rag_query.py | 1 -
.../operators/index_op/vector_index_query.py | 9 ++----
.../operators/llm_op/answer_synthesize.py | 35 ++++++++--------------
.../operators/llm_op/keyword_extract.py | 7 ++---
9 files changed, 58 insertions(+), 65 deletions(-)
diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md
index e9d6f90..ae07627 100644
--- a/hugegraph-llm/README.md
+++ b/hugegraph-llm/README.md
@@ -31,8 +31,7 @@ graph systems and large language models.
3. Install [hugegraph-python-client](../hugegraph-python-client) and
[hugegraph_llm](src/hugegraph_llm)
```bash
cd ./incubator-hugegraph-ai # better to use virtualenv (source
venv/bin/activate)
- pip install ./hugegraph-python-client
- pip install -r ./hugegraph-llm/requirements.txt
+ pip install ./hugegraph-python-client && pip install -r
./hugegraph-llm/requirements.txt
```
4. Enter the project directory
```bash
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 5ee5c05..d832859 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -21,23 +21,29 @@ from pydantic import BaseModel
class RAGRequest(BaseModel):
- query: str
- raw_llm: Optional[bool] = False
- vector_only: Optional[bool] = False
- graph_only: Optional[bool] = False
- graph_vector: Optional[bool] = False
+ query: str = ""
+ raw_answer: bool = False
+ vector_only: bool = False
+ graph_only: bool = False
+ graph_vector_answer: bool = False
graph_ratio: float = 0.5
rerank_method: Literal["bleu", "reranker"] = "bleu"
near_neighbor_first: bool = False
- custom_related_information: str = None
- answer_prompt: Optional[str] = None
+ custom_priority_info: str = ""
+ answer_prompt: str = ""
class GraphRAGRequest(BaseModel):
- query: str
+ query: str = ""
+ raw_answer: bool = True
+ vector_only: bool = False
+ graph_only: bool = False
+ graph_vector_answer: bool = False
+ graph_ratio: float = 0.5
rerank_method: Literal["bleu", "reranker"] = "bleu"
near_neighbor_first: bool = False
- custom_related_information: str = None
+ custom_priority_info: str = ""
+ answer_prompt: str = ""
class GraphConfigRequest(BaseModel):
@@ -74,4 +80,3 @@ class RerankerConfigRequest(BaseModel):
class LogStreamRequest(BaseModel):
admin_token: Optional[str] = None
log_file: Optional[str] = 'llm-server.log'
-
\ No newline at end of file
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index cf57095..26f42e5 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -14,12 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import os
import json
from typing import Literal
from fastapi import status, APIRouter, HTTPException
-from fastapi.responses import StreamingResponse
from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
@@ -27,7 +25,6 @@ from hugegraph_llm.api.models.rag_requests import (
GraphConfigRequest,
LLMConfigRequest,
RerankerConfigRequest, GraphRAGRequest,
- LogStreamRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import settings
@@ -56,11 +53,21 @@ def rag_http_api(
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
- result = rag_answer_func(req.query, req.raw_llm, req.vector_only,
req.graph_only, req.graph_vector,
- req.answer_prompt)
+ result = rag_answer_func(
+ req.query,
+ req.raw_answer,
+ req.vector_only,
+ req.graph_only,
+ req.graph_vector_answer,
+ req.graph_ratio,
+ req.rerank_method,
+ req.near_neighbor_first,
+ req.custom_priority_info,
+ req.answer_prompt
+ )
return {
key: value
- for key, value in zip(["raw_llm", "vector_only", "graph_only",
"graph_vector"], result)
+ for key, value in zip(["raw_answer", "vector_only", "graph_only",
"graph_vector_answer"], result)
if getattr(req, key)
}
@@ -71,7 +78,7 @@ def rag_http_api(
text=req.query,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
- custom_related_information=req.custom_related_information
+ custom_related_information=req.custom_priority_info
)
if isinstance(result, dict):
@@ -129,4 +136,3 @@ def rag_http_api(
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
-
\ No newline at end of file
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
index 16cabf4..84a25b8 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
@@ -132,21 +132,26 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8001, help="port")
args = parser.parse_args()
app = FastAPI()
- api_auth = APIRouter(dependencies=[Depends(authenticate)])
settings.check_env()
prompt.update_yaml_file()
+ auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true"
+ log.info("(Status) Authentication is %s now.", "enabled" if auth_enabled
else "disabled")
+ api_auth = APIRouter(dependencies=[Depends(authenticate)] if auth_enabled
else [])
+
hugegraph_llm = init_rag_ui()
+
rag_http_api(api_auth, rag_answer, apply_graph_config, apply_llm_config,
apply_embedding_config,
apply_reranker_config)
+
admin_http_api(api_auth, log_stream)
app.include_router(api_auth)
- auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true"
- log.info("(Status) Authentication is %s now.", "enabled" if auth_enabled
else "disabled")
+
# TODO: support multi-user login when need
- app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag",
os.getenv("TOKEN")) if auth_enabled else None)
+ app = gr.mount_gradio_app(app, hugegraph_llm, path="/",
+ auth=("rag", os.getenv("USER_TOKEN")) if
auth_enabled else None)
# TODO: we can't use reload now due to the config 'app' of uvicorn.run
# ❎:f'{__name__}:app' / rag_web_demo:app /
hugegraph_llm.demo.rag_web_demo:app
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
index 546d561..0f585cb 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
@@ -57,12 +57,8 @@ class WordExtract:
keywords = self._filter_keywords(keywords, lowercase=False)
context["keywords"] = keywords
-
- verbose = context.get("verbose") or False
- if verbose:
- from hugegraph_llm.utils.log import log
- log.info("KEYWORDS: %s", context['keywords'])
-
+ from hugegraph_llm.utils.log import log
+ log.info("KEYWORDS: %s", context['keywords'])
return context
def _filter_keywords(
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index d15a8c8..8ac2c9a 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -173,7 +173,6 @@ class GraphRAGQuery:
# TODO: set color for ↓ "\033[93mKnowledge from Graph:\033[0m"
log.debug("Knowledge from Graph:")
log.debug("\n".join(context["graph_result"]))
-
return context
def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]:
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
index e845b61..8417a9a 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
@@ -20,8 +20,9 @@ import os
from typing import Dict, Any
from hugegraph_llm.config import resource_path, settings
-from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.indices.vector_index import VectorIndex
+from hugegraph_llm.models.embeddings.base import BaseEmbedding
+from hugegraph_llm.utils.log import log
class VectorIndexQuery:
@@ -37,9 +38,5 @@ class VectorIndexQuery:
results = self.vector_index.search(query_embedding, self.topk)
# TODO: check format results
context["vector_result"] = results
-
- verbose = context.get("verbose") or False
- if verbose:
- print("\033[93mKNOWLEDGE FROM VECTOR:")
- print("\n".join(rel for rel in context["vector_result"]) +
"\033[0m")
+ log.debug("KNOWLEDGE FROM VECTOR:\n%s", "\n".join(rel for rel in
context["vector_result"]))
return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index 0f894cb..149b7ee 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -103,9 +103,6 @@ class AnswerSynthesize:
async def async_generate(self, context: Dict[str, Any], context_head_str:
str,
context_tail_str: str, vector_result_context: str,
graph_result_context: str):
- # pylint: disable=R0912 (too-many-branches)
- verbose = context.get("verbose") or False
-
# async_tasks stores the async tasks for different answer types
async_tasks = {}
if self._raw_answer:
@@ -138,26 +135,18 @@ class AnswerSynthesize:
self._llm.agenerate(prompt=final_prompt)
)
- if async_tasks.get("raw_task"):
- response = await async_tasks["raw_task"]
- context["raw_answer"] = response
- if verbose:
- log.debug(f"ANSWER: {response}")
- if async_tasks.get("vector_only_task"):
- response = await async_tasks["vector_only_task"]
- context["vector_only_answer"] = response
- if verbose:
- log.debug(f"ANSWER: {response}")
- if async_tasks.get("graph_only_task"):
- response = await async_tasks["graph_only_task"]
- context["graph_only_answer"] = response
- if verbose:
- log.debug(f"ANSWER: {response}")
- if async_tasks.get("graph_vector_task"):
- response = await async_tasks["graph_vector_task"]
- context["graph_vector_answer"] = response
- if verbose:
- log.debug(f"ANSWER: {response}")
+ async_tasks_mapping = {
+ "raw_task": "raw_answer",
+ "vector_only_task": "vector_only_answer",
+ "graph_only_task": "graph_only_answer",
+ "graph_vector_task": "graph_vector_answer"
+ }
+
+ for task_key, context_key in async_tasks_mapping.items():
+ if async_tasks.get(task_key):
+ response = await async_tasks[task_key]
+ context[context_key] = response
+ log.debug("Query Answer: %s", response)
ops = sum([self._raw_answer, self._vector_only_answer,
self._graph_only_answer, self._graph_vector_answer])
context['call_count'] = context.get('call_count', 0) + ops
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
index cdfa6b5..828e394 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -22,6 +22,7 @@ from typing import Set, Dict, Any, Optional
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper
+from hugegraph_llm.utils.log import log
KEYWORDS_EXTRACT_TPL = """Extract {max_keywords} keywords from the text:
{question}
@@ -85,11 +86,7 @@ class KeywordExtract:
keywords.union(self._expand_synonyms(keywords=keywords))
keywords = {k.replace("'", "") for k in keywords}
context["keywords"] = list(keywords)
-
- verbose = context.get("verbose") or False
- if verbose:
- from hugegraph_llm.utils.log import log
- log.info("KEYWORDS: %s", context['keywords'])
+ log.info("User Query: %s\nKeywords: %s", self._query,
context["keywords"])
# extracting keywords & expanding synonyms increase the call count by 2
context["call_count"] = context.get("call_count", 0) + 2