This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 4f6414d feat(llm): support basic rag api(v1) (#64)
4f6414d is described below
commit 4f6414d420b94fd539957a02751f43463d248097
Author: chenzihong <[email protected]>
AuthorDate: Mon Aug 19 02:24:53 2024 +0800
feat(llm): support basic rag api(v1) (#64)
* refact: support test wenxin/ollama conn
---------
Co-authored-by: Hongjun Li <[email protected]>
Co-authored-by: imbajin <[email protected]>
---
.gitignore | 1 +
.../api/{rag_api.py => exceptions/__init__.py} | 0
.../{rag_api.py => exceptions/rag_exceptions.py} | 21 ++
.../api/{rag_api.py => models/__init__.py} | 0
.../src/hugegraph_llm/api/models/rag_requests.py | 52 +++
.../api/{rag_api.py => models/rag_response.py} | 7 +
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 50 +++
hugegraph-llm/src/hugegraph_llm/config/config.py | 2 +-
.../src/hugegraph_llm/demo/rag_web_demo.py | 383 +++++++++++----------
.../{api/rag_api.py => enums/build_mode.py} | 11 +
10 files changed, 350 insertions(+), 177 deletions(-)
diff --git a/.gitignore b/.gitignore
index de24191..786b5e1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -164,3 +164,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a
more nuclear
# option (not recommended) you can uncomment the following to ignore the
entire idea folder.
.idea/
+*.DS_Store
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/exceptions/__init__.py
similarity index 100%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/api/exceptions/__init__.py
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py
similarity index 50%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py
index 13a8339..24ef7c1 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py
@@ -14,3 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+from fastapi import HTTPException
+from hugegraph_llm.api.models.rag_response import RAGResponse
+
+
+class ExternalException(HTTPException):
+ def __init__(self):
+ super().__init__(status_code=400, detail="Connect failed with error
code -1, please check the input.")
+
+
+class ConnectionFailedException(HTTPException):
+ def __init__(self, status_code: int, message: str):
+ super().__init__(status_code=status_code, detail=message)
+
+
+def generate_response(response: RAGResponse) -> dict:
+ if response.status_code == -1:
+ raise ExternalException()
+ elif not (200 <= response.status_code < 300):
+ raise ConnectionFailedException(response.status_code, response.message)
+ return {"message": "Connection successful. Configured finished."}
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/models/__init__.py
similarity index 100%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/api/models/__init__.py
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
new file mode 100644
index 0000000..d12a1b8
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pydantic import BaseModel
+from typing import Optional
+
+
+class RAGRequest(BaseModel):
+ query: str
+ raw_llm: Optional[bool] = True
+ vector_only: Optional[bool] = False
+ graph_only: Optional[bool] = False
+ graph_vector: Optional[bool] = False
+
+
+class GraphConfigRequest(BaseModel):
+ ip: str = "127.0.0.1"
+ port: str = "8080"
+ name: str = "hugegraph"
+ user: str = "xxx"
+ pwd: str = "xxx"
+ gs: str = None
+
+
+class LLMConfigRequest(BaseModel):
+ llm_type: str
+ # The common parameters shared by OpenAI, Qianfan Wenxin,
+ # and OLLAMA platforms.
+ api_key: str
+ api_base: str
+ language_model: str
+ # Openai-only properties
+ max_tokens: str = None
+ # qianfan-wenxin-only properties
+ secret_key: str = None
+ # ollama-only properties
+ host: str = None
+ port: str = None
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_response.py
similarity index 87%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/api/models/rag_response.py
index 13a8339..fe139ee 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_response.py
@@ -14,3 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+from pydantic import BaseModel
+
+
+class RAGResponse(BaseModel):
+ status_code: int = -1
+ message: str = ""
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 13a8339..a9c834c 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -14,3 +14,53 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+from fastapi import FastAPI, status
+
+from hugegraph_llm.api.models.rag_response import RAGResponse
+from hugegraph_llm.config import settings
+from hugegraph_llm.api.models.rag_requests import RAGRequest,
GraphConfigRequest, LLMConfigRequest
+from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
+
+
+def rag_http_api(app: FastAPI, rag_answer_func, apply_graph_conf,
apply_llm_conf, apply_embedding_conf):
+ @app.post("/rag", status_code=status.HTTP_200_OK)
+ def rag_answer_api(req: RAGRequest):
+ result = rag_answer_func(req.query, req.raw_llm, req.vector_only,
req.graph_only, req.graph_vector)
+ return {
+ key: value
+ for key, value in zip(["raw_llm", "vector_only", "graph_only",
"graph_vector"], result)
+ if getattr(req, key)
+ }
+
+ @app.post("/config/graph", status_code=status.HTTP_201_CREATED)
+ def graph_config_api(req: GraphConfigRequest):
+ # Accept status code
+ res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd,
req.gs, origin_call="http")
+ return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
+
+ @app.post("/config/llm", status_code=status.HTTP_201_CREATED)
+ def llm_config_api(req: LLMConfigRequest):
+ settings.llm_type = req.llm_type
+
+ if req.llm_type == "openai":
+ res = apply_llm_conf(
+ req.api_key, req.api_base, req.language_model, req.max_tokens,
origin_call="http"
+ )
+ elif req.llm_type == "qianfan_wenxin":
+ res = apply_llm_conf(req.api_key, req.secret_key,
req.language_model, None, origin_call="http")
+ else:
+ res = apply_llm_conf(req.host, req.port, req.language_model, None,
origin_call="http")
+ return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
+
+ @app.post("/config/embedding", status_code=status.HTTP_201_CREATED)
+ def embedding_config_api(req: LLMConfigRequest):
+ settings.embedding_type = req.llm_type
+
+ if req.llm_type == "openai":
+ res = apply_embedding_conf(req.api_key, req.api_base,
req.language_model, origin_call="http")
+ elif req.llm_type == "qianfan_wenxin":
+ res = apply_embedding_conf(req.api_key, req.api_base, None,
origin_call="http")
+ else:
+ res = apply_embedding_conf(req.host, req.port, req.language_model,
origin_call="http")
+ return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py
b/hugegraph-llm/src/hugegraph_llm/config/config.py
index 62d41d4..3659cc1 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/config.py
@@ -67,7 +67,7 @@ class Config:
"""HugeGraph settings"""
graph_ip: Optional[str] = "127.0.0.1"
graph_port: Optional[int] = 8080
- # graph_space: Optional[str] = "DEFAULT"
+ graph_space: Optional[str] = None
graph_name: Optional[str] = "hugegraph"
graph_user: Optional[str] = "admin"
graph_pwd: Optional[str] = "xxx"
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index 01bc85c..756cb1c 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -16,47 +16,38 @@
# under the License.
-import json
import argparse
+import json
import os
-import requests
-import uvicorn
import docx
import gradio as gr
+import requests
+import uvicorn
from fastapi import FastAPI
+from requests.auth import HTTPBasicAuth
-from hugegraph_llm.models.llms.init_llm import LLMs
+from hugegraph_llm.api.rag_api import rag_http_api
+from hugegraph_llm.config import settings, resource_path
+from hugegraph_llm.enums.build_mode import BuildMode
from hugegraph_llm.models.embeddings.init_embedding import Embeddings
+from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.operators.graph_rag_task import GraphRAG
from hugegraph_llm.operators.kg_construction_task import KgBuilder
-from hugegraph_llm.config import settings, resource_path
from hugegraph_llm.operators.llm_op.property_graph_extract import
SCHEMA_EXAMPLE_PROMPT
-from hugegraph_llm.utils.hugegraph_utils import (
- init_hg_test_data,
- run_gremlin_query,
- clean_hg_data
-)
-from hugegraph_llm.utils.log import log
from hugegraph_llm.utils.hugegraph_utils import get_hg_client
+from hugegraph_llm.utils.hugegraph_utils import init_hg_test_data,
run_gremlin_query, clean_hg_data
+from hugegraph_llm.utils.log import log
from hugegraph_llm.utils.vector_index_utils import clean_vector_index
-def convert_bool_str(string):
- if string == "true":
- return True
- if string == "false":
- return False
- raise gr.Error(f"Invalid boolean string: {string}")
-
+def rag_answer(
+ text: str, raw_answer: bool, vector_only_answer: bool,
graph_only_answer: bool, graph_vector_answer: bool
+) -> tuple:
+ vector_search = vector_only_answer or graph_vector_answer
+ graph_search = graph_only_answer or graph_vector_answer
-# TODO: enhance/distinguish the "graph_rag" name to avoid confusion
-def graph_rag(text: str, raw_answer: str, vector_only_answer: str,
- graph_only_answer: str, graph_vector_answer):
- vector_search = convert_bool_str(vector_only_answer) or
convert_bool_str(graph_vector_answer)
- graph_search = convert_bool_str(graph_only_answer) or
convert_bool_str(graph_vector_answer)
-
- if raw_answer == "false" and not vector_search and not graph_search:
+ if raw_answer is False and not vector_search and not graph_search:
gr.Warning("Please select at least one generate mode.")
return "", "", "", ""
searcher = GraphRAG()
@@ -65,10 +56,10 @@ def graph_rag(text: str, raw_answer: str,
vector_only_answer: str,
if graph_search:
searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag()
searcher.merge_dedup_rerank().synthesize_answer(
- raw_answer=convert_bool_str(raw_answer),
- vector_only_answer=convert_bool_str(vector_only_answer),
- graph_only_answer=convert_bool_str(graph_only_answer),
- graph_vector_answer=convert_bool_str(graph_vector_answer)
+ raw_answer=raw_answer,
+ vector_only_answer=vector_only_answer,
+ graph_only_answer=graph_only_answer,
+ graph_vector_answer=graph_vector_answer,
).run(verbose=True, query=text)
try:
@@ -77,7 +68,7 @@ def graph_rag(text: str, raw_answer: str, vector_only_answer:
str,
context.get("raw_answer", ""),
context.get("vector_only_answer", ""),
context.get("graph_only_answer", ""),
- context.get("graph_vector_answer", "")
+ context.get("graph_vector_answer", ""),
)
except ValueError as e:
log.error(e)
@@ -87,7 +78,7 @@ def graph_rag(text: str, raw_answer: str, vector_only_answer:
str,
raise gr.Error(f"An unexpected error occurred: {str(e)}")
-def build_kg(file, schema, example_prompt, build_mode): # pylint:
disable=too-many-branches
+def build_kg(file, schema, example_prompt, build_mode) -> str: # pylint:
disable=too-many-branches
full_path = file.name
if full_path.endswith(".txt"):
with open(full_path, "r", encoding="utf-8") as f:
@@ -99,12 +90,13 @@ def build_kg(file, schema, example_prompt, build_mode): #
pylint: disable=too-m
text += para.text
text += "\n"
elif full_path.endswith(".pdf"):
+ # TODO: support PDF file
raise gr.Error("PDF will be supported later! Try to upload text/docx
now")
else:
raise gr.Error("Please input txt or docx file.")
builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
- if build_mode != "Rebuild vertex index":
+ if build_mode != BuildMode.REBUILD_VERTEX_INDEX.value:
if schema:
try:
schema = json.loads(schema.strip())
@@ -116,39 +108,146 @@ def build_kg(file, schema, example_prompt, build_mode):
# pylint: disable=too-m
return "ERROR: please input schema."
builder.chunk_split(text, "paragraph", "zh")
- # TODO: avoid hardcoding the "build_mode" strings (use var/constant
instead)
- if build_mode == "Rebuild Vector":
+ if build_mode == BuildMode.REBUILD_VECTOR.value:
builder.fetch_graph_data()
else:
builder.extract_info(example_prompt, "property_graph")
# "Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector"
- if build_mode != "Test Mode":
- if build_mode in ("Clear and Import", "Rebuild Vector"):
+ if build_mode != BuildMode.TEST_MODE.value:
+ if build_mode in (BuildMode.CLEAR_AND_IMPORT.value,
BuildMode.REBUILD_VECTOR.value):
clean_vector_index()
builder.build_vector_index()
- if build_mode == "Clear and Import":
+ if build_mode == BuildMode.CLEAR_AND_IMPORT.value:
clean_hg_data()
- if build_mode in ("Clear and Import", "Import Mode"):
+ if build_mode in (BuildMode.CLEAR_AND_IMPORT.value,
BuildMode.IMPORT_MODE.value):
builder.commit_to_hugegraph()
- if build_mode != "Test Mode":
+ if build_mode != BuildMode.TEST_MODE.value:
builder.build_vertex_id_semantic_index()
log.debug(builder.operators)
try:
context = builder.run()
- return context
+ return str(context)
except Exception as e: # pylint: disable=broad-exception-caught
log.error(e)
raise gr.Error(str(e))
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--host", type=str, default="0.0.0.0", help="host")
- parser.add_argument("--port", type=int, default=8001, help="port")
- args = parser.parse_args()
- app = FastAPI()
-
- with gr.Blocks() as hugegraph_llm:
+def test_api_connection(url, method="GET",
+ headers=None, params=None, body=None, auth=None,
origin_call=None) -> int:
+ # TODO: use fastapi.request / starlette instead?
+ log.debug("Request URL: %s", url)
+ try:
+ if method.upper() == "GET":
+ resp = requests.get(url, headers=headers, params=params,
timeout=5, auth=auth)
+ elif method.upper() == "POST":
+ resp = requests.post(url, headers=headers, params=params,
json=body, timeout=5, auth=auth)
+ else:
+ raise ValueError("Unsupported HTTP method, please use GET/POST
instead")
+ except requests.exceptions.RequestException as e:
+ msg = f"Connection failed: {e}"
+ log.error(msg)
+ if origin_call is None:
+ raise gr.Error(msg)
+ return -1 # Error code
+
+ if 200 <= resp.status_code < 300:
+ msg = "Test connection successful~"
+ log.info(msg)
+ gr.Info(msg)
+ else:
+ msg = f"Connection failed with status code: {resp.status_code}, error:
{resp.text}"
+ log.error(msg)
+ # TODO: Only the message returned by rag can be processed, and the
other return values can't be processed
+ if origin_call is None:
+ raise gr.Error(json.loads(resp.text).get("message", msg))
+ return resp.status_code
+
+
+def config_qianfan_model(arg1, arg2, arg3=None, origin_call=None) -> int:
+ settings.qianfan_api_key = arg1
+ settings.qianfan_secret_key = arg2
+ settings.qianfan_language_model = arg3
+ params = {
+ "grant_type": "client_credentials",
+ "client_id": arg1,
+ "client_secret": arg2
+ }
+ status_code =
test_api_connection("https://aip.baidubce.com/oauth/2.0/token", "POST",
params=params,
+ origin_call=origin_call)
+ return status_code
+
+
+def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int:
+ status_code = -1
+ embedding_option = settings.embedding_type
+ if embedding_option == "openai":
+ settings.openai_api_key = arg1
+ settings.openai_api_base = arg2
+ settings.openai_embedding_model = arg3
+ test_url = settings.openai_api_base + "/models"
+ headers = {"Authorization": f"Bearer {arg1}"}
+ status_code = test_api_connection(test_url, headers=headers,
origin_call=origin_call)
+ elif embedding_option == "qianfan_wenxin":
+ status_code = config_qianfan_model(arg1, arg2, origin_call=origin_call)
+ settings.qianfan_embedding_model = arg3
+ elif embedding_option == "ollama":
+ settings.ollama_host = arg1
+ settings.ollama_port = int(arg2)
+ settings.ollama_embedding_model = arg3
+ # TODO: right way to test ollama conn?
+ status_code = test_api_connection(f"http://{arg1}:{arg2}/status",
origin_call=origin_call)
+ settings.update_env()
+ gr.Info("Configured!")
+ return status_code
+
+
+def apply_graph_config(ip, port, name, user, pwd, gs, origin_call=None) -> int:
+ settings.graph_ip = ip
+ settings.graph_port = int(port)
+ settings.graph_name = name
+ settings.graph_user = user
+ settings.graph_pwd = pwd
+ settings.graph_space = gs
+ # Test graph connection (Auth)
+ if gs and gs.strip():
+ test_url = f"http://{ip}:{port}/graphspaces/{gs}/graphs/{name}/schema"
+ else:
+ test_url = f"http://{ip}:{port}/graphs/{name}/schema"
+ auth = HTTPBasicAuth(user, pwd)
+ # for http api return status
+ response = test_api_connection(test_url, auth=auth,
origin_call=origin_call)
+ settings.update_env()
+ return response
+
+
+# Different llm models have different parameters,
+# so no meaningful argument names are given here
+def apply_llm_config(arg1, arg2, arg3, arg4, origin_call=None) -> int:
+ llm_option = settings.llm_type
+ status_code = -1
+ if llm_option == "openai":
+ settings.openai_api_key = arg1
+ settings.openai_api_base = arg2
+ settings.openai_language_model = arg3
+ settings.openai_max_tokens = int(arg4)
+ test_url = settings.openai_api_base + "/models"
+ headers = {"Authorization": f"Bearer {arg1}"}
+ status_code = test_api_connection(test_url, headers=headers,
origin_call=origin_call)
+ elif llm_option == "qianfan_wenxin":
+ status_code = config_qianfan_model(arg1, arg2, arg3, origin_call)
+ elif llm_option == "ollama":
+ settings.ollama_host = arg1
+ settings.ollama_port = int(arg2)
+ settings.ollama_language_model = arg3
+ # TODO: right way to test ollama conn?
+ status_code = test_api_connection(f"http://{arg1}:{arg2}/status",
origin_call=origin_call)
+ gr.Info("Configured!")
+ settings.update_env()
+ return status_code
+
+
+def init_rag_ui() -> gr.Interface:
+ with gr.Blocks() as hugegraph_llm_ui:
gr.Markdown(
"""# HugeGraph LLM RAG Demo
1. Set up the HugeGraph server."""
@@ -159,51 +258,17 @@ if __name__ == "__main__":
gr.Textbox(value=str(settings.graph_port), label="port"),
gr.Textbox(value=settings.graph_name, label="graph"),
gr.Textbox(value=settings.graph_user, label="user"),
- gr.Textbox(value=settings.graph_pwd, label="pwd")
+ gr.Textbox(value=settings.graph_pwd, label="pwd",
type="password"),
+ # gr.Textbox(value=settings.graph_space, label="graphspace
(None)"),
+ # wip: graph_space issue pending
+ gr.Textbox(value="", label="graphspace (None)"),
]
graph_config_button = gr.Button("apply configuration")
-
- def test_api_connection(url, method="GET", ak=None, sk=None,
headers=None, body=None):
- # TODO: use fastapi.request / starlette instead? (Also add a
try-catch here)
- log.debug("Request URL: %s", url)
- if method.upper() == "GET":
- response = requests.get(url, headers=headers, timeout=5)
- elif method.upper() == "POST":
- response = requests.post(url, headers=headers, json=body,
timeout=5)
- else:
- log.error("Unsupported method: %s", method)
- return
-
- if 200 <= response.status_code < 300:
- log.info("Connection successful. Configured finished.")
- gr.Info("Connection successful. Configured finished.")
- else:
- log.error("Connection failed with status code: %s",
response.status_code)
- # pylint: disable=pointless-exception-statement
- gr.Error(f"Connection failed with status code:
{response.status_code}")
-
-
- def apply_graph_configuration(ip, port, name, user, pwd):
- settings.graph_ip = ip
- settings.graph_port = int(port)
- settings.graph_name = name
- settings.graph_user = user
- settings.graph_pwd = pwd
- test_url = f"http://{ip}:{port}/graphs/{name}/schema"
- test_api_connection(test_url)
- settings.update_env()
-
-
- graph_config_button.click(apply_graph_configuration,
inputs=graph_config_input) # pylint: disable=no-member
+ graph_config_button.click(apply_graph_config,
inputs=graph_config_input) # pylint: disable=no-member
gr.Markdown("2. Set up the LLM.")
- llm_dropdown = gr.Dropdown(
- choices=["openai", "qianfan_wenxin", "ollama"],
- value=settings.llm_type,
- label="LLM"
- )
-
+ llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin",
"ollama"], value=settings.llm_type, label="LLM")
@gr.render(inputs=[llm_dropdown])
def llm_settings(llm_type):
@@ -222,58 +287,28 @@ if __name__ == "__main__":
gr.Textbox(value=settings.ollama_host, label="host"),
gr.Textbox(value=str(settings.ollama_port),
label="port"),
gr.Textbox(value=settings.ollama_language_model,
label="model_name"),
- gr.Textbox(value="", visible=False)
+ gr.Textbox(value="", visible=False),
]
elif llm_type == "qianfan_wenxin":
with gr.Row():
llm_config_input = [
- gr.Textbox(value=settings.qianfan_api_key,
label="api_key",
- type="password"),
- gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key",
- type="password"),
+ gr.Textbox(value=settings.qianfan_api_key,
label="api_key", type="password"),
+ gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key", type="password"),
gr.Textbox(value=settings.qianfan_language_model,
label="model_name"),
- gr.Textbox(value="", visible=False)
+ gr.Textbox(value="", visible=False),
]
log.debug(llm_config_input)
else:
llm_config_input = []
llm_config_button = gr.Button("apply configuration")
- def apply_llm_configuration(arg1, arg2, arg3, arg4):
- llm_option = settings.llm_type
-
- if llm_option == "openai":
- settings.openai_api_key = arg1
- settings.openai_api_base = arg2
- settings.openai_language_model = arg3
- settings.openai_max_tokens = int(arg4)
- test_url = settings.openai_api_base + "/models"
- headers = {"Authorization": f"Bearer {arg1}"}
- test_api_connection(test_url, headers=headers, ak=arg1)
- elif llm_option == "qianfan_wenxin":
- settings.qianfan_api_key = arg1
- settings.qianfan_secret_key = arg2
- settings.qianfan_language_model = arg3
- # TODO: test the connection
- # test_url = "https://aip.baidubce.com/oauth/2.0/token" #
POST
- elif llm_option == "ollama":
- settings.ollama_host = arg1
- settings.ollama_port = int(arg2)
- settings.ollama_language_model = arg3
- gr.Info("configured!")
- settings.update_env()
-
- llm_config_button.click(apply_llm_configuration,
inputs=llm_config_input) # pylint: disable=no-member
-
+ llm_config_button.click(apply_llm_config, inputs=llm_config_input)
# pylint: disable=no-member
gr.Markdown("3. Set up the Embedding.")
embedding_dropdown = gr.Dropdown(
- choices=["openai", "ollama", "qianfan_wenxin"],
- value=settings.embedding_type,
- label="Embedding"
+ choices=["openai", "qianfan_wenxin", "ollama"],
value=settings.embedding_type, label="Embedding"
)
-
@gr.render(inputs=[embedding_dropdown])
def embedding_settings(embedding_type):
settings.embedding_type = embedding_type
@@ -282,15 +317,13 @@ if __name__ == "__main__":
embedding_config_input = [
gr.Textbox(value=settings.openai_api_key,
label="api_key", type="password"),
gr.Textbox(value=settings.openai_api_base,
label="api_base"),
- gr.Textbox(value=settings.openai_embedding_model,
label="model_name")
+ gr.Textbox(value=settings.openai_embedding_model,
label="model_name"),
]
elif embedding_type == "qianfan_wenxin":
with gr.Row():
embedding_config_input = [
- gr.Textbox(value=settings.qianfan_api_key,
label="api_key",
- type="password"),
- gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key",
- type="password"),
+ gr.Textbox(value=settings.qianfan_api_key,
label="api_key", type="password"),
+ gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key", type="password"),
gr.Textbox(value=settings.qianfan_embedding_model,
label="model_name"),
]
elif embedding_type == "ollama":
@@ -302,31 +335,13 @@ if __name__ == "__main__":
]
else:
embedding_config_input = []
- embedding_config_button = gr.Button("apply configuration")
- def apply_embedding_configuration(arg1, arg2, arg3):
- embedding_option = settings.embedding_type
- if embedding_option == "openai":
- settings.openai_api_key = arg1
- settings.openai_api_base = arg2
- settings.openai_embedding_model = arg3
- test_url = settings.openai_api_base + "/models"
- headers = {"Authorization": f"Bearer {arg1}"}
- test_api_connection(test_url, headers=headers, ak=arg1)
- elif embedding_option == "ollama":
- settings.ollama_host = arg1
- settings.ollama_port = int(arg2)
- settings.ollama_embedding_model = arg3
- elif embedding_option == "qianfan_wenxin":
- settings.qianfan_access_token = arg1
- settings.qianfan_embed_url = arg2
- settings.update_env()
-
- gr.Info("configured!")
-
- embedding_config_button.click(apply_embedding_configuration, #
pylint: disable=no-member
- inputs=embedding_config_input)
+ embedding_config_button = gr.Button("apply configuration")
+ # Call the separate apply_embedding_configuration function here
+ embedding_config_button.click(
+ apply_embedding_config, inputs=embedding_config_input #
pylint: disable=no-member
+ )
gr.Markdown(
"""## 1. Build vector/graph RAG (💡)
@@ -344,7 +359,7 @@ if __name__ == "__main__":
"""
)
- SCHEMA = """{
+ schema = """{
"vertexlabels": [
{
"id":1,
@@ -380,21 +395,20 @@ if __name__ == "__main__":
}"""
with gr.Row():
- input_file = gr.File(value=os.path.join(resource_path, "demo",
"test.txt"),
- label="Document")
- input_schema = gr.Textbox(value=SCHEMA, label="Schema")
- info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT,
- label="Info extract head")
+ input_file = gr.File(value=os.path.join(resource_path, "demo",
"test.txt"), label="Document")
+ input_schema = gr.Textbox(value=schema, label="Schema")
+ info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT,
label="Info extract head")
with gr.Column():
- mode = gr.Radio(choices=["Test Mode", "Import Mode", "Clear
and Import", "Rebuild Vector"],
- value="Test Mode", label="Build mode")
+ mode = gr.Radio(
+ choices=["Test Mode", "Import Mode", "Clear and Import",
"Rebuild Vector"],
+ value="Test Mode",
+ label="Build mode",
+ )
btn = gr.Button("Build Vector/Graph RAG")
with gr.Row():
out = gr.Textbox(label="Output", show_copy_button=True)
btn.click( # pylint: disable=no-member
- fn=build_kg,
- inputs=[input_file, input_schema, info_extract_template, mode],
- outputs=out
+ fn=build_kg, inputs=[input_file, input_schema,
info_extract_template, mode], outputs=out
)
gr.Markdown("""## 2. RAG with HugeGraph 📖""")
@@ -406,33 +420,50 @@ if __name__ == "__main__":
graph_only_out = gr.Textbox(label="Graph-only Answer",
show_copy_button=True)
graph_vector_out = gr.Textbox(label="Graph-Vector Answer",
show_copy_button=True)
with gr.Column(scale=1):
- raw_radio = gr.Radio(choices=["true", "false"], value="false",
- label="Basic LLM Answer")
- vector_only_radio = gr.Radio(choices=["true", "false"],
value="true",
- label="Vector-only Answer")
- graph_only_radio = gr.Radio(choices=["true", "false"],
value="false",
- label="Graph-only Answer")
- graph_vector_radio = gr.Radio(choices=["true", "false"],
value="false",
- label="Graph-Vector Answer")
+ raw_radio = gr.Radio(choices=[True, False], value=True,
label="Basic LLM Answer")
+ vector_only_radio = gr.Radio(choices=[True, False],
value=False, label="Vector-only Answer")
+ graph_only_radio = gr.Radio(choices=[True, False],
value=False, label="Graph-only Answer")
+ graph_vector_radio = gr.Radio(choices=[True, False],
value=False, label="Graph-Vector Answer")
btn = gr.Button("Answer Question")
- btn.click(fn=graph_rag, inputs=[inp, raw_radio, vector_only_radio,
graph_only_radio, # pylint: disable=no-member
- graph_vector_radio],
- outputs=[raw_out, vector_only_out, graph_only_out,
graph_vector_out])
+ btn.click(
+ fn=rag_answer,
+ inputs=[
+ inp,
+ raw_radio,
+ vector_only_radio,
+ graph_only_radio, # pylint: disable=no-member
+ graph_vector_radio,
+ ],
+ outputs=[raw_out, vector_only_out, graph_only_out,
graph_vector_out],
+ )
gr.Markdown("""## 3. Others (🚧) """)
with gr.Row():
with gr.Column():
inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin
query", show_copy_button=True)
- format = gr.Checkbox(label="Format JSON", value=True)
+ fmt = gr.Checkbox(label="Format JSON", value=True)
out = gr.Textbox(label="Output", show_copy_button=True)
btn = gr.Button("Run gremlin query on HugeGraph")
- btn.click(fn=run_gremlin_query, inputs=[inp, format], outputs=out) #
pylint: disable=no-member
+ btn.click(fn=run_gremlin_query, inputs=[inp, fmt], outputs=out) #
pylint: disable=no-member
with gr.Row():
inp = []
out = gr.Textbox(label="Output", show_copy_button=True)
btn = gr.Button("(BETA) Init HugeGraph test data (🚧WIP)")
btn.click(fn=init_hg_test_data, inputs=inp, outputs=out) # pylint:
disable=no-member
+ return hugegraph_llm_ui
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="host")
+ parser.add_argument("--port", type=int, default=8001, help="port")
+ args = parser.parse_args()
+ app = FastAPI()
+
+ hugegraph_llm = init_rag_ui()
+
+ rag_http_api(app, rag_answer, apply_graph_config, apply_llm_config,
apply_embedding_config)
app = gr.mount_gradio_app(app, hugegraph_llm, path="/")
# Note: set reload to False in production environment
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/enums/build_mode.py
similarity index 76%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/enums/build_mode.py
index 13a8339..50db4c8 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/enums/build_mode.py
@@ -14,3 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+
+from enum import Enum
+
+
+class BuildMode(Enum):
+ REBUILD_VECTOR = "Rebuild Vector"
+ TEST_MODE = "Test Mode"
+ IMPORT_MODE = "Import Mode"
+ CLEAR_AND_IMPORT = "Clear and Import"
+ REBUILD_VERTEX_INDEX = "Rebuild vertex index"