This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 6dd9cd7e refactor(llm): replace QianFan by OpenAI-compatible format
(#285)
6dd9cd7e is described below
commit 6dd9cd7e14cc96a6b155ed6ee0611c706e1ad50b
Author: day0n <[email protected]>
AuthorDate: Fri Jul 25 17:42:13 2025 +0800
refactor(llm): replace QianFan by OpenAI-compatible format (#285)
Co-authored-by: imbajin <[email protected]>
---
hugegraph-llm/pyproject.toml | 5 +-
.../src/hugegraph_llm/api/models/rag_requests.py | 4 +-
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 5 -
.../src/hugegraph_llm/config/llm_config.py | 32 +-----
.../src/hugegraph_llm/demo/rag_demo/app.py | 2 -
.../hugegraph_llm/demo/rag_demo/configs_block.py | 102 ++++-------------
.../models/embeddings/init_embedding.py | 7 --
.../src/hugegraph_llm/models/embeddings/qianfan.py | 68 -----------
.../src/hugegraph_llm/models/llms/init_llm.py | 19 ----
.../src/hugegraph_llm/models/llms/qianfan.py | 124 ---------------------
pyproject.toml | 1 -
11 files changed, 30 insertions(+), 339 deletions(-)
diff --git a/hugegraph-llm/pyproject.toml b/hugegraph-llm/pyproject.toml
index dba36ca8..2ed43896 100644
--- a/hugegraph-llm/pyproject.toml
+++ b/hugegraph-llm/pyproject.toml
@@ -34,16 +34,15 @@ dependencies = [
"setuptools",
"urllib3",
"rich",
-
+
# Data processing dependencies
"numpy",
"pandas",
"pydantic",
-
+
# LLM specific dependencies
"openai",
"ollama",
- "qianfan",
"retry",
"tiktoken",
"nltk",
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 3170e702..89bdd9bc 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -91,15 +91,13 @@ class GraphRAGRequest(BaseModel):
class LLMConfigRequest(BaseModel):
llm_type: str
- # The common parameters shared by OpenAI, Qianfan Wenxin,
+ # The common parameters shared by OpenAI, LiteLLM,
# and OLLAMA platforms.
api_key: str
api_base: str
language_model: str
# Openai-only properties
max_tokens: str = None
- # qianfan-wenxin-only properties
- secret_key: str = None
# ollama-only properties
# TODO: replace to url later
host: str = None
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 2621220c..b7e8a6f7 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -124,7 +124,6 @@ def rag_http_api(
]
user_result = {key: result[key] for key in params if key in
result}
return {"graph_recall": user_result}
- # Note: Maybe only for qianfan/wenxin
return {"graph_recall": json.dumps(result)}
except TypeError as e:
@@ -149,8 +148,6 @@ def rag_http_api(
if req.llm_type == "openai":
res = apply_llm_conf(req.api_key, req.api_base,
req.language_model, req.max_tokens, origin_call="http")
- elif req.llm_type == "qianfan_wenxin":
- res = apply_llm_conf(req.api_key, req.secret_key,
req.language_model, None, origin_call="http")
else:
res = apply_llm_conf(req.host, req.port, req.language_model, None,
origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
@@ -161,8 +158,6 @@ def rag_http_api(
if req.llm_type == "openai":
res = apply_embedding_conf(req.api_key, req.api_base,
req.language_model, origin_call="http")
- elif req.llm_type == "qianfan_wenxin":
- res = apply_embedding_conf(req.api_key, req.api_base, None,
origin_call="http")
else:
res = apply_embedding_conf(req.host, req.port, req.language_model,
origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
diff --git a/hugegraph-llm/src/hugegraph_llm/config/llm_config.py
b/hugegraph-llm/src/hugegraph_llm/config/llm_config.py
index a9b4b2d4..ff738454 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/llm_config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/llm_config.py
@@ -25,10 +25,10 @@ from .models import BaseConfig
class LLMConfig(BaseConfig):
"""LLM settings"""
- chat_llm_type: Literal["openai", "litellm", "ollama/local",
"qianfan_wenxin"] = "openai"
- extract_llm_type: Literal["openai", "litellm", "ollama/local",
"qianfan_wenxin"] = "openai"
- text2gql_llm_type: Literal["openai", "litellm", "ollama/local",
"qianfan_wenxin"] = "openai"
- embedding_type: Optional[Literal["openai", "litellm", "ollama/local",
"qianfan_wenxin"]] = "openai"
+ chat_llm_type: Literal["openai", "litellm", "ollama/local"] = "openai"
+ extract_llm_type: Literal["openai", "litellm", "ollama/local"] = "openai"
+ text2gql_llm_type: Literal["openai", "litellm", "ollama/local"] = "openai"
+ embedding_type: Optional[Literal["openai", "litellm", "ollama/local"]] =
"openai"
reranker_type: Optional[Literal["cohere", "siliconflow"]] = None
# 1. OpenAI settings
openai_chat_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL",
"https://api.openai.com/v1")
@@ -63,29 +63,7 @@ class LLMConfig(BaseConfig):
ollama_embedding_host: Optional[str] = "127.0.0.1"
ollama_embedding_port: Optional[int] = 11434
ollama_embedding_model: Optional[str] = None
- # 4. QianFan/WenXin settings
- # TODO: update to one token key mode
- qianfan_chat_api_key: Optional[str] = None
- qianfan_chat_secret_key: Optional[str] = None
- qianfan_chat_access_token: Optional[str] = None
- qianfan_extract_api_key: Optional[str] = None
- qianfan_extract_secret_key: Optional[str] = None
- qianfan_extract_access_token: Optional[str] = None
- qianfan_text2gql_api_key: Optional[str] = None
- qianfan_text2gql_secret_key: Optional[str] = None
- qianfan_text2gql_access_token: Optional[str] = None
- qianfan_embedding_api_key: Optional[str] = None
- qianfan_embedding_secret_key: Optional[str] = None
- # 4.1 URL settings
- qianfan_url_prefix: Optional[str] =
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
- qianfan_chat_url: Optional[str] = qianfan_url_prefix + "/chat/"
- qianfan_chat_language_model: Optional[str] = "ERNIE-Speed-128K"
- qianfan_extract_language_model: Optional[str] = "ERNIE-Speed-128K"
- qianfan_text2gql_language_model: Optional[str] = "ERNIE-Speed-128K"
- qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/"
- # refer https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu to get more
details
- qianfan_embedding_model: Optional[str] = "embedding-v1"
- # 5. LiteLLM settings
+ # 4. LiteLLM settings
litellm_chat_api_key: Optional[str] = None
litellm_chat_api_base: Optional[str] = None
litellm_chat_language_model: Optional[str] = "openai/gpt-4.1-mini"
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
index 880ac406..c451c65c 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
@@ -72,13 +72,11 @@ def init_rag_ui() -> gr.Interface:
llm_config_input = textbox_array_llm_config
= if settings.llm_type == openai [settings.openai_api_key,
settings.openai_api_base, settings.openai_language_model,
settings.openai_max_tokens]
= else if settings.llm_type == ollama [settings.ollama_host,
settings.ollama_port, settings.ollama_language_model, ""]
- = else if settings.llm_type == qianfan_wenxin
[settings.qianfan_api_key, settings.qianfan_secret_key,
settings.qianfan_language_model, ""]
= else ["","","", ""]
embedding_config_input = textbox_array_embedding_config
= if settings.embedding_type == openai [settings.openai_api_key,
settings.openai_api_base, settings.openai_embedding_model]
= else if settings.embedding_type == ollama [settings.ollama_host,
settings.ollama_port, settings.ollama_embedding_model]
- = else if settings.embedding_type == qianfan_wenxin
[settings.qianfan_api_key, settings.qianfan_secret_key,
settings.qianfan_embedding_model]
= else ["","",""]
reranker_config_input = textbox_array_reranker_config
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
index cb367770..f872ff52 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
@@ -97,22 +97,6 @@ def test_api_connection(url, method="GET", headers=None,
params=None, body=None,
return resp.status_code
-def config_qianfan_model(arg1, arg2, arg3=None, settings_prefix=None,
origin_call=None) -> int:
- setattr(llm_settings, f"qianfan_{settings_prefix}_api_key", arg1)
- setattr(llm_settings, f"qianfan_{settings_prefix}_secret_key", arg2)
- if arg3:
- setattr(llm_settings, f"qianfan_{settings_prefix}_language_model",
arg3)
- params = {
- "grant_type": "client_credentials",
- "client_id": arg1,
- "client_secret": arg2,
- }
- status_code = test_api_connection(
- "https://aip.baidubce.com/oauth/2.0/token", "POST", params=params,
origin_call=origin_call
- )
- return status_code
-
-
def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int:
status_code = -1
embedding_option = llm_settings.embedding_type
@@ -124,9 +108,6 @@ def apply_embedding_config(arg1, arg2, arg3,
origin_call=None) -> int:
headers = {"Authorization": f"Bearer {arg1}"}
data = {"model": arg3, "input": "test"}
status_code = test_api_connection(test_url, method="POST",
headers=headers, body=data, origin_call=origin_call)
- elif embedding_option == "qianfan_wenxin":
- status_code = config_qianfan_model(arg1, arg2,
settings_prefix="embedding", origin_call=origin_call)
- llm_settings.qianfan_embedding_model = arg3
elif embedding_option == "ollama/local":
llm_settings.ollama_embedding_host = arg1
llm_settings.ollama_embedding_port = int(arg2)
@@ -181,7 +162,7 @@ def apply_reranker_config(
def apply_graph_config(url, name, user, pwd, gs, origin_call=None) -> int:
- # Add URL prefix automatically to improve user experience
+ # Add URL prefix automatically to improve the user experience
if url and not (url.startswith('http://') or url.startswith('https://')):
url = f"http://{url}"
@@ -202,45 +183,41 @@ def apply_graph_config(url, name, user, pwd, gs,
origin_call=None) -> int:
return response
-# Different llm models have different parameters, so no meaningful argument
names are given here
-def apply_llm_config(current_llm_config, arg1, arg2, arg3, arg4,
origin_call=None) -> int:
+def apply_llm_config(current_llm_config, api_key_or_host, api_base_or_port,
model_name, max_tokens,
+ origin_call=None) -> int:
log.debug("current llm in apply_llm_config is %s", current_llm_config)
llm_option = getattr(llm_settings, f"{current_llm_config}_llm_type")
log.debug("llm option in apply_llm_config is %s", llm_option)
status_code = -1
if llm_option == "openai":
- setattr(llm_settings, f"openai_{current_llm_config}_api_key", arg1)
- setattr(llm_settings, f"openai_{current_llm_config}_api_base", arg2)
- setattr(llm_settings, f"openai_{current_llm_config}_language_model",
arg3)
- setattr(llm_settings, f"openai_{current_llm_config}_tokens", int(arg4))
+ setattr(llm_settings, f"openai_{current_llm_config}_api_key",
api_key_or_host)
+ setattr(llm_settings, f"openai_{current_llm_config}_api_base",
api_base_or_port)
+ setattr(llm_settings, f"openai_{current_llm_config}_language_model",
model_name)
+ setattr(llm_settings, f"openai_{current_llm_config}_tokens",
int(max_tokens))
test_url = getattr(llm_settings,
f"openai_{current_llm_config}_api_base") + "/chat/completions"
data = {
- "model": arg3,
+ "model": model_name,
"temperature": 0.01,
"messages": [{"role": "user", "content": "test"}],
}
- headers = {"Authorization": f"Bearer {arg1}"}
+ headers = {"Authorization": f"Bearer {api_key_or_host}"}
status_code = test_api_connection(test_url, method="POST",
headers=headers, body=data, origin_call=origin_call)
- elif llm_option == "qianfan_wenxin":
- status_code = config_qianfan_model(arg1, arg2, arg3,
settings_prefix=current_llm_config,
- origin_call=origin_call) # pylint:
disable=C0301
-
elif llm_option == "ollama/local":
- setattr(llm_settings, f"ollama_{current_llm_config}_host", arg1)
- setattr(llm_settings, f"ollama_{current_llm_config}_port", int(arg2))
- setattr(llm_settings, f"ollama_{current_llm_config}_language_model",
arg3)
- status_code = test_api_connection(f"http://{arg1}:{arg2}",
origin_call=origin_call)
+ setattr(llm_settings, f"ollama_{current_llm_config}_host",
api_key_or_host)
+ setattr(llm_settings, f"ollama_{current_llm_config}_port",
int(api_base_or_port))
+ setattr(llm_settings, f"ollama_{current_llm_config}_language_model",
model_name)
+ status_code =
test_api_connection(f"http://{api_key_or_host}:{api_base_or_port}",
origin_call=origin_call)
elif llm_option == "litellm":
- setattr(llm_settings, f"litellm_{current_llm_config}_api_key", arg1)
- setattr(llm_settings, f"litellm_{current_llm_config}_api_base", arg2)
- setattr(llm_settings, f"litellm_{current_llm_config}_language_model",
arg3)
- setattr(llm_settings, f"litellm_{current_llm_config}_tokens",
int(arg4))
+ setattr(llm_settings, f"litellm_{current_llm_config}_api_key",
api_key_or_host)
+ setattr(llm_settings, f"litellm_{current_llm_config}_api_base",
api_base_or_port)
+ setattr(llm_settings, f"litellm_{current_llm_config}_language_model",
model_name)
+ setattr(llm_settings, f"litellm_{current_llm_config}_tokens",
int(max_tokens))
- status_code = test_litellm_chat(arg1, arg2, arg3, int(arg4))
+ status_code = test_litellm_chat(api_key_or_host, api_base_or_port,
model_name, int(max_tokens))
gr.Info("Configured!")
llm_settings.update_env()
@@ -273,7 +250,7 @@ def create_configs_block() -> list:
gr.Markdown("> Tips: The OpenAI option also support openai style api
from other providers. "
"**Refresh the page** to load the **latest configs** in
__UI__.")
with gr.Tab(label='chat'):
- chat_llm_dropdown = gr.Dropdown(choices=["openai", "litellm",
"qianfan_wenxin", "ollama/local"],
+ chat_llm_dropdown = gr.Dropdown(choices=["openai", "litellm",
"ollama/local"],
value=getattr(llm_settings,
"chat_llm_type"), label="type")
apply_llm_config_with_chat_op = partial(apply_llm_config, "chat")
@@ -295,15 +272,6 @@ def create_configs_block() -> list:
gr.Textbox(value=getattr(llm_settings,
"ollama_chat_language_model"), label="model_name"),
gr.Textbox(value="", visible=False),
]
- elif llm_type == "qianfan_wenxin":
- llm_config_input = [
- gr.Textbox(value=getattr(llm_settings,
"qianfan_chat_api_key"), label="api_key",
- type="password"),
- gr.Textbox(value=getattr(llm_settings,
"qianfan_chat_secret_key"), label="secret_key",
- type="password"),
- gr.Textbox(value=getattr(llm_settings,
"qianfan_chat_language_model"), label="model_name"),
- gr.Textbox(value="", visible=False),
- ]
elif llm_type == "litellm":
llm_config_input = [
gr.Textbox(value=getattr(llm_settings,
"litellm_chat_api_key"), label="api_key",
@@ -328,7 +296,7 @@ def create_configs_block() -> list:
if not api_text2sql_key:
llm_config_button.click(apply_llm_config_with_extract_op,
inputs=llm_config_input)
with gr.Tab(label='mini_tasks'):
- extract_llm_dropdown = gr.Dropdown(choices=["openai", "litellm",
"qianfan_wenxin", "ollama/local"],
+ extract_llm_dropdown = gr.Dropdown(choices=["openai", "litellm",
"ollama/local"],
value=getattr(llm_settings,
"extract_llm_type"), label="type")
apply_llm_config_with_extract_op = partial(apply_llm_config,
"extract")
@@ -350,15 +318,6 @@ def create_configs_block() -> list:
gr.Textbox(value=getattr(llm_settings,
"ollama_extract_language_model"), label="model_name"),
gr.Textbox(value="", visible=False),
]
- elif llm_type == "qianfan_wenxin":
- llm_config_input = [
- gr.Textbox(value=getattr(llm_settings,
"qianfan_extract_api_key"), label="api_key",
- type="password"),
- gr.Textbox(value=getattr(llm_settings,
"qianfan_extract_secret_key"), label="secret_key",
- type="password"),
- gr.Textbox(value=getattr(llm_settings,
"qianfan_extract_language_model"), label="model_name"),
- gr.Textbox(value="", visible=False),
- ]
elif llm_type == "litellm":
llm_config_input = [
gr.Textbox(value=getattr(llm_settings,
"litellm_extract_api_key"), label="api_key",
@@ -374,7 +333,7 @@ def create_configs_block() -> list:
llm_config_button = gr.Button("Apply configuration")
llm_config_button.click(apply_llm_config_with_extract_op,
inputs=llm_config_input)
with gr.Tab(label='text2gql'):
- text2gql_llm_dropdown = gr.Dropdown(choices=["openai", "litellm",
"qianfan_wenxin", "ollama/local"],
+ text2gql_llm_dropdown = gr.Dropdown(choices=["openai", "litellm",
"ollama/local"],
value=getattr(llm_settings,
"text2gql_llm_type"), label="type")
apply_llm_config_with_text2gql_op = partial(apply_llm_config,
"text2gql")
@@ -396,15 +355,6 @@ def create_configs_block() -> list:
gr.Textbox(value=getattr(llm_settings,
"ollama_text2gql_language_model"), label="model_name"),
gr.Textbox(value="", visible=False),
]
- elif llm_type == "qianfan_wenxin":
- llm_config_input = [
- gr.Textbox(value=getattr(llm_settings,
"qianfan_text2gql_api_key"), label="api_key",
- type="password"),
- gr.Textbox(value=getattr(llm_settings,
"qianfan_text2gql_secret_key"), label="secret_key",
- type="password"),
- gr.Textbox(value=getattr(llm_settings,
"qianfan_text2gql_language_model"), label="model_name"),
- gr.Textbox(value="", visible=False),
- ]
elif llm_type == "litellm":
llm_config_input = [
gr.Textbox(value=getattr(llm_settings,
"litellm_text2gql_api_key"), label="api_key",
@@ -422,7 +372,7 @@ def create_configs_block() -> list:
with gr.Accordion("3. Set up the Embedding.", open=False):
embedding_dropdown = gr.Dropdown(
- choices=["openai", "litellm", "qianfan_wenxin", "ollama/local"],
value=llm_settings.embedding_type,
+ choices=["openai", "litellm", "ollama/local"],
value=llm_settings.embedding_type,
label="Embedding"
)
@@ -443,14 +393,6 @@ def create_configs_block() -> list:
gr.Textbox(value=str(llm_settings.ollama_embedding_port), label="port"),
gr.Textbox(value=llm_settings.ollama_embedding_model,
label="model_name"),
]
- elif embedding_type == "qianfan_wenxin":
- with gr.Row():
- embedding_config_input = [
-
gr.Textbox(value=llm_settings.qianfan_embedding_api_key, label="api_key",
type="password"),
-
gr.Textbox(value=llm_settings.qianfan_embedding_secret_key, label="secret_key",
- type="password"),
- gr.Textbox(value=llm_settings.qianfan_embedding_model,
label="model_name"),
- ]
elif embedding_type == "litellm":
with gr.Row():
embedding_config_input = [
diff --git
a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py
index 4d9a14fd..b7d37996 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py
@@ -18,7 +18,6 @@
from hugegraph_llm.models.embeddings.openai import OpenAIEmbedding
from hugegraph_llm.models.embeddings.ollama import OllamaEmbedding
-from hugegraph_llm.models.embeddings.qianfan import QianFanEmbedding
from hugegraph_llm.models.embeddings.litellm import LiteLLMEmbedding
from hugegraph_llm.config import llm_settings
@@ -40,12 +39,6 @@ class Embeddings:
host=llm_settings.ollama_embedding_host,
port=llm_settings.ollama_embedding_port
)
- if self.embedding_type == "qianfan_wenxin":
- return QianFanEmbedding(
- model_name=llm_settings.qianfan_embedding_model,
- api_key=llm_settings.qianfan_embedding_api_key,
- secret_key=llm_settings.qianfan_embedding_secret_key
- )
if self.embedding_type == "litellm":
return LiteLLMEmbedding(
model_name=llm_settings.litellm_embedding_model,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
deleted file mode 100644
index 99eeb591..00000000
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-from typing import Optional, List
-
-import qianfan
-
-from hugegraph_llm.config import llm_settings
-
-"""
-"QianFan" platform can be understood as a unified LLM platform that
encompasses the
-WenXin large model along with other
-common open-source models.
-
-It enables the invocation and switching between WenXin and these open-source
models.
-"""
-
-
-class QianFanEmbedding:
- def __init__(
- self,
- model_name: str = "embedding-v1",
- api_key: Optional[str] = None,
- secret_key: Optional[str] = None
- ):
- qianfan.get_config().AK = api_key or
llm_settings.qianfan_embedding_api_key
- qianfan.get_config().SK = secret_key or
llm_settings.qianfan_embedding_secret_key
- self.embedding_model_name = model_name
- self.client = qianfan.Embedding()
-
- def get_text_embedding(self, text: str) -> List[float]:
- """ Usage refer:
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn"""
- response = self.client.do(
- model=self.embedding_model_name,
- texts=[text]
- )
- return response["body"]["data"][0]["embedding"]
-
- def get_texts_embeddings(self, texts: List[str]) -> List[List[float]]:
- """ Usage refer:
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn"""
- response = self.client.do(
- model=self.embedding_model_name,
- texts=texts
- )
- return [data["embedding"] for data in response["body"]["data"]]
-
- async def async_get_text_embedding(self, text: str) -> List[float]:
- """ Usage refer:
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn"""
- response = await self.client.ado(
- model=self.embedding_model_name,
- texts=[text]
- )
- return response["body"]["data"][0]["embedding"]
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
index 58eb4799..e70b0d9d 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
@@ -18,7 +18,6 @@
from hugegraph_llm.models.llms.ollama import OllamaClient
from hugegraph_llm.models.llms.openai import OpenAIClient
-from hugegraph_llm.models.llms.qianfan import QianfanClient
from hugegraph_llm.models.llms.litellm import LiteLLMClient
from hugegraph_llm.config import llm_settings
@@ -30,12 +29,6 @@ class LLMs:
self.text2gql_llm_type = llm_settings.text2gql_llm_type
def get_chat_llm(self):
- if self.chat_llm_type == "qianfan_wenxin":
- return QianfanClient(
- model_name=llm_settings.qianfan_chat_language_model,
- api_key=llm_settings.qianfan_chat_api_key,
- secret_key=llm_settings.qianfan_chat_secret_key
- )
if self.chat_llm_type == "openai":
return OpenAIClient(
api_key=llm_settings.openai_chat_api_key,
@@ -59,12 +52,6 @@ class LLMs:
raise Exception("chat llm type is not supported !")
def get_extract_llm(self):
- if self.extract_llm_type == "qianfan_wenxin":
- return QianfanClient(
- model_name=llm_settings.qianfan_extract_language_model,
- api_key=llm_settings.qianfan_extract_api_key,
- secret_key=llm_settings.qianfan_extract_secret_key
- )
if self.extract_llm_type == "openai":
return OpenAIClient(
api_key=llm_settings.openai_extract_api_key,
@@ -88,12 +75,6 @@ class LLMs:
raise Exception("extract llm type is not supported !")
def get_text2gql_llm(self):
- if self.text2gql_llm_type == "qianfan_wenxin":
- return QianfanClient(
- model_name=llm_settings.qianfan_text2gql_language_model,
- api_key=llm_settings.qianfan_text2gql_api_key,
- secret_key=llm_settings.qianfan_text2gql_secret_key
- )
if self.text2gql_llm_type == "openai":
return OpenAIClient(
api_key=llm_settings.openai_text2gql_api_key,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
deleted file mode 100644
index 2d306ac4..00000000
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import json
-from typing import AsyncGenerator, Generator, Optional, List, Dict, Any,
Callable
-
-import qianfan
-from retry import retry
-
-from hugegraph_llm.config import llm_settings
-from hugegraph_llm.models.llms.base import BaseLLM
-from hugegraph_llm.utils.log import log
-
-
-class QianfanClient(BaseLLM):
- def __init__(self, model_name: Optional[str] = "ernie-4.5-8k-preview",
- api_key: Optional[str] = None, secret_key: Optional[str] =
None):
- qianfan.get_config().AK = api_key or llm_settings.qianfan_chat_api_key
- qianfan.get_config().SK = secret_key or
llm_settings.qianfan_chat_secret_key
- self.chat_model = model_name
- self.chat_comp = qianfan.ChatCompletion()
-
- @retry(tries=3, delay=1)
- def generate(
- self,
- messages: Optional[List[Dict[str, Any]]] = None,
- prompt: Optional[str] = None,
- ) -> str:
- if messages is None:
- assert prompt is not None, "Messages or prompt must be provided."
- messages = [{"role": "user", "content": prompt}]
-
- response = self.chat_comp.do(model=self.chat_model, messages=messages)
- if response.code != 200:
- raise Exception(
- f"Request failed with code {response.code}, message:
{response.body['error_msg']}"
- )
- log.info("Token usage: %s", json.dumps(response.body["usage"]))
- return response.body["result"]
-
- @retry(tries=3, delay=1)
- async def agenerate(
- self,
- messages: Optional[List[Dict[str, Any]]] = None,
- prompt: Optional[str] = None,
- ) -> str:
- if messages is None:
- assert prompt is not None, "Messages or prompt must be provided."
- messages = [{"role": "user", "content": prompt}]
-
- response = await self.chat_comp.ado(model=self.chat_model,
messages=messages)
- if response.code != 200:
- raise Exception(
- f"Request failed with code {response.code}, message:
{response.body['error_msg']}"
- )
- log.info("Token usage: %s", json.dumps(response.body["usage"]))
- return response.body["result"]
-
- def generate_streaming(
- self,
- messages: Optional[List[Dict[str, Any]]] = None,
- prompt: Optional[str] = None,
- on_token_callback: Optional[Callable] = None,
- ) -> Generator[str, None, None]:
- if messages is None:
- assert prompt is not None, "Messages or prompt must be provided."
- messages = [{"role": "user", "content": prompt}]
-
- for msg in self.chat_comp.do(messages=messages, model=self.chat_model,
stream=True):
- token = msg.body['result']
- if on_token_callback:
- on_token_callback(token)
- yield token
-
- async def agenerate_streaming(
- self,
- messages: Optional[List[Dict[str, Any]]] = None,
- prompt: Optional[str] = None,
- on_token_callback: Optional[Callable] = None,
- ) -> AsyncGenerator[str, None]:
- if messages is None:
- assert prompt is not None, "Messages or prompt must be provided."
- messages = [{"role": "user", "content": prompt}]
-
- try:
- async_generator = await self.chat_comp.ado(messages=messages,
model=self.chat_model, stream=True)
- async for msg in async_generator:
- chunk = msg.body['result']
- if on_token_callback:
- on_token_callback(chunk)
- yield chunk
- except Exception as e:
- print(f"Retrying LLM call {e}")
- raise e
-
- def num_tokens_from_string(self, string: str) -> int:
- return len(string)
-
- def max_allowed_token_length(self) -> int:
- # TODO: replace with config way
- return 6000
-
- def get_llm_type(self) -> str:
- return "qianfan_wenxin"
-
-
-if __name__ == "__main__":
- client = QianfanClient()
- print(client.generate(prompt="What is the capital of China?"))
- print(client.generate(messages=[{"role": "user", "content": "What is the
capital of China?"}]))
diff --git a/pyproject.toml b/pyproject.toml
index 012a1607..9e1624b4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -108,7 +108,6 @@ constraint-dependencies = [
# LLM dependencies
"openai~=1.61.0",
"ollama~=0.4.8",
- "qianfan~=0.3.18", # TODO: remove it
"retry~=0.9.2",
"tiktoken~=0.7.0",
"nltk~=3.9.1",