This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch graphspace in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit 794c963f4434b332a3afe4a0e5cc5c843fffbaa9 Author: imbajin <[email protected]> AuthorDate: Thu Aug 15 11:35:55 2024 +0800 feat: support basic auth in llm server (WIP) --- .gitignore | 1 + .../src/hugegraph_llm/demo/rag_web_demo.py | 56 ++++++++++++++-------- .../operators/hugegraph_op/graph_rag_query.py | 5 ++ 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index de24191..bb75d10 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +.DS_Store diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 59edeb1..1230cc9 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -16,41 +16,50 @@ # under the License. -import json import argparse +import json import os from typing import Optional -import requests -import uvicorn import docx import gradio as gr -from fastapi import FastAPI +import requests +import uvicorn +from dotenv import load_dotenv +from fastapi import FastAPI, Depends +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel from requests.auth import HTTPBasicAuth -from hugegraph_llm.models.llms.init_llm import LLMs +from hugegraph_llm.config import settings, resource_path from hugegraph_llm.models.embeddings.init_embedding import Embeddings +from hugegraph_llm.models.llms.init_llm import LLMs from hugegraph_llm.operators.graph_rag_task import GraphRAG from hugegraph_llm.operators.kg_construction_task import KgBuilder -from hugegraph_llm.config import settings, resource_path from hugegraph_llm.operators.llm_op.property_graph_extract import SCHEMA_EXAMPLE_PROMPT +from hugegraph_llm.utils.hugegraph_utils import get_hg_client from hugegraph_llm.utils.hugegraph_utils import ( init_hg_test_data, run_gremlin_query, clean_hg_data ) from hugegraph_llm.utils.log import log -from hugegraph_llm.utils.hugegraph_utils import get_hg_client from hugegraph_llm.utils.vector_index_utils import clean_vector_index +load_dotenv() + +sec = HTTPBearer() -def convert_bool_str(string): - if string == "true": - return True - if string == "false": - return False - raise gr.Error(f"Invalid boolean string: {string}") + +def authenticate(credentials: HTTPAuthorizationCredentials = Depends(sec)): + correct_token = os.getenv("TOKEN") + if credentials.credentials != correct_token: + from fastapi import HTTPException + raise HTTPException( + status_code=401, + detail="Invalid token, please contact the admin", + headers={"WWW-Authenticate": "Bearer"}, + ) # TODO: enhance/distinguish the "graph_rag" name to avoid confusion @@ -59,9 +68,10 @@ def graph_rag(text: str, raw_answer: bool, vector_only_answer: bool, vector_search = vector_only_answer or graph_vector_answer graph_search = graph_only_answer or graph_vector_answer - if raw_answer == False and not vector_search and not graph_search: + if raw_answer is False and not vector_search and not graph_search: gr.Warning("Please select at least one generate mode.") return "", "", "", "" + searcher = GraphRAG() if vector_search: searcher.query_vector_index_for_rag() @@ -159,7 +169,12 @@ if __name__ == "__main__": args = parser.parse_args() app = FastAPI() - with gr.Blocks() as hugegraph_llm: + with gr.Blocks( + theme='default', + title="HugeGraph RAG Platform", + css="footer {visibility: hidden}" + ) as hugegraph_llm: + gr.Markdown( """# HugeGraph LLM RAG Demo 1. Set up the HugeGraph server.""" @@ -171,7 +186,7 @@ if __name__ == "__main__": gr.Textbox(value=settings.graph_name, label="graph"), gr.Textbox(value=settings.graph_user, label="user"), gr.Textbox(value=settings.graph_pwd, label="pwd", type="password"), - gr.Textbox(value=settings.graph_space, label="graphspace (None)"), + gr.Textbox(value=settings.graph_space, label="graphspace (Optional)"), ] graph_config_button = gr.Button("apply configuration") @@ -455,25 +470,26 @@ if __name__ == "__main__": @app.get("/rag/{query}") - def graph_rag_api(query: str): + def graph_rag_api(query: str, _: HTTPAuthorizationCredentials = Depends(authenticate)): result = graph_rag(query, True, True, True, True) return {"raw_answer": result[0], "vector_only_answer": result[1], "graph_only_answer": result[2], "graph_vector_answer": result[3]} @app.post("/rag") - def graph_rag_api(req: RAGRequest): + def graph_rag_api(req: RAGRequest, _: HTTPAuthorizationCredentials = Depends(authenticate)): result = graph_rag(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector) return {key: value for key, value in zip( ["raw_llm", "vector_only", "graph_only", "graph_vector"], result) if getattr(req, key)} @app.get("/rag/graph/{query}") - def graph_rag_api(query: str): + def graph_rag_api(query: str, credentials: HTTPAuthorizationCredentials = Depends(authenticate)): result = graph_rag(query, False, False, True, False) - log.debug(result) + log.debug(f'credentials: {credentials}') return {"graph_only_answer": result[2]} + hugegraph_llm.launch(share=False, auth=("rag", os.getenv("TOKEN")), server_name="0.0.0.0", server_port=8001) app = gr.mount_gradio_app(app, hugegraph_llm, path="/") # Note: set reload to False in production environment diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index ed61eff..5c9fdae 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -27,6 +27,11 @@ class GraphRAGQuery: VERTEX_GREMLIN_QUERY_TEMPL = ( "g.V().hasId({keywords}).as('subj').toList()" ) + # ID_RAG_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV( + # ).as('obj')).times({max_deep}).path().by(project('label', 'id', 'props').by(label()).by(id()).by(valueMap().by( + # unfold()))).by(project('label', 'inV', 'outV', 'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap( + # ).by(unfold()))).limit({max_items}).toList()" + # TODO: we could use a simpler query (like kneighbor-api to get the edges) ID_RAG_GREMLIN_QUERY_TEMPL = """ g.V().hasId({keywords}).as('subj')
