This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 71fe026e feat(llm): add text2gremlin api (#258)
71fe026e is described below
commit 71fe026e9bee0f62da70f37164bc68277ccd0562
Author: SoJGooo <[email protected]>
AuthorDate: Mon Aug 11 14:24:17 2025 +0800
feat(llm): add text2gremlin api (#258)
Co-authored-by: jinsong04 <[email protected]>
Co-authored-by: Seanium <[email protected]>
Co-authored-by: imbajin <[email protected]>
---
.../src/hugegraph_llm/api/models/rag_requests.py | 45 ++++-
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 30 ++-
.../src/hugegraph_llm/config/prompt_config.py | 9 +
.../src/hugegraph_llm/demo/rag_demo/app.py | 2 +
.../demo/rag_demo/text2gremlin_block.py | 218 +++++++++++++++++----
5 files changed, 260 insertions(+), 44 deletions(-)
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 89bdd9bc..cf227e8b 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -15,17 +15,17 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Optional, Literal
-
+from typing import Optional, Literal, List
+from enum import Enum
from fastapi import Query
-from pydantic import BaseModel
+from pydantic import BaseModel, field_validator
from hugegraph_llm.config import prompt
class GraphConfigRequest(BaseModel):
url: str = Query('127.0.0.1:8080', description="hugegraph client url.")
- name: str = Query('hugegraph', description="hugegraph client name.")
+ graph: str = Query('hugegraph', description="hugegraph client name.")
user: str = Query('', description="hugegraph client user.")
pwd: str = Query('', description="hugegraph client pwd.")
gs: str = None
@@ -114,3 +114,40 @@ class RerankerConfigRequest(BaseModel):
class LogStreamRequest(BaseModel):
admin_token: Optional[str] = None
log_file: Optional[str] = "llm-server.log"
+
+class GremlinOutputType(str, Enum):
+ MATCH_RESULT = "match_result"
+ TEMPLATE_GREMLIN = "template_gremlin"
+ RAW_GREMLIN = "raw_gremlin"
+ TEMPLATE_EXECUTION_RESULT = "template_execution_result"
+ RAW_EXECUTION_RESULT = "raw_execution_result"
+
+class GremlinGenerateRequest(BaseModel):
+ query: str
+ example_num: Optional[int] = Query(
+ 0,
+ description="Number of Gremlin templates to use.(0 means no templates)"
+ )
+ gremlin_prompt: Optional[str] = Query(
+ prompt.gremlin_generate_prompt,
+ description="Prompt for the Text2Gremlin query.",
+ )
+ client_config: Optional[GraphConfigRequest] = Query(None,
description="hugegraph server config.")
+ output_types: Optional[List[GremlinOutputType]] = Query(
+ default=[GremlinOutputType.TEMPLATE_GREMLIN],
+ description="""
+ a list can contain "match_result","template_gremlin",
+ "raw_gremlin","template_execution_result","raw_execution_result"
+ You can specify which type of result do you need. Empty means all
types.
+ """
+ )
+
+ @field_validator('gremlin_prompt')
+ @classmethod
+ def validate_prompt_placeholders(cls, v):
+ if v is not None:
+ required_placeholders = ['{query}', '{schema}', '{example}',
'{vertices}']
+ missing = [p for p in required_placeholders if p not in v]
+ if missing:
+ raise ValueError(f"Prompt template is missing required
placeholders: {', '.join(missing)}")
+ return v
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index b7e8a6f7..c39c7771 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -26,6 +26,7 @@ from hugegraph_llm.api.models.rag_requests import (
LLMConfigRequest,
RerankerConfigRequest,
GraphRAGRequest,
+ GremlinGenerateRequest,
)
from hugegraph_llm.config import huge_settings
from hugegraph_llm.api.models.rag_response import RAGResponse
@@ -41,6 +42,7 @@ def rag_http_api(
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
+ gremlin_generate_selective_func,
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
@@ -79,7 +81,7 @@ def rag_http_api(
def set_graph_config(req):
if req.client_config:
huge_settings.graph_url = req.client_config.url
- huge_settings.graph_name = req.client_config.name
+ huge_settings.graph_name = req.client_config.graph
huge_settings.graph_user = req.client_config.user
huge_settings.graph_pwd = req.client_config.pwd
huge_settings.graph_space = req.client_config.gs
@@ -173,3 +175,29 @@ def rag_http_api(
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
+
+ @router.post("/text2gremlin", status_code=status.HTTP_200_OK)
+ def text2gremlin_api(req: GremlinGenerateRequest):
+ try:
+ set_graph_config(req)
+
+ output_types_str_list = None
+ if req.output_types:
+ output_types_str_list = [ot.value for ot in req.output_types]
+
+ response_dict = gremlin_generate_selective_func(
+ inp=req.query,
+ example_num=req.example_num,
+ schema_input=huge_settings.graph_name,
+ gremlin_prompt_input=req.gremlin_prompt,
+ requested_outputs=output_types_str_list,
+ )
+ return response_dict
+ except HTTPException as e:
+ raise e
+ except Exception as e:
+ log.error("Error in text2gremlin_api: %s", e)
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="An unexpected error occurred during Gremlin
generation.",
+ ) from e
diff --git a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
index cc92d2fe..e5e1c926 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
@@ -220,6 +220,8 @@ Assess the user's query to determine its complexity based
on the following crite
- You may use the vertex ID directly if it’s provided in the context.
- If the provided question contains entity names that are very similar to the
Vertices IDs, then in the generated Gremlin statement, replace the approximate
entities from the original question.
For example, if the question includes the name ABC, and the provided
VerticesIDs do not contain ABC but only abC, then use abC instead of ABC from
the original question when generating the gremlin.
+- Similarly, if the user's query refers to specific property names or their
values, and these are present or align with the 'Referenced Extracted
Properties', actively utilize these properties in your Gremlin query.
+For instance, you can use them for filtering vertices or edges (e.g., using
`has('propertyName', 'propertyValue')`), or for projecting specific values.
The output format must be as follows:
```gremlin
@@ -233,6 +235,9 @@ Refer Gremlin Example Pair:
Referenced Extracted Vertex IDs Related to the Query:
{vertices}
+Referenced Extracted Properties Related to the Query (Format:
[('property_name', 'property_value'), ...]):
+{properties}
+
Generate Gremlin from the Following User Query:
{query}
@@ -334,6 +339,7 @@ and experiences.
- 如果在上下文中提供了顶点 ID,可以直接使用。
- 如果提供的问题包含与顶点 ID 非常相似的实体名称,则在生成的 Gremlin 语句中替换原始问题中的近似实体。
例如,如果问题包含名称 ABC,而提供的顶点 ID 不包含 ABC 而只有 abC,则在生成 gremlin 时使用 abC 而不是原始问题中的 ABC。
+- 同样地,如果用户查询中提及特定的属性名称或属性值,并且这些属性在"查询相关的已提取属性"中存在或匹配,请在生成的 Gremlin
查询中充分利用这些属性信息。比如可以用它们进行顶点或边的过滤(如使用 `has('属性名', '属性值')`),或者用于特定值的投影查询。
输出格式必须如下:
```gremlin
@@ -347,6 +353,9 @@ g.V().limit(10)
与查询相关的已提取顶点 ID:
{vertices}
+查询相关的已提取属性(格式:[('属性名', '属性值'), ...]):
+{properties}
+
从以下用户查询生成 Gremlin:
{query}
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
index ec8adeab..2f9c3b34 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
@@ -39,6 +39,7 @@ from hugegraph_llm.demo.rag_demo.rag_block import
create_rag_block, rag_answer
from hugegraph_llm.demo.rag_demo.text2gremlin_block import (
create_text2gremlin_block,
graph_rag_recall,
+ gremlin_generate_selective,
)
from hugegraph_llm.demo.rag_demo.vector_graph_block import
create_vector_graph_block
from hugegraph_llm.resources.demo.css import CSS
@@ -179,6 +180,7 @@ def create_app():
apply_llm_config,
apply_embedding_config,
apply_reranker_config,
+ gremlin_generate_selective,
)
admin_http_api(api_auth, log_stream)
diff --git
a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
index e051eef3..7d682403 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
@@ -18,7 +18,8 @@
import json
import os
from datetime import datetime
-from typing import Any, Tuple, Dict, Union, Literal
+from dataclasses import dataclass
+from typing import Any, Tuple, Dict, Literal, Optional, List
import gradio as gr
import pandas as pd
@@ -34,6 +35,36 @@ from hugegraph_llm.utils.hugegraph_utils import
run_gremlin_query
from hugegraph_llm.utils.log import log
+@dataclass
+class GremlinResult:
+ """Standardized result class for gremlin_generate function"""
+ success: bool
+ match_result: str
+ template_gremlin: Optional[str] = None
+ raw_gremlin: Optional[str] = None
+ template_exec_result: Optional[str] = None
+ raw_exec_result: Optional[str] = None
+ error_message: Optional[str] = None
+
+ @classmethod
+ def error(cls, message: str) -> 'GremlinResult':
+ """Create an error result"""
+ return cls(success=False, match_result=message, error_message=message)
+
+ @classmethod
+ def success_result(cls, match_result: str, template_gremlin: str,
+ raw_gremlin: str, template_exec: str, raw_exec: str) ->
'GremlinResult':
+ """Create a successful result"""
+ return cls(
+ success=True,
+ match_result=match_result,
+ template_gremlin=template_gremlin,
+ raw_gremlin=raw_gremlin,
+ template_exec_result=template_exec,
+ raw_exec_result=raw_exec
+ )
+
+
def store_schema(schema, question, gremlin_prompt):
if (
prompt.text2gql_graph_schema != schema
@@ -83,52 +114,95 @@ def build_example_vector_index(temp_file) -> dict:
return builder.example_index_build(examples).run()
+def _process_schema(schema, generator, sm):
+ """Process and validate schema input"""
+ short_schema = False
+ if not schema:
+ return None, short_schema
+
+ schema = schema.strip()
+ if not schema.startswith("{"):
+ short_schema = True
+ log.info("Try to get schema from graph '%s'", schema)
+ generator.import_schema(from_hugegraph=schema)
+ schema = sm.schema.getSchema()
+ else:
+ try:
+ schema = json.loads(schema)
+ generator.import_schema(from_user_defined=schema)
+ except json.JSONDecodeError as e:
+ log.error("Invalid JSON schema provided: %s", e)
+ return None, None # Error case
+ return schema, short_schema
+
+
+def _configure_output_types(requested_outputs):
+ """Configure which outputs are requested"""
+ output_types = {
+ "match_result": True,
+ "template_gremlin": True,
+ "raw_gremlin": True,
+ "template_execution_result": True,
+ "raw_execution_result": True
+ }
+ if requested_outputs:
+ for key in output_types:
+ output_types[key] = False
+ for key in requested_outputs:
+ if key in output_types:
+ output_types[key] = True
+ return output_types
+
+
+def _execute_queries(context, output_types):
+ """Execute gremlin queries based on output requirements"""
+ if output_types["template_execution_result"]:
+ try:
+ context["template_exec_res"] =
run_gremlin_query(query=context["result"])
+ except Exception as e: # pylint: disable=broad-except
+ context["template_exec_res"] = f"{e}"
+ else:
+ context["template_exec_res"] = ""
+
+ if output_types["raw_execution_result"]:
+ try:
+ context["raw_exec_res"] =
run_gremlin_query(query=context["raw_result"])
+ except Exception as e: # pylint: disable=broad-except
+ context["raw_exec_res"] = f"{e}"
+ else:
+ context["raw_exec_res"] = ""
+
+
def gremlin_generate(
- inp, example_num, schema, gremlin_prompt
-) -> Union[tuple[str, str], tuple[str, Any, Any, Any, Any]]:
+ inp, example_num, schema, gremlin_prompt, requested_outputs:
Optional[List[str]] = None
+) -> GremlinResult:
generator = GremlinGenerator(llm=LLMs().get_text2gql_llm(),
embedding=Embeddings().get_embedding())
sm = SchemaManager(graph_name=schema)
- short_schema = False
- if schema:
- schema = schema.strip()
- if not schema.startswith("{"):
- short_schema = True
- log.info("Try to get schema from graph '%s'", schema)
- generator.import_schema(from_hugegraph=schema)
- # FIXME: update the logic here
- schema = sm.schema.getSchema()
- else:
- try:
- schema = json.loads(schema)
- generator.import_schema(from_user_defined=schema)
- except json.JSONDecodeError as e:
- log.error("Invalid JSON schema provided: %s", e)
- return "Invalid JSON schema, please check the format
carefully.", ""
- # FIXME: schema is not used in gremlin_generate() step, no context for it
(enhance the logic here)
- updated_schema = sm.simple_schema(schema) if short_schema else schema
+ processed_schema, short_schema = _process_schema(schema, generator, sm)
+ if processed_schema is None and short_schema is None:
+ return GremlinResult.error("Invalid JSON schema, please check the
format carefully.")
+
+ updated_schema = sm.simple_schema(processed_schema) if short_schema else
processed_schema
store_schema(str(updated_schema), inp, gremlin_prompt)
+
+ output_types = _configure_output_types(requested_outputs)
+
context = (
generator.example_index_query(example_num)
.gremlin_generate_synthesize(updated_schema, gremlin_prompt)
.run(query=inp)
)
- try:
- context["template_exec_res"] =
run_gremlin_query(query=context["result"])
- except Exception as e: # pylint: disable=broad-except
- context["template_exec_res"] = f"{e}"
- try:
- context["raw_exec_res"] =
run_gremlin_query(query=context["raw_result"])
- except Exception as e: # pylint: disable=broad-except
- context["raw_exec_res"] = f"{e}"
+
+ _execute_queries(context, output_types)
match_result = json.dumps(context.get("match_result", "No Results"),
ensure_ascii=False, indent=2)
- return (
- match_result,
- context["result"],
- context["raw_result"],
- context["template_exec_res"],
- context["raw_exec_res"],
+ return GremlinResult.success_result(
+ match_result=match_result,
+ template_gremlin=context["result"],
+ raw_gremlin=context["raw_result"],
+ template_exec=context["template_exec_res"],
+ raw_exec=context["raw_exec_res"],
)
@@ -152,12 +226,28 @@ def simple_schema(schema: Dict[str, Any]) -> Dict[str,
Any]:
return mini_schema
+def gremlin_generate_for_ui(inp, example_num, schema, gremlin_prompt):
+ """UI wrapper for gremlin_generate that returns tuple for Gradio
compatibility"""
+ result = gremlin_generate(inp, example_num, schema, gremlin_prompt)
+
+ if not result.success:
+ return result.match_result, "", "", "", ""
+
+ return (
+ result.match_result,
+ result.template_gremlin or "",
+ result.raw_gremlin or "",
+ result.template_exec_result or "",
+ result.raw_exec_result or ""
+ )
+
+
def create_text2gremlin_block() -> Tuple:
gr.Markdown(
"""## Build Vector Template Index (Optional)
- > Uploaded CSV file should be in `query,gremlin` format below:
- > e.g. `who is peter?`,`g.V().has('name', 'peter')`
- > JSON file should be in format below:
+ > Uploaded CSV file should be in `query,gremlin` format below:
+ > e.g. `who is peter?`,`g.V().has('name', 'peter')`
+ > JSON file should be in format below:
> e.g. `[{"query":"who is peter", "gremlin":"g.V().has('name', 'peter')"}]`
"""
)
@@ -192,7 +282,7 @@ def create_text2gremlin_block() -> Tuple:
)
btn = gr.Button("Text2Gremlin", variant="primary")
btn.click( # pylint: disable=no-member
- fn=gremlin_generate,
+ fn=gremlin_generate_for_ui,
inputs=[input_box, example_num_slider, schema_box, prompt_box],
outputs=[match, initialized_out, raw_out, tmpl_exec_out, raw_exec_out],
)
@@ -233,3 +323,53 @@ def graph_rag_recall(
)
context = rag.run(verbose=True, query=query, graph_search=True)
return context
+
+def gremlin_generate_selective(
+ inp: str,
+ example_num: int,
+ schema_input: str,
+ gremlin_prompt_input: str,
+ requested_outputs: Optional[List[str]] = None,
+) -> Dict[str, Any]:
+ """
+ Wraps the gremlin_generate function to return a dictionary of outputs
+ based on the requested_outputs list of strings.
+ """
+ output_keys = [
+ "match_result",
+ "template_gremlin",
+ "raw_gremlin",
+ "template_execution_result",
+ "raw_execution_result",
+ ]
+ if not requested_outputs: # None or empty list
+ requested_outputs = output_keys
+
+ result = gremlin_generate(
+ inp, example_num, schema_input, gremlin_prompt_input, requested_outputs
+ )
+
+ outputs_dict: Dict[str, Any] = {}
+
+ if not result.success:
+ # Handle error case
+ if "match_result" in requested_outputs:
+ outputs_dict["match_result"] = result.match_result
+ if result.error_message:
+ outputs_dict["error_detail"] = result.error_message
+ return outputs_dict
+
+ # Handle successful case
+ output_mapping = {
+ "match_result": result.match_result,
+ "template_gremlin": result.template_gremlin,
+ "raw_gremlin": result.raw_gremlin,
+ "template_execution_result": result.template_exec_result,
+ "raw_execution_result": result.raw_exec_result,
+ }
+
+ for key in requested_outputs:
+ if key in output_mapping:
+ outputs_dict[key] = output_mapping[key]
+
+ return outputs_dict