This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 7ae5d6f feat(llm): support async streaming output in RAG answer block
(#190)
7ae5d6f is described below
commit 7ae5d6fcc1013bc39164672355a8270f078bff8d
Author: vichayturen <[email protected]>
AuthorDate: Thu Mar 6 18:21:42 2025 +0800
feat(llm): support async streaming output in RAG answer block (#190)
follow #172
In order to achieve asynchronization, we compromised by changing
`gremlin_generate_operator` to a synchronous generation mode. This can be
changed back to an asynchronous mode after achieving full asynchronization in
the subsequent agentization process.
---------
Co-authored-by: chenzihong <[email protected]>
Co-authored-by: chenzihong
<[email protected]>
Co-authored-by: imbajin <[email protected]>
---
.../src/hugegraph_llm/demo/rag_demo/admin_block.py | 2 +-
.../src/hugegraph_llm/demo/rag_demo/rag_block.py | 164 ++++++++++++++++-----
.../src/hugegraph_llm/models/llms/base.py | 15 +-
.../src/hugegraph_llm/models/llms/litellm.py | 31 +++-
.../src/hugegraph_llm/models/llms/ollama.py | 47 ++++--
.../src/hugegraph_llm/models/llms/openai.py | 104 ++++++++++---
.../src/hugegraph_llm/models/llms/qianfan.py | 37 ++++-
.../operators/llm_op/answer_synthesize.py | 156 ++++++++++++++++----
.../operators/llm_op/gremlin_generate.py | 32 +++-
9 files changed, 484 insertions(+), 104 deletions(-)
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py
index b8c1852..2d5937a 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py
@@ -50,7 +50,7 @@ async def log_stream(log_path: str, lines: int = 125):
def read_llm_server_log(lines=250):
log_path = "logs/llm-server.log"
try:
- with open(log_path, "r", encoding='utf-8') as f:
+ with open(log_path, "r", encoding='utf-8', errors="replace") as f:
return ''.join(deque(f, maxlen=lines))
except FileNotFoundError:
log.critical("Log file not found: %s", log_path)
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
index 8261887..df82568 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
@@ -18,7 +18,7 @@
# pylint: disable=E1101
import os
-from typing import Tuple, Literal, Optional
+from typing import AsyncGenerator, Tuple, Literal, Optional
import gradio as gr
import pandas as pd
@@ -26,6 +26,7 @@ from gradio.utils import NamedString
from hugegraph_llm.config import resource_path, prompt, huge_settings,
llm_settings
from hugegraph_llm.operators.graph_rag_task import RAGPipeline
+from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
from hugegraph_llm.utils.log import log
@@ -56,25 +57,10 @@ def rag_answer(
4. Synthesize the final answer.
5. Run the pipeline and return the results.
"""
-
- gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
- should_update_prompt = (
- prompt.default_question != text
- or prompt.answer_prompt != answer_prompt
- or prompt.keywords_extract_prompt != keywords_extract_prompt
- or prompt.gremlin_generate_prompt != gremlin_prompt
- or prompt.custom_rerank_info != custom_related_information
- )
- if should_update_prompt:
- prompt.custom_rerank_info = custom_related_information
- prompt.default_question = text
- prompt.answer_prompt = answer_prompt
- prompt.keywords_extract_prompt = keywords_extract_prompt
- prompt.gremlin_generate_prompt = gremlin_prompt
- prompt.update_yaml_file()
-
- vector_search = vector_only_answer or graph_vector_answer
- graph_search = graph_only_answer or graph_vector_answer
+ graph_search, gremlin_prompt, vector_search =
update_ui_configs(answer_prompt, custom_related_information,
+
graph_only_answer, graph_vector_answer,
+
gremlin_prompt, keywords_extract_prompt, text,
+
vector_only_answer)
if raw_answer is False and not vector_search and not graph_search:
gr.Warning("Please select at least one generate mode.")
return "", "", "", ""
@@ -121,6 +107,106 @@ def rag_answer(
raise gr.Error(f"An unexpected error occurred: {str(e)}")
+def update_ui_configs(answer_prompt, custom_related_information,
graph_only_answer, graph_vector_answer, gremlin_prompt,
+ keywords_extract_prompt, text, vector_only_answer):
+ gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
+ should_update_prompt = (
+ prompt.default_question != text
+ or prompt.answer_prompt != answer_prompt
+ or prompt.keywords_extract_prompt != keywords_extract_prompt
+ or prompt.gremlin_generate_prompt != gremlin_prompt
+ or prompt.custom_rerank_info != custom_related_information
+ )
+ if should_update_prompt:
+ prompt.custom_rerank_info = custom_related_information
+ prompt.default_question = text
+ prompt.answer_prompt = answer_prompt
+ prompt.keywords_extract_prompt = keywords_extract_prompt
+ prompt.gremlin_generate_prompt = gremlin_prompt
+ prompt.update_yaml_file()
+ vector_search = vector_only_answer or graph_vector_answer
+ graph_search = graph_only_answer or graph_vector_answer
+ return graph_search, gremlin_prompt, vector_search
+
+
+async def rag_answer_streaming(
+ text: str,
+ raw_answer: bool,
+ vector_only_answer: bool,
+ graph_only_answer: bool,
+ graph_vector_answer: bool,
+ graph_ratio: float,
+ rerank_method: Literal["bleu", "reranker"],
+ near_neighbor_first: bool,
+ custom_related_information: str,
+ answer_prompt: str,
+ keywords_extract_prompt: str,
+ gremlin_tmpl_num: Optional[int] = 2,
+ gremlin_prompt: Optional[str] = None,
+) -> AsyncGenerator[Tuple[str, str, str, str], None]:
+ """
+ Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
+ 1. Initialize the RAGPipeline.
+ 2. Select vector search or graph search based on parameters.
+ 3. Merge, deduplicate, and rerank the results.
+ 4. Synthesize the final answer.
+ 5. Run the pipeline and return the results.
+ """
+
+ graph_search, gremlin_prompt, vector_search =
update_ui_configs(answer_prompt, custom_related_information,
+
graph_only_answer, graph_vector_answer,
+
gremlin_prompt, keywords_extract_prompt, text,
+
vector_only_answer)
+ if raw_answer is False and not vector_search and not graph_search:
+ gr.Warning("Please select at least one generate mode.")
+ yield "", "", "", ""
+ return
+
+ rag = RAGPipeline()
+ if vector_search:
+ rag.query_vector_index()
+ if graph_search:
+
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
+ huge_settings.graph_name
+ ).query_graphdb(
+ num_gremlin_generate_example=gremlin_tmpl_num,
+ gremlin_prompt=gremlin_prompt,
+ )
+ rag.merge_dedup_rerank(
+ graph_ratio,
+ rerank_method,
+ near_neighbor_first,
+ )
+ # rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer,
graph_vector_answer, answer_prompt)
+
+ try:
+ context = rag.run(verbose=True, query=text,
vector_search=vector_search, graph_search=graph_search)
+ if context.get("switch_to_bleu"):
+ gr.Warning("Online reranker fails, automatically switches to local
bleu rerank.")
+ answer_synthesize = AnswerSynthesize(
+ raw_answer=raw_answer,
+ vector_only_answer=vector_only_answer,
+ graph_only_answer=graph_only_answer,
+ graph_vector_answer=graph_vector_answer,
+ prompt_template=answer_prompt,
+ )
+ async for context in answer_synthesize.run_streaming(context):
+ if context.get("switch_to_bleu"):
+ gr.Warning("Online reranker fails, automatically switches to
local bleu rerank.")
+ yield (
+ context.get("raw_answer", ""),
+ context.get("vector_only_answer", ""),
+ context.get("graph_only_answer", ""),
+ context.get("graph_vector_answer", ""),
+ )
+ except ValueError as e:
+ log.critical(e)
+ raise gr.Error(str(e))
+ except Exception as e:
+ log.critical(e)
+ raise gr.Error(f"An unexpected error occurred: {str(e)}")
+
+
def create_rag_block():
# pylint: disable=R0915 (too-many-statements),C0301
gr.Markdown("""## 1. HugeGraph RAG Query""")
@@ -130,13 +216,17 @@ def create_rag_block():
# TODO: Only support inline formula now. Should support block
formula
gr.Markdown("Basic LLM Answer", elem_classes="output-box-label")
- raw_out = gr.Markdown(elem_classes="output-box",
show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$",
"display":False}])
+ raw_out = gr.Markdown(elem_classes="output-box",
show_copy_button=True,
+ latex_delimiters=[{"left": "$", "right":
"$", "display": False}])
gr.Markdown("Vector-only Answer", elem_classes="output-box-label")
- vector_only_out = gr.Markdown(elem_classes="output-box",
show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$",
"display":False}])
+ vector_only_out = gr.Markdown(elem_classes="output-box",
show_copy_button=True,
+ latex_delimiters=[{"left": "$",
"right": "$", "display": False}])
gr.Markdown("Graph-only Answer", elem_classes="output-box-label")
- graph_only_out = gr.Markdown(elem_classes="output-box",
show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$",
"display":False}])
+ graph_only_out = gr.Markdown(elem_classes="output-box",
show_copy_button=True,
+ latex_delimiters=[{"left": "$",
"right": "$", "display": False}])
gr.Markdown("Graph-Vector Answer", elem_classes="output-box-label")
- graph_vector_out = gr.Markdown(elem_classes="output-box",
show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$",
"display":False}])
+ graph_vector_out = gr.Markdown(elem_classes="output-box",
show_copy_button=True,
+ latex_delimiters=[{"left": "$",
"right": "$", "display": False}])
answer_prompt_input = gr.Textbox(
value=prompt.answer_prompt, label="Query Prompt",
show_copy_button=True, lines=7
@@ -184,7 +274,7 @@ def create_rag_block():
btn = gr.Button("Answer Question", variant="primary")
btn.click( # pylint: disable=no-member
- fn=rag_answer,
+ fn=rag_answer_streaming,
inputs=[
inp,
raw_radio,
@@ -254,13 +344,13 @@ def create_rag_block():
is_vector_only_answer: bool,
is_graph_only_answer: bool,
is_graph_vector_answer: bool,
- graph_ratio: float,
- rerank_method: Literal["bleu", "reranker"],
- near_neighbor_first: bool,
- custom_related_information: str,
+ graph_ratio_ui: float,
+ rerank_method_ui: Literal["bleu", "reranker"],
+ near_neighbor_first_ui: bool,
+ custom_related_information_ui: str,
answer_prompt: str,
keywords_extract_prompt: str,
- answer_max_line_count: int = 1,
+ answer_max_line_count_ui: int = 1,
progress=gr.Progress(track_tqdm=True),
):
df = pd.read_excel(questions_path, dtype=str)
@@ -273,10 +363,10 @@ def create_rag_block():
is_vector_only_answer,
is_graph_only_answer,
is_graph_vector_answer,
- graph_ratio,
- rerank_method,
- near_neighbor_first,
- custom_related_information,
+ graph_ratio_ui,
+ rerank_method_ui,
+ near_neighbor_first_ui,
+ custom_related_information_ui,
answer_prompt,
keywords_extract_prompt,
)
@@ -285,9 +375,9 @@ def create_rag_block():
df.at[index, "Graph-only Answer"] = graph_only_answer
df.at[index, "Graph-Vector Answer"] = graph_vector_answer
progress((index + 1, total_rows))
- answers_path = os.path.join(resource_path, "demo",
"questions_answers.xlsx")
- df.to_excel(answers_path, index=False)
- return df.head(answer_max_line_count), answers_path
+ answers_path_ui = os.path.join(resource_path, "demo",
"questions_answers.xlsx")
+ df.to_excel(answers_path_ui, index=False)
+ return df.head(answer_max_line_count_ui), answers_path_ui
with gr.Row():
with gr.Column():
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
index f2dd234..c6bfa44 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
@@ -16,7 +16,7 @@
# under the License.
from abc import ABC, abstractmethod
-from typing import Any, List, Optional, Callable, Dict
+from typing import Any, AsyncGenerator, Generator, List, Optional, Callable,
Dict
class BaseLLM(ABC):
@@ -43,8 +43,17 @@ class BaseLLM(ABC):
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
- on_token_callback: Callable = None,
- ) -> List[Any]:
+ on_token_callback: Optional[Callable] = None,
+ ) -> Generator[str, None, None]:
+ """Comment"""
+
+ @abstractmethod
+ async def agenerate_streaming(
+ self,
+ messages: Optional[List[Dict[str, Any]]] = None,
+ prompt: Optional[str] = None,
+ on_token_callback: Optional[Callable] = None,
+ ) -> AsyncGenerator[str, None]:
"""Comment"""
@abstractmethod
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py
index 23a1250..ca5ae60 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Callable, List, Optional, Dict, Any
+from typing import Callable, List, Optional, Dict, Any, AsyncGenerator
import tiktoken
from litellm import completion, acompletion
@@ -137,6 +137,35 @@ class LiteLLMClient(BaseLLM):
log.error("Error in streaming LiteLLM call: %s", e)
return f"Error: {str(e)}"
+ async def agenerate_streaming(
+ self,
+ messages: Optional[List[Dict[str, Any]]] = None,
+ prompt: Optional[str] = None,
+ on_token_callback: Optional[Callable] = None,
+ ) -> AsyncGenerator[str, None]:
+ """Generate a response to the query messages/prompt in async streaming
mode."""
+ if messages is None:
+ assert prompt is not None, "Messages or prompt must be provided."
+ messages = [{"role": "user", "content": prompt}]
+ try:
+ response = await acompletion(
+ model=self.model,
+ messages=messages,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ api_key=self.api_key,
+ base_url=self.api_base,
+ stream=True,
+ )
+ async for chunk in response:
+ if chunk.choices[0].delta.content:
+ if on_token_callback:
+ on_token_callback(chunk)
+ yield chunk.choices[0].delta.content
+ except (RateLimitError, BudgetExceededError, APIError) as e:
+ log.error("Error in async streaming LiteLLM call: %s", e)
+ yield f"Error: {str(e)}"
+
def num_tokens_from_string(self, string: str) -> int:
"""Get token count from string."""
try:
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
index 62f5ef2..58f063b 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
@@ -17,7 +17,7 @@
import json
-from typing import Any, List, Optional, Callable, Dict
+from typing import Any, AsyncGenerator, Generator, List, Optional, Callable,
Dict
import ollama
from retry import retry
@@ -89,22 +89,49 @@ class OllamaClient(BaseLLM):
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
- on_token_callback: Callable = None,
- ) -> List[Any]:
+ on_token_callback: Optional[Callable] = None,
+ ) -> Generator[str, None, None]:
"""Comment"""
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
- stream = self.client.chat(
+
+ for chunk in self.client.chat(
model=self.model,
messages=messages,
stream=True
- )
- chunks = []
- for chunk in stream:
- on_token_callback(chunk["message"]["content"])
- chunks.append(chunk)
- return chunks
+ ):
+ token = chunk["message"]["content"]
+ if on_token_callback:
+ on_token_callback(token)
+ yield token
+
+ async def agenerate_streaming(
+ self,
+ messages: Optional[List[Dict[str, Any]]] = None,
+ prompt: Optional[str] = None,
+ on_token_callback: Optional[Callable] = None,
+ ) -> AsyncGenerator[str, None]:
+ """Comment"""
+ if messages is None:
+ assert prompt is not None, "Messages or prompt must be provided."
+ messages = [{"role": "user", "content": prompt}]
+
+ try:
+ async_generator = await self.async_client.chat(
+ model=self.model,
+ messages=messages,
+ stream=True
+ )
+ async for chunk in async_generator:
+ token = chunk.get("message", {}).get("content", "")
+ if on_token_callback:
+ on_token_callback(token)
+ yield token
+ except Exception as e:
+ print(f"Retrying LLM call {e}")
+ raise e
+
def num_tokens_from_string(
self,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
index a020067..45f6d7a 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Callable, List, Optional, Dict, Any
+from typing import Callable, List, Optional, Dict, Any, Generator,
AsyncGenerator
import openai
import tiktoken
@@ -90,9 +90,9 @@ class OpenAIClient(BaseLLM):
retry=retry_if_exception_type((RateLimitError, APIConnectionError,
APITimeoutError)),
)
async def agenerate(
- self,
- messages: Optional[List[Dict[str, Any]]] = None,
- prompt: Optional[str] = None,
+ self,
+ messages: Optional[List[Dict[str, Any]]] = None,
+ prompt: Optional[str] = None,
) -> str:
"""Generate a response to the query messages/prompt."""
if messages is None:
@@ -119,31 +119,91 @@ class OpenAIClient(BaseLLM):
log.error("Retrying LLM call %s", e)
raise e
+ @retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError,
APITimeoutError)),
+ )
def generate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
- on_token_callback: Callable = None,
- ) -> str:
- """Generate a response to the query messages/prompt in streaming
mode."""
+ on_token_callback: Optional[Callable[[str], None]] = None,
+ ) -> Generator[str, None, None]:
+ """Generate a response to the query messages/prompt in streaming mode.
+
+ Yields:
+ Accumulated response string after each new token.
+ """
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
- completions = self.client.chat.completions.create(
- model=self.model,
- temperature=self.temperature,
- max_tokens=self.max_tokens,
- messages=messages,
- stream=True,
- )
- result = ""
- for message in completions:
- # Process the streamed messages or perform any other desired action
- delta = message["choices"][0]["delta"]
- if "content" in delta:
- result += delta["content"]
- on_token_callback(message)
- return result
+
+ try:
+ completions = self.client.chat.completions.create(
+ model=self.model,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ messages=messages,
+ stream=True,
+ )
+
+ for chunk in completions:
+ delta = chunk.choices[0].delta
+ if delta.content:
+ token = delta.content
+ if on_token_callback:
+ on_token_callback(token)
+ yield token
+
+ except openai.BadRequestError as e:
+ log.critical("Fatal: %s", e)
+ yield str(f"Error: {e}")
+ except openai.AuthenticationError:
+ log.critical("The provided API key is invalid")
+ yield "Error: The provided API key is invalid"
+ except Exception as e:
+ log.error("Error in streaming: %s", e)
+ raise e
+
+ async def agenerate_streaming(
+ self,
+ messages: Optional[List[Dict[str, Any]]] = None,
+ prompt: Optional[str] = None,
+ on_token_callback: Optional[Callable] = None,
+ ) -> AsyncGenerator[str, None]:
+ """Comment"""
+ if messages is None:
+ assert prompt is not None, "Messages or prompt must be provided."
+ messages = [{"role": "user", "content": prompt}]
+
+ try:
+ completions = await self.aclient.chat.completions.create(
+ model=self.model,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ messages=messages,
+ stream=True
+ )
+ async for chunk in completions:
+ delta = chunk.choices[0].delta
+ if delta.content:
+ token = delta.content
+ if on_token_callback:
+ on_token_callback(token)
+ yield token
+ # TODO: log.info("Token usage: %s",
completions.usage.model_dump_json())
+ # catch context length / do not retry
+ except openai.BadRequestError as e:
+ log.critical("Fatal: %s", e)
+ yield str(f"Error: {e}")
+ # catch authorization errors / do not retry
+ except openai.AuthenticationError:
+ log.critical("The provided OpenAI API key is invalid")
+ yield "Error: The provided OpenAI API key is invalid"
+ except Exception as e:
+ log.error("Retrying LLM call %s", e)
+ raise e
def num_tokens_from_string(self, string: str) -> int:
"""Get token count from string."""
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
index 967c391..cbca691 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
@@ -16,7 +16,7 @@
# under the License.
import json
-from typing import Optional, List, Dict, Any, Callable
+from typing import AsyncGenerator, Generator, Optional, List, Dict, Any,
Callable
import qianfan
from retry import retry
@@ -74,9 +74,38 @@ class QianfanClient(BaseLLM):
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
- on_token_callback: Callable = None,
- ) -> str:
- return self.generate(messages, prompt)
+ on_token_callback: Optional[Callable] = None,
+ ) -> Generator[str, None, None]:
+ if messages is None:
+ assert prompt is not None, "Messages or prompt must be provided."
+ messages = [{"role": "user", "content": prompt}]
+
+ for msg in self.chat_comp.do(messages=messages, model=self.chat_model,
stream=True):
+ token = msg.body['result']
+ if on_token_callback:
+ on_token_callback(token)
+ yield token
+
+ async def agenerate_streaming(
+ self,
+ messages: Optional[List[Dict[str, Any]]] = None,
+ prompt: Optional[str] = None,
+ on_token_callback: Optional[Callable] = None,
+ ) -> AsyncGenerator[str, None]:
+ if messages is None:
+ assert prompt is not None, "Messages or prompt must be provided."
+ messages = [{"role": "user", "content": prompt}]
+
+ try:
+ async_generator = await self.chat_comp.ado(messages=messages,
model=self.chat_model, stream=True)
+ async for msg in async_generator:
+ chunk = msg.body['result']
+ if on_token_callback:
+ on_token_callback(chunk)
+ yield chunk
+ except Exception as e:
+ print(f"Retrying LLM call {e}")
+ raise e
def num_tokens_from_string(self, string: str) -> int:
return len(string)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index 666ecf9..5c4ab5f 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -18,7 +18,7 @@
# pylint: disable=W0621
import asyncio
-from typing import Any, Dict, Optional
+from typing import Any, AsyncGenerator, Dict, Optional
from hugegraph_llm.config import prompt
from hugegraph_llm.models.llms.base import BaseLLM
@@ -35,17 +35,17 @@ DEFAULT_ANSWER_TEMPLATE = prompt.answer_prompt
class AnswerSynthesize:
def __init__(
- self,
- llm: Optional[BaseLLM] = None,
- prompt_template: Optional[str] = None,
- question: Optional[str] = None,
- context_body: Optional[str] = None,
- context_head: Optional[str] = None,
- context_tail: Optional[str] = None,
- raw_answer: bool = False,
- vector_only_answer: bool = True,
- graph_only_answer: bool = False,
- graph_vector_answer: bool = False,
+ self,
+ llm: Optional[BaseLLM] = None,
+ prompt_template: Optional[str] = None,
+ question: Optional[str] = None,
+ context_body: Optional[str] = None,
+ context_head: Optional[str] = None,
+ context_tail: Optional[str] = None,
+ raw_answer: bool = False,
+ vector_only_answer: bool = True,
+ graph_only_answer: bool = False,
+ graph_vector_answer: bool = False,
):
self._llm = llm
self._prompt_template = prompt_template or DEFAULT_ANSWER_TEMPLATE
@@ -59,15 +59,7 @@ class AnswerSynthesize:
self._graph_vector_answer = graph_vector_answer
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
- if self._llm is None:
- self._llm = LLMs().get_chat_llm()
-
- if self._question is None:
- self._question = context.get("query") or None
- assert self._question is not None, "No question for synthesizing."
-
- context_head_str = context.get("synthesize_context_head") or
self._context_head or ""
- context_tail_str = context.get("synthesize_context_tail") or
self._context_tail or ""
+ context_head_str, context_tail_str = self.init_llm(context)
if self._context_body is not None:
context_str = (f"{context_head_str}\n"
@@ -78,6 +70,22 @@ class AnswerSynthesize:
response = self._llm.generate(prompt=final_prompt)
return {"answer": response}
+ graph_result_context, vector_result_context =
self.handle_vector_graph(context)
+ context = asyncio.run(self.async_generate(context, context_head_str,
context_tail_str,
+ vector_result_context,
graph_result_context))
+ return context
+
+ def init_llm(self, context):
+ if self._llm is None:
+ self._llm = LLMs().get_chat_llm()
+ if self._question is None:
+ self._question = context.get("query") or None
+ assert self._question is not None, "No question for synthesizing."
+ context_head_str = context.get("synthesize_context_head") or
self._context_head or ""
+ context_tail_str = context.get("synthesize_context_tail") or
self._context_tail or ""
+ return context_head_str, context_tail_str
+
+ def handle_vector_graph(self, context):
vector_result = context.get("vector_result")
if vector_result:
vector_result_context = "Phrases related to the query:\n" +
"\n".join(
@@ -85,7 +93,6 @@ class AnswerSynthesize:
)
else:
vector_result_context = "No (vector)phrase related to the query."
-
graph_result = context.get("graph_result")
if graph_result:
graph_context_head = context.get("graph_context_head", "Knowledge
from graphdb for the query:\n")
@@ -95,10 +102,31 @@ class AnswerSynthesize:
else:
graph_result_context = "No related graph data found for current
query."
log.warning(graph_result_context)
+ return graph_result_context, vector_result_context
- context = asyncio.run(self.async_generate(context, context_head_str,
context_tail_str,
- vector_result_context,
graph_result_context))
- return context
+ async def run_streaming(self, context: Dict[str, Any]) ->
AsyncGenerator[Dict[str, Any], None]:
+ context_head_str, context_tail_str = self.init_llm(context)
+
+ if self._context_body is not None:
+ context_str = (f"{context_head_str}\n"
+ f"{self._context_body}\n"
+ f"{context_tail_str}".strip("\n"))
+
+ final_prompt =
self._prompt_template.format(context_str=context_str, query_str=self._question)
+ response = self._llm.generate(prompt=final_prompt)
+ yield {"answer": response}
+ return
+
+ graph_result_context, vector_result_context =
self.handle_vector_graph(context)
+
+ async for context in self.async_streaming_generate(
+ context,
+ context_head_str,
+ context_tail_str,
+ vector_result_context,
+ graph_result_context
+ ):
+ yield context
async def async_generate(self, context: Dict[str, Any], context_head_str:
str,
context_tail_str: str, vector_result_context: str,
@@ -151,3 +179,81 @@ class AnswerSynthesize:
ops = sum([self._raw_answer, self._vector_only_answer,
self._graph_only_answer, self._graph_vector_answer])
context['call_count'] = context.get('call_count', 0) + ops
return context
+
+ async def async_streaming_generate(self, context: Dict[str, Any],
context_head_str: str,
+ context_tail_str: str,
vector_result_context: str,
+ graph_result_context: str) ->
AsyncGenerator[Dict[str, Any], None]:
+ # async_tasks stores the async tasks for different answer types
+ async_generators = []
+ auto_id = 0
+ if self._raw_answer:
+ final_prompt = self._question
+ async_generators.append(
+ self.__llm_generate_with_meta_info(task_id=auto_id,
target_key="raw_answer", prompt=final_prompt)
+ )
+ auto_id += 1
+ if self._vector_only_answer:
+ context_str = (f"{context_head_str}\n"
+ f"{vector_result_context}\n"
+ f"{context_tail_str}".strip("\n"))
+
+ final_prompt =
self._prompt_template.format(context_str=context_str, query_str=self._question)
+ async_generators.append(
+ self.__llm_generate_with_meta_info(
+ task_id=auto_id,
+ target_key="vector_only_answer",
+ prompt=final_prompt
+ )
+ )
+ auto_id += 1
+ if self._graph_only_answer:
+ context_str = (f"{context_head_str}\n"
+ f"{graph_result_context}\n"
+ f"{context_tail_str}".strip("\n"))
+
+ final_prompt =
self._prompt_template.format(context_str=context_str, query_str=self._question)
+ async_generators.append(
+ self.__llm_generate_with_meta_info(task_id=auto_id,
target_key="graph_only_answer", prompt=final_prompt)
+ )
+ auto_id += 1
+ if self._graph_vector_answer:
+ context_body_str =
f"{vector_result_context}\n{graph_result_context}"
+ if context.get("graph_ratio", 0.5) < 0.5:
+ context_body_str =
f"{graph_result_context}\n{vector_result_context}"
+ context_str = (f"{context_head_str}\n"
+ f"{context_body_str}\n"
+ f"{context_tail_str}".strip("\n"))
+
+ final_prompt =
self._prompt_template.format(context_str=context_str, query_str=self._question)
+ async_generators.append(
+ self.__llm_generate_with_meta_info(
+ task_id=auto_id,
+ target_key="graph_vector_answer",
+ prompt=final_prompt
+ )
+ )
+ auto_id += 1
+
+ ops = sum([self._raw_answer, self._vector_only_answer,
self._graph_only_answer, self._graph_vector_answer])
+ context['call_count'] = context.get('call_count', 0) + ops
+
+ async_tasks = [asyncio.create_task(anext(gen)) for gen in
async_generators]
+ while True:
+ done, _ = await asyncio.wait(async_tasks,
return_when=asyncio.FIRST_COMPLETED)
+ stop_task_num = 0
+ for task in done:
+ try:
+ task_id, target_key, token = task.result()
+ context[target_key] = context.get(target_key, "") + token
+ gen = async_generators[task_id]
+ async_tasks[task_id] = asyncio.create_task(anext(gen))
+ except StopAsyncIteration:
+ stop_task_num += 1
+ if stop_task_num == len(async_tasks):
+ break
+ yield context
+
+ async def __llm_generate_with_meta_info(self, task_id: int, target_key:
str, prompt: str):
+ # FIXME: Expected type 'AsyncIterable', got 'Coroutine[Any, Any,
AsyncGenerator[str, None]]' instead
+ async for token in self._llm.agenerate_streaming(prompt=prompt):
+ yield task_id, target_key, token
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
index 9694647..219a358 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
@@ -92,10 +92,40 @@ class GremlinGenerateSynthesize:
return context
+ def sync_generate(self, context: Dict[str, Any]):
+ query = context.get("query")
+ raw_example = [{'query': 'who is peter', 'gremlin': "g.V().has('name',
'peter')"}]
+ raw_prompt = self.gremlin_prompt.format(
+ query=query,
+ schema=self.schema,
+ example=self._format_examples(examples=raw_example),
+ vertices=self._format_vertices(vertices=self.vertices)
+ )
+ raw_response = self.llm.generate(prompt=raw_prompt)
+
+ examples = context.get("match_result")
+ init_prompt = self.gremlin_prompt.format(
+ query=query,
+ schema=self.schema,
+ example=self._format_examples(examples=examples),
+ vertices=self._format_vertices(vertices=self.vertices)
+ )
+ initialized_response = self.llm.generate(prompt=init_prompt)
+
+ log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s",
init_prompt, initialized_response)
+
+ context["result"] =
self._extract_gremlin(response=initialized_response)
+ context["raw_result"] = self._extract_gremlin(response=raw_response)
+ context["call_count"] = context.get("call_count", 0) + 2
+
+ return context
+
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
query = context.get("query", "")
if not query:
raise ValueError("query is required")
- context = asyncio.run(self.async_generate(context))
+ # TODO: Update to async_generate again
+ # The best method may be changing all `operator.run(*arg)` to be
async function
+ context = self.sync_generate(context)
return context