github-advanced-security[bot] commented on code in PR #179:
URL:
https://github.com/apache/incubator-hugegraph-ai/pull/179#discussion_r1970897554
##########
hugegraph-llm/src/hugegraph_llm/api/rag_api.py:
##########
@@ -33,76 +36,243 @@
def rag_http_api(
- router: APIRouter,
- rag_answer_func,
- graph_rag_recall_func,
- apply_graph_conf,
- apply_llm_conf,
- apply_embedding_conf,
- apply_reranker_conf,
+ router: APIRouter,
+ rag_answer_func,
+ graph_rag_recall_func,
+ apply_graph_conf,
+ apply_llm_conf,
+ apply_embedding_conf,
+ apply_reranker_conf,
+ rag_answer_stream_func=None,
+ graph_rag_recall_stream_func=None,
):
- @router.post("/rag", status_code=status.HTTP_200_OK)
- def rag_answer_api(req: RAGRequest):
- result = rag_answer_func(
- text=req.query,
- raw_answer=req.raw_answer,
- vector_only_answer=req.vector_only,
- graph_only_answer=req.graph_only,
- graph_vector_answer=req.graph_vector_answer,
- graph_ratio=req.graph_ratio,
- rerank_method=req.rerank_method,
- near_neighbor_first=req.near_neighbor_first,
- custom_related_information=req.custom_priority_info,
- answer_prompt=req.answer_prompt or prompt.answer_prompt,
- keywords_extract_prompt=req.keywords_extract_prompt or
prompt.keywords_extract_prompt,
- gremlin_tmpl_num=req.gremlin_tmpl_num,
- gremlin_prompt=req.gremlin_prompt or
prompt.gremlin_generate_prompt,
- )
- # TODO: we need more info in the response for users to understand the
query logic
- return {
- "query": req.query,
- **{
- key: value
- for key, value in zip(["raw_answer", "vector_only",
"graph_only", "graph_vector_answer"], result)
- if getattr(req, key)
- },
- }
+ async def stream_rag_answer(
+ text,
+ raw_answer,
+ vector_only_answer,
+ graph_only_answer,
+ graph_vector_answer,
+ graph_ratio,
+ rerank_method,
+ near_neighbor_first,
+ custom_related_information,
+ answer_prompt,
+ keywords_extract_prompt,
+ gremlin_tmpl_num,
+ gremlin_prompt,
+ ) -> AsyncGenerator[str, None]:
+ """
+ Stream the RAG answer results
+ """
+ if rag_answer_stream_func:
+ # If a streaming-specific function exists, use it
+ async for chunk in rag_answer_stream_func(
+ text=text,
+ raw_answer=raw_answer,
+ vector_only_answer=vector_only_answer,
+ graph_only_answer=graph_only_answer,
+ graph_vector_answer=graph_vector_answer,
+ graph_ratio=graph_ratio,
+ rerank_method=rerank_method,
+ near_neighbor_first=near_neighbor_first,
+ custom_related_information=custom_related_information,
+ answer_prompt=answer_prompt,
+ keywords_extract_prompt=keywords_extract_prompt,
+ gremlin_tmpl_num=gremlin_tmpl_num,
+ gremlin_prompt=gremlin_prompt,
+ ):
+ yield f"data: {json.dumps({'chunk': chunk})}\n\n"
+ else:
+ # Otherwise, use the normal function but adapt it for streaming
+ # by sending the entire result at once
+ result = rag_answer_func(
+ text=text,
+ raw_answer=raw_answer,
+ vector_only_answer=vector_only_answer,
+ graph_only_answer=graph_only_answer,
+ graph_vector_answer=graph_vector_answer,
+ graph_ratio=graph_ratio,
+ rerank_method=rerank_method,
+ near_neighbor_first=near_neighbor_first,
+ custom_related_information=custom_related_information,
+ answer_prompt=answer_prompt,
+ keywords_extract_prompt=keywords_extract_prompt,
+ gremlin_tmpl_num=gremlin_tmpl_num,
+ gremlin_prompt=gremlin_prompt,
+ )
- @router.post("/rag/graph", status_code=status.HTTP_200_OK)
- def graph_rag_recall_api(req: GraphRAGRequest):
- try:
- result = graph_rag_recall_func(
- query=req.query,
- gremlin_tmpl_num=req.gremlin_tmpl_num,
+ response_data = {
+ "query": text,
+ **{
+ key: value
+ for key, value in zip(["raw_answer", "vector_only",
"graph_only", "graph_vector_answer"], result)
+ if eval(key) # Convert string to boolean
+ },
+ }
+
+ yield f"data: {json.dumps(response_data)}\n\n"
+ # Signal end of stream
+ yield "data: [DONE]\n\n"
+
+ async def stream_graph_rag_recall(
+ query,
+ gremlin_tmpl_num,
+ rerank_method,
+ near_neighbor_first,
+ custom_related_information,
+ gremlin_prompt,
+ ) -> AsyncGenerator[str, None]:
+ """
+ Stream the graph RAG recall results
+ """
+ if graph_rag_recall_stream_func:
+ # If a streaming-specific function exists, use it
+ async for chunk in graph_rag_recall_stream_func(
+ query=query,
+ gremlin_tmpl_num=gremlin_tmpl_num,
+ rerank_method=rerank_method,
+ near_neighbor_first=near_neighbor_first,
+ custom_related_information=custom_related_information,
+ gremlin_prompt=gremlin_prompt,
+ ):
+ yield f"data: {json.dumps({'chunk': chunk})}\n\n"
+ else:
+ # Otherwise, use the normal function but adapt it for streaming
+ try:
+ result = graph_rag_recall_func(
+ query=query,
+ gremlin_tmpl_num=gremlin_tmpl_num,
+ rerank_method=rerank_method,
+ near_neighbor_first=near_neighbor_first,
+ custom_related_information=custom_related_information,
+ gremlin_prompt=gremlin_prompt,
+ )
+
+ if isinstance(result, dict):
+ params = [
+ "query",
+ "keywords",
+ "match_vids",
+ "graph_result_flag",
+ "gremlin",
+ "graph_result",
+ "vertex_degree_list",
+ ]
+ user_result = {key: result[key] for key in params if key
in result}
+ yield f"data: {json.dumps({'graph_recall':
user_result})}\n\n"
+ else:
+ # Note: Maybe only for qianfan/wenxin
+ yield f"data: {json.dumps({'graph_recall':
json.dumps(result)})}\n\n"
+
+ # Signal end of stream
+ yield "data: [DONE]\n\n"
+
+ except TypeError as e:
+ log.error("TypeError in stream_graph_rag_recall: %s", e)
+ yield f"data: {json.dumps({'error': str(e), 'status':
400})}\n\n"
+ except Exception as e:
+ log.error("Unexpected error occurred: %s", e)
+ yield f"data: {json.dumps({'error': 'An unexpected error
occurred.', 'status': 500})}\n\n"
+
+ @router.post("/rag", status_code=status.HTTP_200_OK)
+ async def rag_answer_api(req: RAGRequest):
+ if req.stream:
+ # Return a streaming response
+ return StreamingResponse(
+ stream_rag_answer(
+ text=req.query,
+ raw_answer=req.raw_answer,
+ vector_only_answer=req.vector_only,
+ graph_only_answer=req.graph_only,
+ graph_vector_answer=req.graph_vector_answer,
+ graph_ratio=req.graph_ratio,
+ rerank_method=req.rerank_method,
+ near_neighbor_first=req.near_neighbor_first,
+ custom_related_information=req.custom_priority_info,
+ answer_prompt=req.answer_prompt or prompt.answer_prompt,
+ keywords_extract_prompt=req.keywords_extract_prompt or
prompt.keywords_extract_prompt,
+ gremlin_tmpl_num=req.gremlin_tmpl_num,
+ gremlin_prompt=req.gremlin_prompt or
prompt.gremlin_generate_prompt,
+ ),
+ media_type="text/event-stream",
+ )
+ else:
+ # Synchronous response (original behavior)
+ result = rag_answer_func(
+ text=req.query,
+ raw_answer=req.raw_answer,
+ vector_only_answer=req.vector_only,
+ graph_only_answer=req.graph_only,
+ graph_vector_answer=req.graph_vector_answer,
+ graph_ratio=req.graph_ratio,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
+ answer_prompt=req.answer_prompt or prompt.answer_prompt,
+ keywords_extract_prompt=req.keywords_extract_prompt or
prompt.keywords_extract_prompt,
+ gremlin_tmpl_num=req.gremlin_tmpl_num,
gremlin_prompt=req.gremlin_prompt or
prompt.gremlin_generate_prompt,
)
+ # TODO: we need more info in the response for users to understand
the query logic
+ return {
+ "query": req.query,
+ **{
+ key: value
+ for key, value in zip(["raw_answer", "vector_only",
"graph_only", "graph_vector_answer"], result)
+ if getattr(req, key)
+ },
+ }
- if isinstance(result, dict):
- params = [
- "query",
- "keywords",
- "match_vids",
- "graph_result_flag",
- "gremlin",
- "graph_result",
- "vertex_degree_list",
- ]
- user_result = {key: result[key] for key in params if key in
result}
- return {"graph_recall": user_result}
- # Note: Maybe only for qianfan/wenxin
- return {"graph_recall": json.dumps(result)}
-
- except TypeError as e:
- log.error("TypeError in graph_rag_recall_api: %s", e)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)) from e
- except Exception as e:
- log.error("Unexpected error occurred: %s", e)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An
unexpected error occurred."
- ) from e
+ @router.post("/rag/graph", status_code=status.HTTP_200_OK)
+ async def graph_rag_recall_api(req: GraphRAGRequest):
+ if req.stream:
+ # Return a streaming response
+ return StreamingResponse(
+ stream_graph_rag_recall(
+ query=req.query,
+ gremlin_tmpl_num=req.gremlin_tmpl_num,
+ rerank_method=req.rerank_method,
+ near_neighbor_first=req.near_neighbor_first,
+ custom_related_information=req.custom_priority_info,
+ gremlin_prompt=req.gremlin_prompt or
prompt.gremlin_generate_prompt,
+ ),
Review Comment:
## Information exposure through an exception
[Stack trace information](1) flows to this location and may be exposed to an
external user.
[Show more
details](https://github.com/apache/incubator-hugegraph-ai/security/code-scanning/25)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]