This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new c26f5bf refactor(llm): enhance the regex extraction func (#194)
c26f5bf is described below
commit c26f5bfb44bf6e349adadf071c0497f5d5d6ea95
Author: HaoJin Yang <[email protected]>
AuthorDate: Fri Mar 7 14:01:40 2025 +0800
refactor(llm): enhance the regex extraction func (#194)
---
.../operators/llm_op/gremlin_generate.py | 34 +++++++++++-----------
1 file changed, 17 insertions(+), 17 deletions(-)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
index 219a358..09e01e5 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
@@ -32,7 +32,7 @@ class GremlinGenerateSynthesize:
llm: BaseLLM = None,
schema: Optional[Union[dict, str]] = None,
vertices: Optional[List[str]] = None,
- gremlin_prompt: Optional[str] = None
+ gremlin_prompt: Optional[str] = None,
) -> None:
self.llm = llm or LLMs().get_text2gql_llm()
if isinstance(schema, dict):
@@ -41,10 +41,10 @@ class GremlinGenerateSynthesize:
self.vertices = vertices
self.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
- def _extract_gremlin(self, response: str) -> str:
- match = re.search("```gremlin.*```", response, re.DOTALL)
- assert match is not None, f"No gremlin found in response: {response}"
- return match.group()[len("```gremlin"):-len("```")].strip()
+ def _extract_response(self, response: str, label: str = "gremlin") -> str:
+ match = re.search(f"```{label}(.*?)```", response, re.DOTALL)
+ assert match is not None, f"No {label} found in response: {response}"
+ return match.group(1).strip()
def _format_examples(self, examples: Optional[List[Dict[str, str]]]) ->
Optional[str]:
if not examples:
@@ -52,8 +52,8 @@ class GremlinGenerateSynthesize:
example_strings = []
for example in examples:
example_strings.append(
- f"- query: {example['query']}\n"
- f"- gremlin:\n```gremlin\n{example['gremlin']}\n```")
+ f"- query: {example['query']}\n" f"-
gremlin:\n```gremlin\n{example['gremlin']}\n```"
+ )
return "\n\n".join(example_strings)
def _format_vertices(self, vertices: Optional[List[str]]) -> Optional[str]:
@@ -64,12 +64,12 @@ class GremlinGenerateSynthesize:
async def async_generate(self, context: Dict[str, Any]):
async_tasks = {}
query = context.get("query")
- raw_example = [{'query': 'who is peter', 'gremlin': "g.V().has('name',
'peter')"}]
+ raw_example = [{"query": "who is peter", "gremlin": "g.V().has('name',
'peter')"}]
raw_prompt = self.gremlin_prompt.format(
query=query,
schema=self.schema,
example=self._format_examples(examples=raw_example),
- vertices=self._format_vertices(vertices=self.vertices)
+ vertices=self._format_vertices(vertices=self.vertices),
)
async_tasks["raw_answer"] =
asyncio.create_task(self.llm.agenerate(prompt=raw_prompt))
@@ -78,7 +78,7 @@ class GremlinGenerateSynthesize:
query=query,
schema=self.schema,
example=self._format_examples(examples=examples),
- vertices=self._format_vertices(vertices=self.vertices)
+ vertices=self._format_vertices(vertices=self.vertices),
)
async_tasks["initialized_answer"] =
asyncio.create_task(self.llm.agenerate(prompt=init_prompt))
@@ -86,20 +86,20 @@ class GremlinGenerateSynthesize:
initialized_response = await async_tasks["initialized_answer"]
log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s",
init_prompt, initialized_response)
- context["result"] =
self._extract_gremlin(response=initialized_response)
- context["raw_result"] = self._extract_gremlin(response=raw_response)
+ context["result"] =
self._extract_response(response=initialized_response)
+ context["raw_result"] = self._extract_response(response=raw_response)
context["call_count"] = context.get("call_count", 0) + 2
return context
def sync_generate(self, context: Dict[str, Any]):
query = context.get("query")
- raw_example = [{'query': 'who is peter', 'gremlin': "g.V().has('name',
'peter')"}]
+ raw_example = [{"query": "who is peter", "gremlin": "g.V().has('name',
'peter')"}]
raw_prompt = self.gremlin_prompt.format(
query=query,
schema=self.schema,
example=self._format_examples(examples=raw_example),
- vertices=self._format_vertices(vertices=self.vertices)
+ vertices=self._format_vertices(vertices=self.vertices),
)
raw_response = self.llm.generate(prompt=raw_prompt)
@@ -108,14 +108,14 @@ class GremlinGenerateSynthesize:
query=query,
schema=self.schema,
example=self._format_examples(examples=examples),
- vertices=self._format_vertices(vertices=self.vertices)
+ vertices=self._format_vertices(vertices=self.vertices),
)
initialized_response = self.llm.generate(prompt=init_prompt)
log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s",
init_prompt, initialized_response)
- context["result"] =
self._extract_gremlin(response=initialized_response)
- context["raw_result"] = self._extract_gremlin(response=raw_response)
+ context["result"] =
self._extract_response(response=initialized_response)
+ context["raw_result"] = self._extract_response(response=raw_response)
context["call_count"] = context.get("call_count", 0) + 2
return context