This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 265d43c feat(llm): separate multi llm configs/models (#112)
265d43c is described below
commit 265d43c6ac83810ed31a2c5af3431644c85d8f5c
Author: HaoJin Yang <[email protected]>
AuthorDate: Thu Nov 21 17:42:20 2024 +0800
feat(llm): separate multi llm configs/models (#112)
* update qianfan default model
---------
Co-authored-by: imbajin <[email protected]>
---
hugegraph-llm/README.md | 2 +-
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 1 +
.../src/hugegraph_llm/config/config_data.py | 66 +++++--
.../demo/gremlin_generate_web_demo.py | 76 ++++-----
.../hugegraph_llm/demo/rag_demo/configs_block.py | 190 +++++++++++++++------
.../models/embeddings/init_embedding.py | 10 +-
.../src/hugegraph_llm/models/llms/init_llm.py | 80 +++++++--
.../src/hugegraph_llm/models/llms/ollama.py | 2 +-
.../operators/document_op/word_extract.py | 2 +-
.../src/hugegraph_llm/operators/graph_rag_task.py | 4 +-
.../operators/llm_op/answer_synthesize.py | 2 +-
.../operators/llm_op/keyword_extract.py | 2 +-
.../src/hugegraph_llm/utils/graph_index_utils.py | 8 +-
.../src/hugegraph_llm/utils/vector_index_utils.py | 2 +-
14 files changed, 306 insertions(+), 141 deletions(-)
diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md
index ae07627..a8a4957 100644
--- a/hugegraph-llm/README.md
+++ b/hugegraph-llm/README.md
@@ -98,7 +98,7 @@ This can be obtained from the `LLMs` class.
from hugegraph_llm.operators.kg_construction_task import KgBuilder
TEXT = ""
- builder = KgBuilder(LLMs().get_llm())
+ builder = KgBuilder(LLMs().get_chat_llm())
(
builder
.import_schema(from_hugegraph="talent_graph").print_result()
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index e5c4375..4db2116 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -101,6 +101,7 @@ def rag_http_api(
res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd,
req.gs, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
+ #TODO: restructure the implement of llm to three types, like
"/config/chat_llm"
@router.post("/config/llm", status_code=status.HTTP_201_CREATED)
def llm_config_api(req: LLMConfigRequest):
settings.llm_type = req.llm_type
diff --git a/hugegraph-llm/src/hugegraph_llm/config/config_data.py
b/hugegraph-llm/src/hugegraph_llm/config/config_data.py
index 52b41dd..3e5711c 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/config_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/config_data.py
@@ -26,42 +26,76 @@ class ConfigData:
"""LLM settings"""
# env_path: Optional[str] = ".env"
- llm_type: Literal["openai", "ollama", "qianfan_wenxin", "zhipu"] = "openai"
- embedding_type: Optional[Literal["openai", "ollama", "qianfan_wenxin",
"zhipu"]] = "openai"
+ chat_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin",
"zhipu"] = "openai"
+ extract_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin",
"zhipu"] = "openai"
+ text2gql_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin",
"zhipu"] = "openai"
+ embedding_type: Optional[Literal["openai", "ollama/local",
"qianfan_wenxin", "zhipu"]] = "openai"
reranker_type: Optional[Literal["cohere", "siliconflow"]] = None
# 1. OpenAI settings
- openai_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL",
"https://api.openai.com/v1")
- openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
- openai_language_model: Optional[str] = "gpt-4o-mini"
+ openai_chat_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL",
"https://api.openai.com/v1")
+ openai_chat_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
+ openai_chat_language_model: Optional[str] = "gpt-4o-mini"
+ openai_extract_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL",
"https://api.openai.com/v1")
+ openai_extract_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
+ openai_extract_language_model: Optional[str] = "gpt-4o-mini"
+ openai_text2gql_api_base: Optional[str] =
os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
+ openai_text2gql_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
+ openai_text2gql_language_model: Optional[str] = "gpt-4o-mini"
openai_embedding_api_base: Optional[str] =
os.environ.get("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")
openai_embedding_api_key: Optional[str] =
os.environ.get("OPENAI_EMBEDDING_API_KEY")
openai_embedding_model: Optional[str] = "text-embedding-3-small"
- openai_max_tokens: int = 4096
+ openai_chat_tokens: int = 4096
+ openai_extract_tokens: int = 4096
+ openai_text2gql_tokens: int = 4096
# 2. Rerank settings
cohere_base_url: Optional[str] = os.environ.get("CO_API_URL",
"https://api.cohere.com/v1/rerank")
reranker_api_key: Optional[str] = None
reranker_model: Optional[str] = None
# 3. Ollama settings
- ollama_host: Optional[str] = "127.0.0.1"
- ollama_port: Optional[int] = 11434
- ollama_language_model: Optional[str] = None
+ ollama_chat_host: Optional[str] = "127.0.0.1"
+ ollama_chat_port: Optional[int] = 11434
+ ollama_chat_language_model: Optional[str] = None
+ ollama_extract_host: Optional[str] = "127.0.0.1"
+ ollama_extract_port: Optional[int] = 11434
+ ollama_extract_language_model: Optional[str] = None
+ ollama_text2gql_host: Optional[str] = "127.0.0.1"
+ ollama_text2gql_port: Optional[int] = 11434
+ ollama_text2gql_language_model: Optional[str] = None
+ ollama_embedding_host: Optional[str] = "127.0.0.1"
+ ollama_embedding_port: Optional[int] = 11434
ollama_embedding_model: Optional[str] = None
# 4. QianFan/WenXin settings
- qianfan_api_key: Optional[str] = None
- qianfan_secret_key: Optional[str] = None
- qianfan_access_token: Optional[str] = None
+ qianfan_chat_api_key: Optional[str] = None
+ qianfan_chat_secret_key: Optional[str] = None
+ qianfan_chat_access_token: Optional[str] = None
+ qianfan_extract_api_key: Optional[str] = None
+ qianfan_extract_secret_key: Optional[str] = None
+ qianfan_extract_access_token: Optional[str] = None
+ qianfan_text2gql_api_key: Optional[str] = None
+ qianfan_text2gql_secret_key: Optional[str] = None
+ qianfan_text2gql_access_token: Optional[str] = None
+ qianfan_embedding_api_key: Optional[str] = None
+ qianfan_embedding_secret_key: Optional[str] = None
# 4.1 URL settings
qianfan_url_prefix: Optional[str] =
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
qianfan_chat_url: Optional[str] = qianfan_url_prefix + "/chat/"
- qianfan_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K"
+ qianfan_chat_language_model: Optional[str] = "ERNIE-Speed-128K"
+ qianfan_extract_language_model: Optional[str] = "ERNIE-Speed-128K"
+ qianfan_text2gql_language_model: Optional[str] = "ERNIE-Speed-128K"
qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/"
# refer https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu to get more
details
qianfan_embedding_model: Optional[str] = "embedding-v1"
# TODO: To be confirmed, whether to configure
# 5. ZhiPu(GLM) settings
- zhipu_api_key: Optional[str] = None
- zhipu_language_model: Optional[str] = "glm-4"
- zhipu_embedding_model: Optional[str] = "embedding-2"
+ zhipu_chat_api_key: Optional[str] = None
+ zhipu_chat_language_model: Optional[str] = "glm-4"
+ zhipu_chat_embedding_model: Optional[str] = "embedding-2"
+ zhipu_extract_api_key: Optional[str] = None
+ zhipu_extract_language_model: Optional[str] = "glm-4"
+ zhipu_extract_embedding_model: Optional[str] = "embedding-2"
+ zhipu_text2gql_api_key: Optional[str] = None
+ zhipu_text2gql_language_model: Optional[str] = "glm-4"
+ zhipu_text2gql_embedding_model: Optional[str] = "embedding-2"
"""HugeGraph settings"""
graph_ip: Optional[str] = "127.0.0.1"
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py
b/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py
index 6166321..d21a54b 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py
@@ -34,7 +34,7 @@ def build_example_vector_index(temp_file):
else:
return "ERROR: please input json file."
builder = GremlinGenerator(
- llm=LLMs().get_llm(),
+ llm=LLMs().get_text2gql_llm(),
embedding=Embeddings().get_embedding(),
)
return builder.example_index_build(examples).run()
@@ -42,7 +42,7 @@ def build_example_vector_index(temp_file):
def gremlin_generate(inp, use_schema, use_example, example_num, schema):
generator = GremlinGenerator(
- llm=LLMs().get_llm(),
+ llm=LLMs().get_text2gql_llm(),
embedding=Embeddings().get_embedding(),
)
if use_example == "true":
@@ -58,35 +58,35 @@ if __name__ == '__main__':
"""# HugeGraph LLM Text2Gremlin Demo"""
)
gr.Markdown("## Set up the LLM")
- llm_dropdown = gr.Dropdown(["openai", "qianfan_wenxin", "ollama"],
value=settings.llm_type,
+ llm_dropdown = gr.Dropdown(["openai", "qianfan_wenxin",
"ollama/local"], value=settings.text2gql_llm_type,
label="LLM")
@gr.render(inputs=[llm_dropdown])
def llm_settings(llm_type):
- settings.llm_type = llm_type
+ settings.text2gql_llm_type = llm_type
if llm_type == "openai":
with gr.Row():
llm_config_input = [
- gr.Textbox(value=settings.openai_api_key,
label="api_key"),
- gr.Textbox(value=settings.openai_api_base,
label="api_base"),
- gr.Textbox(value=settings.openai_language_model,
label="model_name"),
- gr.Textbox(value=str(settings.openai_max_tokens),
label="max_token"),
+ gr.Textbox(value=settings.openai_text2gql_api_key,
label="api_key"),
+ gr.Textbox(value=settings.openai_text2gql_api_base,
label="api_base"),
+
gr.Textbox(value=settings.openai_text2gql_language_model, label="model_name"),
+ gr.Textbox(value=str(settings.openai_text2gql_tokens),
label="max_token"),
]
elif llm_type == "qianfan_wenxin":
with gr.Row():
llm_config_input = [
- gr.Textbox(value=settings.qianfan_api_key,
label="api_key"),
- gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key"),
+ gr.Textbox(value=settings.qianfan_text2gql_api_key,
label="api_key"),
+ gr.Textbox(value=settings.qianfan_text2gql_secret_key,
label="secret_key"),
gr.Textbox(value=settings.qianfan_chat_url,
label="chat_url"),
- gr.Textbox(value=settings.qianfan_language_model,
label="model_name")
+
gr.Textbox(value=settings.qianfan_text2gql_language_model, label="model_name")
]
- elif llm_type == "ollama":
+ elif llm_type == "ollama/local":
with gr.Row():
llm_config_input = [
- gr.Textbox(value=settings.ollama_host, label="host"),
- gr.Textbox(value=str(settings.ollama_port),
label="port"),
- gr.Textbox(value=settings.ollama_language_model,
label="model_name"),
+ gr.Textbox(value=settings.ollama_text2gql_host,
label="host"),
+ gr.Textbox(value=str(settings.ollama_text2gql_port),
label="port"),
+
gr.Textbox(value=settings.ollama_text2gql_language_model, label="model_name"),
gr.Textbox(value="", visible=False)
]
else:
@@ -94,28 +94,28 @@ if __name__ == '__main__':
llm_config_button = gr.Button("Apply Configuration")
def apply_configuration(arg1, arg2, arg3, arg4):
- llm_option = settings.llm_type
+ llm_option = settings.text2gql_llm_type
if llm_option == "openai":
- settings.openai_api_key = arg1
- settings.openai_api_base = arg2
- settings.openai_language_model = arg3
- settings.openai_max_tokens = int(arg4)
+ settings.openai_text2gql_api_key = arg1
+ settings.openai_text2gql_api_base = arg2
+ settings.openai_text2gql_language_model = arg3
+ settings.openai_text2gql_tokens = int(arg4)
elif llm_option == "qianfan_wenxin":
- settings.qianfan_api_key = arg1
- settings.qianfan_secret_key = arg2
+ settings.qianfan_text2gql_api_key = arg1
+ settings.qianfan_text2gql_secret_key = arg2
settings.qianfan_chat_url = arg3
- settings.qianfan_language_model = arg4
- elif llm_option == "ollama":
- settings.ollama_host = arg1
- settings.ollama_port = int(arg2)
- settings.ollama_language_model = arg3
+ settings.qianfan_text2gql_language_model = arg4
+ elif llm_option == "ollam/local":
+ settings.ollama_text2gql_host = arg1
+ settings.ollama_text2gql_port = int(arg2)
+ settings.ollama_text2gql_language_model = arg3
gr.Info("configured!")
llm_config_button.click(apply_configuration,
inputs=llm_config_input) # pylint: disable=no-member
gr.Markdown("## Set up the Embedding")
embedding_dropdown = gr.Dropdown(
- choices=["openai", "ollama"],
+ choices=["openai", "ollama/local"],
value=settings.embedding_type,
label="Embedding"
)
@@ -126,15 +126,15 @@ if __name__ == '__main__':
if embedding_type == "openai":
with gr.Row():
embedding_config_input = [
- gr.Textbox(value=settings.openai_api_key,
label="api_key"),
- gr.Textbox(value=settings.openai_api_base,
label="api_base"),
+ gr.Textbox(value=settings.openai_text2gql_api_key,
label="api_key"),
+ gr.Textbox(value=settings.openai_text2gql_api_base,
label="api_base"),
gr.Textbox(value=settings.openai_embedding_model,
label="model_name")
]
- elif embedding_type == "ollama":
+ elif embedding_type == "ollama/local":
with gr.Row():
embedding_config_input = [
- gr.Textbox(value=settings.ollama_host, label="host"),
- gr.Textbox(value=str(settings.ollama_port),
label="port"),
+ gr.Textbox(value=settings.ollama_text2gql_host,
label="host"),
+ gr.Textbox(value=str(settings.ollama_text2gql_port),
label="port"),
gr.Textbox(value=settings.ollama_embedding_model,
label="model_name"),
]
else:
@@ -144,12 +144,12 @@ if __name__ == '__main__':
def apply_configuration(arg1, arg2, arg3):
embedding_option = settings.embedding_type
if embedding_option == "openai":
- settings.openai_api_key = arg1
- settings.openai_api_base = arg2
+ settings.openai_text2gql_api_key = arg1
+ settings.openai_text2gql_api_base = arg2
settings.openai_embedding_model = arg3
- elif embedding_option == "ollama":
- settings.ollama_host = arg1
- settings.ollama_port = int(arg2)
+ elif embedding_option == "ollama/local":
+ settings.ollama_text2gql_host = arg1
+ settings.ollama_text2gql_port = int(arg2)
settings.ollama_embedding_model = arg3
gr.Info("configured!")
# pylint: disable=no-member
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
index cf3e139..39b036b 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
@@ -25,6 +25,9 @@ from requests.auth import HTTPBasicAuth
from hugegraph_llm.config import settings
from hugegraph_llm.utils.log import log
+from functools import partial
+
+current_llm = "chat"
def test_api_connection(url, method="GET", headers=None, params=None,
body=None, auth=None, origin_call=None) -> int:
@@ -60,11 +63,11 @@ def test_api_connection(url, method="GET", headers=None,
params=None, body=None,
return resp.status_code
-def config_qianfan_model(arg1, arg2, arg3=None, origin_call=None) -> int:
- settings.qianfan_api_key = arg1
- settings.qianfan_secret_key = arg2
+def config_qianfan_model(arg1, arg2, arg3=None, settings_prefix=None,
origin_call=None) -> int:
+ setattr(settings, f"qianfan_{settings_prefix}_api_key", arg1)
+ setattr(settings, f"qianfan_{settings_prefix}_secret_key", arg2)
if arg3:
- settings.qianfan_language_model = arg3
+ setattr(settings, f"qianfan_{settings_prefix}_language_model", arg3)
params = {
"grant_type": "client_credentials",
"client_id": arg1,
@@ -88,11 +91,11 @@ def apply_embedding_config(arg1, arg2, arg3,
origin_call=None) -> int:
data = {"model": arg3, "input": "test"}
status_code = test_api_connection(test_url, method="POST",
headers=headers, body=data, origin_call=origin_call)
elif embedding_option == "qianfan_wenxin":
- status_code = config_qianfan_model(arg1, arg2, origin_call=origin_call)
+ status_code = config_qianfan_model(arg1, arg2,
settings_prefix="embedding", origin_call=origin_call)
settings.qianfan_embedding_model = arg3
- elif embedding_option == "ollama":
- settings.ollama_host = arg1
- settings.ollama_port = int(arg2)
+ elif embedding_option == "ollama/local":
+ settings.ollama_embedding_host = arg1
+ settings.ollama_embedding_port = int(arg2)
settings.ollama_embedding_model = arg3
status_code = test_api_connection(f"http://{arg1}:{arg2}",
origin_call=origin_call)
settings.update_env()
@@ -158,15 +161,20 @@ def apply_graph_config(ip, port, name, user, pwd, gs,
origin_call=None) -> int:
# Different llm models have different parameters, so no meaningful argument
names are given here
-def apply_llm_config(arg1, arg2, arg3, arg4, origin_call=None) -> int:
- llm_option = settings.llm_type
+def apply_llm_config(current_llm, arg1, arg2, arg3, arg4, origin_call=None) ->
int:
+ log.debug("current llm in apply_llm_config is %s", current_llm)
+ llm_option = getattr(settings, f"{current_llm}_llm_type")
+ log.debug("llm option in apply_llm_config is %s", llm_option)
status_code = -1
+
if llm_option == "openai":
- settings.openai_api_key = arg1
- settings.openai_api_base = arg2
- settings.openai_language_model = arg3
- settings.openai_max_tokens = int(arg4)
- test_url = settings.openai_api_base + "/chat/completions"
+ setattr(settings, f"openai_{current_llm}_api_key", arg1)
+ setattr(settings, f"openai_{current_llm}_api_base", arg2)
+ setattr(settings, f"openai_{current_llm}_language_model", arg3)
+ setattr(settings, f"openai_{current_llm}_tokens", int(arg4))
+
+ test_url = getattr(settings, f"openai_{current_llm}_api_base") +
"/chat/completions"
+ log.debug(f"Type of openai {current_llm} max token is %s", type(arg4))
data = {
"model": arg3,
"temperature": 0.0,
@@ -174,17 +182,22 @@ def apply_llm_config(arg1, arg2, arg3, arg4,
origin_call=None) -> int:
}
headers = {"Authorization": f"Bearer {arg1}"}
status_code = test_api_connection(test_url, method="POST",
headers=headers, body=data, origin_call=origin_call)
+
elif llm_option == "qianfan_wenxin":
- status_code = config_qianfan_model(arg1, arg2, arg3, origin_call)
- elif llm_option == "ollama":
- settings.ollama_host = arg1
- settings.ollama_port = int(arg2)
- settings.ollama_language_model = arg3
+ status_code = config_qianfan_model(arg1, arg2, arg3,
settings_prefix=current_llm, origin_call=origin_call)
+
+ elif llm_option == "ollama/local":
+ setattr(settings, f"ollama_{current_llm}_host", arg1)
+ setattr(settings, f"ollama_{current_llm}_port", int(arg2))
+ setattr(settings, f"ollama_{current_llm}_language_model", arg3)
status_code = test_api_connection(f"http://{arg1}:{arg2}",
origin_call=origin_call)
+
gr.Info("Configured!")
settings.update_env()
+
return status_code
+
# TODO: refactor the function to reduce the number of statements & separate
the logic
def create_configs_block() -> list:
# pylint: disable=R0915 (too-many-statements)
@@ -201,51 +214,116 @@ def create_configs_block() -> list:
graph_config_button = gr.Button("Apply Configuration")
graph_config_button.click(apply_graph_config, inputs=graph_config_input)
# pylint: disable=no-member
+ #TODO : use OOP to restruact
with gr.Accordion("2. Set up the LLM.", open=False):
- gr.Markdown("> Tips: the openai sdk also support openai style api from
other providers.")
- llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin",
"ollama"], value=settings.llm_type, label="LLM")
+ gr.Markdown("> Tips: the openai option also support openai style api
from other providers.")
+ with gr.Tab(label='chat'):
+ chat_llm_dropdown = gr.Dropdown(choices=["openai",
"qianfan_wenxin", "ollama/local"],
+ value=getattr(settings, f"chat_llm_type"),
label=f"type")
+ apply_llm_config_with_chat_op = partial(apply_llm_config, "chat")
+ @gr.render(inputs=[chat_llm_dropdown])
+ def chat_llm_settings(llm_type):
+ settings.chat_llm_type = llm_type
+ llm_config_input = []
+ if llm_type == "openai":
+ llm_config_input = [
+ gr.Textbox(value=getattr(settings,
f"openai_chat_api_key"), label="api_key", type="password"),
+ gr.Textbox(value=getattr(settings,
f"openai_chat_api_base"), label="api_base"),
+ gr.Textbox(value=getattr(settings,
f"openai_chat_language_model"), label="model_name"),
+ gr.Textbox(value=getattr(settings,
f"openai_chat_tokens"), label="max_token"),
+ ]
+ elif llm_type == "ollama/local":
+ llm_config_input = [
+ gr.Textbox(value=getattr(settings,
f"ollama_chat_host"), label="host"),
+ gr.Textbox(value=str(getattr(settings,
f"ollama_chat_port")), label="port"),
+ gr.Textbox(value=getattr(settings,
f"ollama_chat_language_model"), label="model_name"),
+ gr.Textbox(value="", visible=False),
+ ]
+ elif llm_type == "qianfan_wenxin":
+ llm_config_input = [
+ gr.Textbox(value=getattr(settings,
f"qianfan_chat_api_key"), label="api_key", type="password"),
+ gr.Textbox(value=getattr(settings,
f"qianfan_chat_secret_key"), label="secret_key", type="password"),
+ gr.Textbox(value=getattr(settings,
f"qianfan_chat_language_model"), label="model_name"),
+ gr.Textbox(value="", visible=False),
+ ]
+ else:
+ llm_config_input = [gr.Textbox(value="", visible=False)
for _ in range(4)]
+ llm_config_button = gr.Button("Apply configuration")
+ llm_config_button.click(apply_llm_config_with_chat_op,
inputs=llm_config_input)
- @gr.render(inputs=[llm_dropdown])
- def llm_settings(llm_type):
- settings.llm_type = llm_type
- if llm_type == "openai":
- with gr.Row():
+ with gr.Tab(label='extract'):
+ extract_llm_dropdown = gr.Dropdown(choices=["openai",
"qianfan_wenxin", "ollama/local"],
+ value=getattr(settings, f"extract_llm_type"),
label=f"type")
+ apply_llm_config_with_extract_op = partial(apply_llm_config,
"extract")
+
+ @gr.render(inputs=[extract_llm_dropdown])
+ def extract_llm_settings(llm_type):
+ settings.extract_llm_type = llm_type
+ llm_config_input = []
+ if llm_type == "openai":
+ llm_config_input = [
+ gr.Textbox(value=getattr(settings,
f"openai_extract_api_key"), label="api_key", type="password"),
+ gr.Textbox(value=getattr(settings,
f"openai_extract_api_base"), label="api_base"),
+ gr.Textbox(value=getattr(settings,
f"openai_extract_language_model"), label="model_name"),
+ gr.Textbox(value=getattr(settings,
f"openai_extract_tokens"), label="max_token"),
+ ]
+ elif llm_type == "ollama/local":
llm_config_input = [
- gr.Textbox(value=settings.openai_api_key,
label="api_key", type="password"),
- gr.Textbox(value=settings.openai_api_base,
label="api_base"),
- gr.Textbox(value=settings.openai_language_model,
label="model_name"),
- gr.Textbox(value=settings.openai_max_tokens,
label="max_token"),
+ gr.Textbox(value=getattr(settings,
f"ollama_extract_host"), label="host"),
+ gr.Textbox(value=str(getattr(settings,
f"ollama_extract_port")), label="port"),
+ gr.Textbox(value=getattr(settings,
f"ollama_extract_language_model"), label="model_name"),
+ gr.Textbox(value="", visible=False),
]
- elif llm_type == "ollama":
- with gr.Row():
+ elif llm_type == "qianfan_wenxin":
llm_config_input = [
- gr.Textbox(value=settings.ollama_host, label="host"),
- gr.Textbox(value=str(settings.ollama_port),
label="port"),
- gr.Textbox(value=settings.ollama_language_model,
label="model_name"),
+ gr.Textbox(value=getattr(settings,
f"qianfan_extract_api_key"), label="api_key", type="password"),
+ gr.Textbox(value=getattr(settings,
f"qianfan_extract_secret_key"), label="secret_key", type="password"),
+ gr.Textbox(value=getattr(settings,
f"qianfan_extract_language_model"), label="model_name"),
gr.Textbox(value="", visible=False),
]
- elif llm_type == "qianfan_wenxin":
- with gr.Row():
+ else:
+ llm_config_input = [gr.Textbox(value="", visible=False)
for _ in range(4)]
+ llm_config_button = gr.Button("Apply configuration")
+ llm_config_button.click(apply_llm_config_with_extract_op,
inputs=llm_config_input)
+ with gr.Tab(label='text2gql'):
+ text2gql_llm_dropdown = gr.Dropdown(choices=["openai",
"qianfan_wenxin", "ollama/local"],
+ value=getattr(settings, f"text2gql_llm_type"),
label=f"type")
+ apply_llm_config_with_text2gql_op = partial(apply_llm_config,
"text2gql")
+
+ @gr.render(inputs=[text2gql_llm_dropdown])
+ def text2gql_llm_settings(llm_type):
+ settings.text2gql_llm_type = llm_type
+ llm_config_input = []
+ if llm_type == "openai":
llm_config_input = [
- gr.Textbox(value=settings.qianfan_api_key,
label="api_key", type="password"),
- gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key", type="password"),
- gr.Textbox(value=settings.qianfan_language_model,
label="model_name"),
+ gr.Textbox(value=getattr(settings,
f"openai_text2gql_api_key"), label="api_key", type="password"),
+ gr.Textbox(value=getattr(settings,
f"openai_text2gql_api_base"), label="api_base"),
+ gr.Textbox(value=getattr(settings,
f"openai_text2gql_language_model"), label="model_name"),
+ gr.Textbox(value=getattr(settings,
f"openai_text2gql_tokens"), label="max_token"),
+ ]
+ elif llm_type == "ollama/local":
+ llm_config_input = [
+ gr.Textbox(value=getattr(settings,
f"ollama_text2gql_host"), label="host"),
+ gr.Textbox(value=str(getattr(settings,
f"ollama_text2gql_port")), label="port"),
+ gr.Textbox(value=getattr(settings,
f"ollama_text2gql_language_model"), label="model_name"),
gr.Textbox(value="", visible=False),
]
- else:
- llm_config_input = [
- gr.Textbox(value="", visible=False),
- gr.Textbox(value="", visible=False),
- gr.Textbox(value="", visible=False),
- gr.Textbox(value="", visible=False),
- ]
- llm_config_button = gr.Button("Apply configuration")
- llm_config_button.click(apply_llm_config, inputs=llm_config_input)
# pylint: disable=no-member
+ elif llm_type == "qianfan_wenxin":
+ llm_config_input = [
+ gr.Textbox(value=getattr(settings,
f"qianfan_text2gql_api_key"), label="api_key", type="password"),
+ gr.Textbox(value=getattr(settings,
f"qianfan_text2gql_secret_key"), label="secret_key", type="password"),
+ gr.Textbox(value=getattr(settings,
f"qianfan_text2gql_language_model"), label="model_name"),
+ gr.Textbox(value="", visible=False),
+ ]
+ else:
+ llm_config_input = [gr.Textbox(value="", visible=False)
for _ in range(4)]
+ llm_config_button = gr.Button("Apply configuration")
+ llm_config_button.click(apply_llm_config_with_text2gql_op,
inputs=llm_config_input)
with gr.Accordion("3. Set up the Embedding.", open=False):
embedding_dropdown = gr.Dropdown(
- choices=["openai", "qianfan_wenxin", "ollama"],
value=settings.embedding_type, label="Embedding"
+ choices=["openai", "qianfan_wenxin", "ollama/local"],
value=settings.embedding_type, label="Embedding"
)
@gr.render(inputs=[embedding_dropdown])
@@ -258,18 +336,18 @@ def create_configs_block() -> list:
gr.Textbox(value=settings.openai_embedding_api_base,
label="api_base"),
gr.Textbox(value=settings.openai_embedding_model,
label="model_name"),
]
- elif embedding_type == "ollama":
+ elif embedding_type == "ollama/local":
with gr.Row():
embedding_config_input = [
- gr.Textbox(value=settings.ollama_host, label="host"),
- gr.Textbox(value=str(settings.ollama_port),
label="port"),
+ gr.Textbox(value=settings.ollama_embedding_host,
label="host"),
+ gr.Textbox(value=str(settings.ollama_embedding_port),
label="port"),
gr.Textbox(value=settings.ollama_embedding_model,
label="model_name"),
]
elif embedding_type == "qianfan_wenxin":
with gr.Row():
embedding_config_input = [
- gr.Textbox(value=settings.qianfan_api_key,
label="api_key", type="password"),
- gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key", type="password"),
+ gr.Textbox(value=settings.qianfan_embedding_api_key,
label="api_key", type="password"),
+
gr.Textbox(value=settings.qianfan_embedding_secret_key, label="secret_key",
type="password"),
gr.Textbox(value=settings.qianfan_embedding_model,
label="model_name"),
]
else:
diff --git
a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py
index ded48af..63ea7ab 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py
@@ -33,17 +33,17 @@ class Embeddings:
api_key=settings.openai_embedding_api_key,
api_base=settings.openai_embedding_api_base
)
- if self.embedding_type == "ollama":
+ if self.embedding_type == "ollama/local":
return OllamaEmbedding(
model=settings.ollama_embedding_model,
- host=settings.ollama_host,
- port=settings.ollama_port
+ host=settings.ollama_embedding_host,
+ port=settings.ollama_embedding_port
)
if self.embedding_type == "qianfan_wenxin":
return QianFanEmbedding(
model_name=settings.qianfan_embedding_model,
- api_key=settings.qianfan_api_key,
- secret_key=settings.qianfan_secret_key
+ api_key=settings.qianfan_embedding_api_key,
+ secret_key=settings.qianfan_embedding_secret_key
)
raise Exception("embedding type is not supported !")
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
index 2c90748..cb7e73d 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
@@ -24,28 +24,78 @@ from hugegraph_llm.config import settings
class LLMs:
def __init__(self):
- self.llm_type = settings.llm_type
+ self.chat_llm_type = settings.chat_llm_type
+ self.extract_llm_type = settings.extract_llm_type
+ self.text2gql_llm_type = settings.text2gql_llm_type
- def get_llm(self):
- if self.llm_type == "qianfan_wenxin":
+ def get_chat_llm(self):
+ if self.chat_llm_type == "qianfan_wenxin":
return QianfanClient(
- model_name=settings.qianfan_language_model,
- api_key=settings.qianfan_api_key,
- secret_key=settings.qianfan_secret_key
+ model_name=settings.qianfan_chat_language_model,
+ api_key=settings.qianfan_chat_api_key,
+ secret_key=settings.qianfan_chat_secret_key
)
- if self.llm_type == "openai":
+ if self.chat_llm_type == "openai":
return OpenAIClient(
- api_key=settings.openai_api_key,
- api_base=settings.openai_api_base,
- model_name=settings.openai_language_model,
- max_tokens=settings.openai_max_tokens,
+ api_key=settings.openai_chat_api_key,
+ api_base=settings.openai_chat_api_base,
+ model_name=settings.openai_chat_language_model,
+ max_tokens=settings.openai_chat_tokens,
)
- if self.llm_type == "ollama":
- return OllamaClient(model=settings.ollama_language_model)
- raise Exception("llm type is not supported !")
+ if self.chat_llm_type == "ollama/local":
+ return OllamaClient(
+ model=settings.ollama_chat_language_model,
+ host=settings.ollama_chat_host,
+ port=settings.ollama_chat_port,
+ )
+ raise Exception("chat llm type is not supported !")
+
+ def get_extract_llm(self):
+ if self.extract_llm_type == "qianfan_wenxin":
+ return QianfanClient(
+ model_name=settings.qianfan_extract_language_model,
+ api_key=settings.qianfan_extract_api_key,
+ secret_key=settings.qianfan_extract_secret_key
+ )
+ if self.extract_llm_type == "openai":
+ return OpenAIClient(
+ api_key=settings.openai_extract_api_key,
+ api_base=settings.openai_extract_api_base,
+ model_name=settings.openai_extract_language_model,
+ max_tokens=settings.openai_extract_tokens,
+ )
+ if self.extract_llm_type == "ollama/local":
+ return OllamaClient(
+ model=settings.ollama_extract_language_model,
+ host=settings.ollama_extract_host,
+ port=settings.ollama_extract_port,
+ )
+ raise Exception("extract llm type is not supported !")
+
+ def get_text2gql_llm(self):
+ if self.text2gql_llm_type == "qianfan_wenxin":
+ return QianfanClient(
+ model_name=settings.qianfan_text2gql_language_model,
+ api_key=settings.qianfan_text2gql_api_key,
+ secret_key=settings.qianfan_text2gql_secret_key
+ )
+ if self.text2gql_llm_type == "openai":
+ return OpenAIClient(
+ api_key=settings.openai_text2gql_api_key,
+ api_base=settings.openai_text2gql_api_base,
+ model_name=settings.openai_text2gql_language_model,
+ max_tokens=settings.openai_text2gql_tokens,
+ )
+ if self.text2gql_llm_type == "ollama/local":
+ return OllamaClient(
+ model=settings.ollama_text2gql_language_model,
+ host=settings.ollama_text2gql_host,
+ port=settings.ollama_text2gql_port,
+ )
+ raise Exception("text2gql llm type is not supported !")
if __name__ == "__main__":
- client = LLMs().get_llm()
+ client = LLMs().get_chat_llm()
print(client.generate(prompt="What is the capital of China?"))
print(client.generate(messages=[{"role": "user", "content": "What is the
capital of China?"}]))
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
index b7b0148..62f5ef2 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
@@ -121,4 +121,4 @@ class OllamaClient(BaseLLM):
def get_llm_type(self) -> str:
"""Returns the type of the LLM"""
- return "ollama"
+ return "ollama/local"
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
index 0f585cb..895a379 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
@@ -45,7 +45,7 @@ class WordExtract:
context["query"] = self._query
if self._llm is None:
- self._llm = LLMs().get_llm()
+ self._llm = LLMs().get_extract_llm()
assert isinstance(self._llm, BaseLLM), "Invalid LLM Object."
if isinstance(context.get("language"), str):
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 789ec20..0dd4d7d 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -46,7 +46,9 @@ class RAGPipeline:
:param llm: Optional LLM model to use.
:param embedding: Optional embedding model to use.
"""
- self._llm = llm or LLMs().get_llm()
+ self._chat_llm = llm or LLMs().get_chat_llm()
+ self._extract_llm = llm or LLMs().get_extract_llm()
+ self._text2gqlt_llm = llm or LLMs().get_text2gql_llm()
self._embedding = embedding or Embeddings().get_embedding()
self._operators: List[Any] = []
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index 149b7ee..666ecf9 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -60,7 +60,7 @@ class AnswerSynthesize:
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if self._llm is None:
- self._llm = LLMs().get_llm()
+ self._llm = LLMs().get_chat_llm()
if self._question is None:
self._question = context.get("query") or None
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
index 828e394..c47a79d 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -66,7 +66,7 @@ class KeywordExtract:
context["query"] = self._query
if self._llm is None:
- self._llm = LLMs().get_llm()
+ self._llm = LLMs().get_extract_llm()
assert isinstance(self._llm, BaseLLM), "Invalid LLM Object."
if isinstance(context.get("language"), str):
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
index a8ea115..73d7057 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
@@ -34,7 +34,7 @@ from ..operators.kg_construction_task import KgBuilder
def get_graph_index_info():
- builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+ builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(),
get_hg_client())
context = builder.fetch_graph_data().run()
vector_index = VectorIndex.from_index_file(str(os.path.join(resource_path,
settings.graph_name, "graph_vids")))
context["vid_index"] = {
@@ -54,7 +54,7 @@ def clean_all_graph_index():
def extract_graph(input_file, input_text, schema, example_prompt) -> str:
texts = read_documents(input_file, input_text)
- builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+ builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(),
get_hg_client())
if schema:
try:
@@ -77,7 +77,7 @@ def extract_graph(input_file, input_text, schema,
example_prompt) -> str:
def fit_vid_index():
- builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+ builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(),
get_hg_client())
builder.fetch_graph_data().build_vertex_id_semantic_index()
log.debug("Operators: %s", builder.operators)
try:
@@ -94,7 +94,7 @@ def import_graph_data(data: str, schema: str) -> Union[str,
Dict[str, Any]]:
try:
data_json = json.loads(data.strip())
log.debug("Import graph data: %s", data)
- builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+ builder = KgBuilder(LLMs().get_chat_llm(),
Embeddings().get_embedding(), get_hg_client())
if schema:
try:
schema = json.loads(schema.strip())
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
index a7afdf8..e955aac 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
@@ -71,6 +71,6 @@ def clean_vector_index():
def build_vector_index(input_file, input_text):
texts = read_documents(input_file, input_text)
- builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+ builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(),
get_hg_client())
context = builder.chunk_split(texts, "paragraph",
"zh").build_vector_index().run()
return json.dumps(context, ensure_ascii=False, indent=2)