This is an automated email from the ASF dual-hosted git repository.
ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 31b1720 feat: initialize rag based on HugeGraph (#20)
31b1720 is described below
commit 31b17206bfa9aff451fcfc0ce5954202c010afc0
Author: SweeeetYogurt <[email protected]>
AuthorDate: Thu Oct 26 09:50:31 2023 +0800
feat: initialize rag based on HugeGraph (#20)
* feat: init rag based on HugeGraph
* chore: format files
* rebase
* format code
* fix
---
hugegraph-llm/examples/graph_rag_test.py | 114 ++++++++++
hugegraph-llm/requirements.txt | 1 +
hugegraph-llm/src/hugegraph_llm/llms/base.py | 6 +-
hugegraph-llm/src/hugegraph_llm/llms/openai_llm.py | 22 +-
.../src/hugegraph_llm/operators/__init__.py | 55 ++---
.../hugegraph_llm/operators/build_kg_operator.py | 2 +-
.../hugegraph_llm/operators/graph_rag_operator.py | 89 ++++++++
.../operators/hugegraph_op/graph_rag_query.py | 253 +++++++++++++++++++++
.../operators/llm_op/answer_synthesize.py | 98 ++++++++
.../operators/llm_op/disambiguate_data.py | 3 +-
.../operators/llm_op/keyword_extract.py | 148 ++++++++++++
.../operators/llm_op/parse_text_to_data.py | 22 +-
.../hugegraph_llm/operators/utils_op/__init__.py | 58 ++---
.../operators/utils_op/nltk_helper.py | 83 +++++++
.../src/pyhugegraph/api/gremlin.py | 2 +-
.../structure/{respon_data.py => response_data.py} | 0
16 files changed, 852 insertions(+), 104 deletions(-)
diff --git a/hugegraph-llm/examples/graph_rag_test.py
b/hugegraph-llm/examples/graph_rag_test.py
new file mode 100644
index 0000000..5c163fe
--- /dev/null
+++ b/hugegraph-llm/examples/graph_rag_test.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import os
+
+from hugegraph_llm.operators.graph_rag_operator import GraphRAG
+from pyhugegraph.client import PyHugeClient
+
+
+def prepare_data():
+ client = PyHugeClient(
+ "127.0.0.1", 18080, "hugegraph", "admin", "admin"
+ )
+ schema = client.schema()
+ schema.propertyKey("name").asText().ifNotExist().create()
+ schema.propertyKey("birthDate").asText().ifNotExist().create()
+ schema.vertexLabel("Person").properties("name", "birthDate") \
+ .useCustomizeStringId().ifNotExist().create()
+
schema.vertexLabel("Movie").properties("name").useCustomizeStringId().ifNotExist().create()
+
schema.indexLabel("PersonByName").onV("Person").by("name").secondary().ifNotExist().create()
+
schema.indexLabel("MovieByName").onV("Movie").by("name").secondary().ifNotExist().create()
+
schema.edgeLabel("ActedIn").sourceLabel("Person").targetLabel("Movie").ifNotExist().create()
+
+ graph = client.graph()
+ graph.addVertex("Person", {"name": "Al Pacino", "birthDate":
"1940-04-25"}, id="Al Pacino")
+ graph.addVertex(
+ "Person", {"name": "Robert De Niro", "birthDate": "1943-08-17"},
id="Robert De Niro")
+ graph.addVertex("Movie", {"name": "The Godfather"}, id="The Godfather")
+ graph.addVertex("Movie", {"name": "The Godfather Part II"}, id="The
Godfather Part II")
+ graph.addVertex("Movie", {"name": "The Godfather Coda The Death of Michael
Corleone"},
+ id="The Godfather Coda The Death of Michael Corleone")
+
+ graph.addEdge("ActedIn", "Al Pacino", "The Godfather", {})
+ graph.addEdge("ActedIn", "Al Pacino", "The Godfather Part II", {})
+ graph.addEdge("ActedIn", "Al Pacino", "The Godfather Coda The Death of
Michael Corleone", {})
+ graph.addEdge("ActedIn", "Robert De Niro", "The Godfather Part II", {})
+
+ graph.close()
+
+
+if __name__ == '__main__':
+ os.environ["http_proxy"] = ""
+ os.environ["https_proxy"] = ""
+ os.environ["OPENAI_API_KEY"] = ""
+
+ # prepare_data()
+
+ graph_rag = GraphRAG()
+
+ # configure operator with context dict
+ context = {
+ # hugegraph client
+ "ip": "localhost", # default to "localhost" if not set
+ "port": 18080, # default to 8080 if not set
+ "user": "admin", # default to "admin" if not set
+ "pwd": "admin", # default to "admin" if not set
+ "graph": "hugegraph", # default to "hugegraph" if not set
+
+ # query question
+ "query": "Tell me about Al Pacino.", # must be set
+
+ # keywords extraction
+ "max_keywords": 5, # default to 5 if not set
+ "language": "english", # default to "english" if not set
+
+ # graph rag query
+ "prop_to_match": "name", # default to None if not set
+ "max_deep": 2, # default to 2 if not set
+ "max_items": 30, # default to 30 if not set
+
+ # print intermediate processes result
+ "verbose": True, # default to False if not set
+ }
+ result = graph_rag \
+ .extract_keyword() \
+ .query_graph_for_rag() \
+ .synthesize_answer() \
+ .run(**context)
+ print(f"Query:\n- {context['query']}")
+ print(f"Answer:\n- {result['answer']}")
+
+ print("--------------------------------------------------------")
+
+ # configure operator with parameters
+ graph_client = PyHugeClient(
+ "127.0.0.1", 18080, "hugegraph", "admin", "admin"
+ )
+ result = graph_rag.extract_keyword(
+ text="Tell me about Al Pacino.",
+ max_keywords=5, # default to 5 if not set
+ language="english", # default to "english" if not set
+ ).query_graph_for_rag(
+ graph_client=graph_client,
+ max_deep=2, # default to 2 if not set
+ max_items=30, # default to 30 if not set
+ prop_to_match=None, # default to None if not set
+ ).synthesize_answer().run(verbose=True)
+ print("Query:\n- Tell me about Al Pacino.")
+ print(f"Answer:\n- {result['answer']}")
diff --git a/hugegraph-llm/requirements.txt b/hugegraph-llm/requirements.txt
index 5341b0b..e8cd6d0 100644
--- a/hugegraph-llm/requirements.txt
+++ b/hugegraph-llm/requirements.txt
@@ -1,3 +1,4 @@
openai==0.28.1
retry==0.9.2
tiktoken==0.5.1
+nltk==3.8.1
\ No newline at end of file
diff --git a/hugegraph-llm/src/hugegraph_llm/llms/base.py
b/hugegraph-llm/src/hugegraph_llm/llms/base.py
index a9a1bdb..17b51be 100644
--- a/hugegraph-llm/src/hugegraph_llm/llms/base.py
+++ b/hugegraph-llm/src/hugegraph_llm/llms/base.py
@@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
-from typing import Any, List, Optional, Callable
+from typing import Any, List, Optional, Callable, Dict
class BaseLLM(ABC):
@@ -26,7 +26,7 @@ class BaseLLM(ABC):
@abstractmethod
def generate(
self,
- messages: Optional[List[str]] = None,
+ messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
"""Comment"""
@@ -34,7 +34,7 @@ class BaseLLM(ABC):
@abstractmethod
async def generate_streaming(
self,
- messages: Optional[List[str]] = None,
+ messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Callable = None,
) -> List[Any]:
diff --git a/hugegraph-llm/src/hugegraph_llm/llms/openai_llm.py
b/hugegraph-llm/src/hugegraph_llm/llms/openai_llm.py
index 7766abf..47735b6 100644
--- a/hugegraph-llm/src/hugegraph_llm/llms/openai_llm.py
+++ b/hugegraph-llm/src/hugegraph_llm/llms/openai_llm.py
@@ -16,7 +16,8 @@
# under the License.
-from typing import Callable, List, Optional
+import os
+from typing import Callable, List, Optional, Dict, Any
import openai
import tiktoken
from retry import retry
@@ -34,7 +35,7 @@ class OpenAIChat(BaseLLM):
max_tokens: int = 1000,
temperature: float = 0.0,
) -> None:
- openai.api_key = api_key
+ openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.model = model_name
self.max_tokens = max_tokens
self.temperature = temperature
@@ -42,9 +43,13 @@ class OpenAIChat(BaseLLM):
@retry(tries=3, delay=1)
def generate(
self,
- messages: Optional[List[str]] = None,
+ messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
+ """Generate a response to the query messages/prompt."""
+ if messages is None:
+ assert prompt is not None, "Messages or prompt must be provided."
+ messages = [{"role": "user", "content": prompt}]
try:
completions = openai.ChatCompletion.create(
model=self.model,
@@ -57,18 +62,19 @@ class OpenAIChat(BaseLLM):
except openai.error.InvalidRequestError as e:
return str(f"Error: {e}")
# catch authorization errors / do not retry
- except openai.error.AuthenticationError as e:
- return f"Error: The provided OpenAI API key is invalid, {e}"
+ except openai.error.AuthenticationError:
+ return "Error: The provided OpenAI API key is invalid"
except Exception as e:
print(f"Retrying LLM call {e}")
- raise Exception() from e
+ raise e
async def generate_streaming(
self,
- messages: Optional[List[str]] = None,
+ messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Callable = None,
) -> str:
+ """Generate a response to the query messages/prompt in streaming
mode."""
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
@@ -89,10 +95,12 @@ class OpenAIChat(BaseLLM):
return result
async def num_tokens_from_string(self, string: str) -> int:
+ """Get token count from string."""
encoding = tiktoken.encoding_for_model(self.model)
num_tokens = len(encoding.encode(string))
return num_tokens
async def max_allowed_token_length(self) -> int:
+ """Get max-allowed token length"""
# TODO: list all models and their max tokens from api
return 2049
diff --git a/hugegraph-python-client/src/pyhugegraph/structure/respon_data.py
b/hugegraph-llm/src/hugegraph_llm/operators/__init__.py
similarity index 60%
copy from hugegraph-python-client/src/pyhugegraph/structure/respon_data.py
copy to hugegraph-llm/src/hugegraph_llm/operators/__init__.py
index 0f4f1b1..13a8339 100644
--- a/hugegraph-python-client/src/pyhugegraph/structure/respon_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/__init__.py
@@ -1,39 +1,16 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-class ResponseData:
- def __init__(self, dic):
- self.__id = dic["requestId"]
- self.__status = dic["status"]
- self.__result = dic["result"]
-
- @property
- def id(self):
- return self.__id
-
- @property
- def status(self):
- return self.__status
-
- @property
- def result(self):
- return self.__result
-
- def __repr__(self):
- res = f"id: {self.__id}, status: {self.__status}, result:
{self.__result}"
- return res
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/build_kg_operator.py
b/hugegraph-llm/src/hugegraph_llm/operators/build_kg_operator.py
index deb7a2b..7e4b51c 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/build_kg_operator.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/build_kg_operator.py
@@ -16,13 +16,13 @@
# under the License.
+from hugegraph_llm.llms.base import BaseLLM
from hugegraph_llm.operators.hugegraph_op.commit_data_to_kg import
CommitDataToKg
from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData
from hugegraph_llm.operators.llm_op.parse_text_to_data import (
ParseTextToData,
ParseTextToDataWithSchemas,
)
-from hugegraph_llm.llms.base import BaseLLM
class KgBuilder:
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_operator.py
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_operator.py
new file mode 100644
index 0000000..e3c27bd
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_operator.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from typing import Dict, Any, Optional, List
+
+from hugegraph_llm.llms.base import BaseLLM
+from hugegraph_llm.llms.openai_llm import OpenAIChat
+from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery
+from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
+from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract
+from pyhugegraph.client import PyHugeClient
+
+
+class GraphRAG:
+ def __init__(self, llm: Optional[BaseLLM] = None):
+ self._llm = llm or OpenAIChat()
+ self._operators: List[Any] = []
+
+ def extract_keyword(
+ self,
+ text: Optional[str] = None,
+ max_keywords: int = 5,
+ language: str = 'english',
+ extract_template: Optional[str] = None,
+ expand_template: Optional[str] = None,
+ ):
+ self._operators.append(
+ KeywordExtract(
+ text=text,
+ max_keywords=max_keywords,
+ language=language,
+ extract_template=extract_template,
+ expand_template=expand_template,
+ )
+ )
+ return self
+
+ def query_graph_for_rag(
+ self,
+ graph_client: Optional[PyHugeClient] = None,
+ max_deep: int = 2,
+ max_items: int = 30,
+ prop_to_match: Optional[str] = None,
+ ):
+ self._operators.append(
+ GraphRAGQuery(
+ client=graph_client,
+ max_deep=max_deep,
+ max_items=max_items,
+ prop_to_match=prop_to_match,
+ )
+ )
+ return self
+
+ def synthesize_answer(
+ self,
+ prompt_template: Optional[str] = None,
+ ):
+ self._operators.append(
+ AnswerSynthesize(
+ prompt_template=prompt_template,
+ )
+ )
+ return self
+
+ def run(self, **kwargs) -> Dict[str, Any]:
+ if len(self._operators) == 0:
+ self.extract_keyword().query_graph_for_rag().synthesize_answer()
+
+ context = kwargs
+ context["llm"] = self._llm
+ for op in self._operators:
+ context = op.run(context=context)
+ return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
new file mode 100644
index 0000000..b50d1c1
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import re
+from typing import Any, Dict, Optional, List, Set, Tuple
+
+from pyhugegraph.client import PyHugeClient
+
+
+class GraphRAGQuery:
+ ID_RAG_GREMLIN_QUERY_TEMPL = (
+ "g.V().hasId({keywords}).as('subj')"
+ ".repeat("
+ " bothE({edge_labels}).as('rel').otherV().as('obj')"
+ ").times({max_deep})"
+ ".path()"
+ ".by(project('label', 'id', 'props')"
+ " .by(label())"
+ " .by(id())"
+ " .by(valueMap().by(unfold()))"
+ ")"
+ ".by(project('label', 'inV', 'outV', 'props')"
+ " .by(label())"
+ " .by(inV().id())"
+ " .by(outV().id())"
+ " .by(valueMap().by(unfold()))"
+ ")"
+ ".limit({max_items})"
+ ".toList()"
+ )
+ PROP_RAG_GREMLIN_QUERY_TEMPL = (
+ "g.V().has('{prop}', within({keywords})).as('subj')"
+ ".repeat("
+ " bothE({edge_labels}).as('rel').otherV().as('obj')"
+ ").times({max_deep})"
+ ".path()"
+ ".by(project('label', 'props')"
+ " .by(label())"
+ " .by(valueMap().by(unfold()))"
+ ")"
+ ".by(project('label', 'inV', 'outV', 'props')"
+ " .by(label())"
+ " .by(inV().values('{prop}'))"
+ " .by(outV().values('{prop}'))"
+ " .by(valueMap().by(unfold()))"
+ ")"
+ ".limit({max_items})"
+ ".toList()"
+ )
+
+ def __init__(
+ self,
+ client: Optional[PyHugeClient] = None,
+ max_deep: int = 2,
+ max_items: int = 30,
+ prop_to_match: Optional[str] = None,
+ ):
+ self._client = client
+ self._max_deep = max_deep
+ self._max_items = max_items
+ self._prop_to_match = prop_to_match
+ self._schema = ""
+
+ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
+ if self._client is None:
+ if isinstance(context.get("graph_client"), PyHugeClient):
+ self._client = context["graph_client"]
+ else:
+ ip = context.get("ip") or "localhost"
+ port = context.get("port") or 8080
+ graph = context.get("graph") or "hugegraph"
+ user = context.get("user") or "admin"
+ pwd = context.get("pwd") or "admin"
+ self._client = PyHugeClient(
+ ip=ip, port=port, graph=graph, user=user, pwd=pwd
+ )
+ assert self._client is not None, "No graph for query."
+
+ keywords = context.get("keywords")
+ assert keywords is not None, "No keywords for query."
+
+ if isinstance(context.get("max_deep"), int):
+ self._max_deep = context["max_deep"]
+ if isinstance(context.get("max_items"), int):
+ self._max_items = context["max_items"]
+ if isinstance(context.get("prop_to_match"), str):
+ self._prop_to_match = context["prop_to_match"]
+
+ _, edge_labels = self._extract_labels_from_schema()
+ edge_labels_str = ",".join("'" + label + "'" for label in edge_labels)
+
+ use_id_to_match = self._prop_to_match is None
+
+ if not use_id_to_match:
+ keywords_str = ",".join("'" + kw + "'" for kw in keywords)
+ rag_gremlin_query_template = self.PROP_RAG_GREMLIN_QUERY_TEMPL
+ rag_gremlin_query = rag_gremlin_query_template.format(
+ prop=self._prop_to_match,
+ keywords=keywords_str,
+ max_deep=self._max_deep,
+ max_items=self._max_items,
+ edge_labels=edge_labels_str,
+ )
+ else:
+ id_format = self._get_graph_id_format()
+ if id_format == "STRING":
+ keywords_str = ",".join("'" + kw + "'" for kw in keywords)
+ else:
+ raise RuntimeError("Unsupported ID format for Graph RAG.")
+
+ rag_gremlin_query_template = self.ID_RAG_GREMLIN_QUERY_TEMPL
+ rag_gremlin_query = rag_gremlin_query_template.format(
+ keywords=keywords_str,
+ max_deep=self._max_deep,
+ max_items=self._max_items,
+ edge_labels=edge_labels_str,
+ )
+
+ result: List[Any] =
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
+ knowledge: Set[str] = self._format_knowledge_from_query_result(
+ query_result=result
+ )
+
+ context["synthesize_context_body"] = list(knowledge)
+ context["synthesize_context_head"] = (
+ f"The following are knowledge sequence in max depth
{self._max_deep} "
+ f"in the form of directed graph like:\n"
+ "`subject -[predicate]-> object <-[predicate_next_hop]-
object_next_hop ...` "
+ "extracted based on key entities as subject:"
+ )
+
+ verbose = context.get("verbose") or False
+ if verbose:
+ print("\033[93mKNOWLEDGE FROM GRAPH:")
+ print("\n".join(rel for rel in context["synthesize_context_body"])
+ "\033[0m")
+
+ return context
+
+ def _format_knowledge_from_query_result(
+ self,
+ query_result: List[Any],
+ ) -> Set[str]:
+ use_id_to_match = self._prop_to_match is None
+ knowledge = set()
+ for line in query_result:
+ flat_rel = ""
+ raw_flat_rel = line["objects"]
+ assert len(raw_flat_rel) % 2 == 1
+ node_cache = set()
+ prior_edge_str_len = 0
+ for i, item in enumerate(raw_flat_rel):
+ if i % 2 == 0:
+ matched_str = (
+ item["id"]
+ if use_id_to_match
+ else item["props"][self._prop_to_match]
+ )
+ if matched_str in node_cache:
+ flat_rel = flat_rel[:-prior_edge_str_len]
+ break
+ node_cache.add(matched_str)
+ props_str = ", ".join(f"{k}: {v}" for k, v in
item["props"].items())
+ node_str = f"{item['label']}{{{props_str}}}"
+ flat_rel += node_str
+ if flat_rel in knowledge:
+ knowledge.remove(flat_rel)
+ else:
+ props_str = ", ".join(f"{k}: {v}" for k, v in
item["props"].items())
+ props_str = f"{{{props_str}}}" if len(props_str) > 0 else
""
+ prev_matched_str = (
+ raw_flat_rel[i - 1]["id"]
+ if use_id_to_match
+ else raw_flat_rel[i - 1]["props"][self._prop_to_match]
+ )
+ if item["outV"] == prev_matched_str:
+ edge_str = f" -[{item['label']}{props_str}]-> "
+ else:
+ edge_str = f" <-[{item['label']}{props_str}]- "
+ flat_rel += edge_str
+ prior_edge_str_len = len(edge_str)
+ knowledge.add(flat_rel)
+ return knowledge
+
+ def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]:
+ schema = self._get_graph_schema()
+ node_props_str, edge_props_str = schema.split("\n")[:2]
+ node_props_str = (
+ node_props_str[len("Node properties: "):].strip("[").strip("]")
+ )
+ edge_props_str = (
+ edge_props_str[len("Edge properties: "):].strip("[").strip("]")
+ )
+ node_labels = self._extract_label_names(node_props_str)
+ edge_labels = self._extract_label_names(edge_props_str)
+ return node_labels, edge_labels
+
+ @staticmethod
+ def _extract_label_names(
+ source: str,
+ head: str = "name: ",
+ tail: str = ", ",
+ ) -> List[str]:
+ result = []
+ for s in source.split(head):
+ end = s.find(tail)
+ label = s[:end]
+ if label:
+ result.append(label)
+ return result
+
+ def _get_graph_id_format(self) -> str:
+ sample = self._client.gremlin().exec("g.V().limit(1)")["data"]
+ if len(sample) == 0:
+ return "EMPTY"
+ sample_id = sample[0]["id"]
+ if isinstance(sample_id, int):
+ return "INT"
+ if isinstance(sample_id, str):
+ if re.match(r"^\d+:.*", sample_id):
+ return "INT:STRING"
+ return "STRING"
+ return "UNKNOWN"
+
+ def _get_graph_schema(self, refresh: bool = False) -> str:
+ if self._schema and not refresh:
+ return self._schema
+
+ schema = self._client.schema()
+ vertex_schema = schema.get_vertex_labels()
+ edge_schema = schema.get_edge_labels()
+ relationships = schema.get_relations()
+
+ self._schema = (
+ f"Node properties: {vertex_schema}\n"
+ f"Edge properties: {edge_schema}\n"
+ f"Relationships: {relationships}\n"
+ )
+ return self._schema
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
new file mode 100644
index 0000000..7a24ebc
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from typing import Any, Dict, Optional
+
+from hugegraph_llm.llms.base import BaseLLM
+from hugegraph_llm.llms.openai_llm import OpenAIChat
+
+DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL = (
+ "Context information is below.\n"
+ "---------------------\n"
+ "{context_str}\n"
+ "---------------------\n"
+ "Given the context information and not prior knowledge, answer the
query.\n"
+ "Query: {query_str}\n"
+ "Answer: "
+)
+
+
+class AnswerSynthesize:
+ def __init__(
+ self,
+ llm: Optional[BaseLLM] = None,
+ prompt_template: Optional[str] = None,
+ question: Optional[str] = None,
+ context_body: Optional[str] = None,
+ context_head: Optional[str] = None,
+ context_tail: Optional[str] = None,
+ ):
+ self._llm = llm
+ self._prompt_template = (
+ prompt_template or DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL
+ )
+ self._question = question
+ self._context_body = context_body
+ self._context_head = context_head
+ self._context_tail = context_tail
+
+ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
+ if self._llm is None:
+ self._llm = context.get("llm") or OpenAIChat()
+ if context.get("llm") is None:
+ context["llm"] = self._llm
+
+ if self._question is None:
+ self._question = context.get("query") or None
+
+ if self._context_body is None:
+ self._context_body = context.get("synthesize_context_body") or None
+
+ assert self._context_body is not None, "No context for synthesizing."
+ assert self._question is not None, "No question for synthesizing."
+
+ if isinstance(self._context_body, str):
+ context_body_str = self._context_body
+ elif isinstance(self._context_body, (list, set)):
+ context_body_str = "\n".join(line for line in self._context_body)
+ elif isinstance(self._context_body, dict):
+ context_body_str = "\n".join(f"{k}: {v}" for k, v in
self._context_body.items())
+ else:
+ context_body_str = str(self._context_body)
+
+ context_head_str = context.get("synthesize_context_head") or
self._context_head or ""
+ context_tail_str = context.get("synthesize_context_tail") or
self._context_tail or ""
+
+ context_str = (
+ f"{context_head_str}\n"
+ f"{context_body_str}\n"
+ f"{context_tail_str}"
+ ).strip("\n")
+
+ prompt = self._prompt_template.format(
+ context_str=context_str,
+ query_str=self._question,
+ )
+ response = self._llm.generate(prompt=prompt)
+ context["answer"] = response
+
+ verbose = context.get("verbose") or False
+ if verbose:
+ print(f"\033[91mANSWER: {response}\033[0m")
+
+ return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
index f540673..f7c16b9 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
@@ -19,6 +19,7 @@
import json
import re
from itertools import groupby
+from typing import Dict, List, Any
from hugegraph_llm.operators.llm_op.unstructured_data_utils import (
nodes_text_to_list_of_dict,
@@ -112,7 +113,7 @@ class DisambiguateData:
self.llm = llm
self.is_user_schema = is_user_schema
- def run(self, data: dict) -> dict[str, list[any]]:
+ def run(self, data: Dict) -> Dict[str, List[Any]]:
nodes = sorted(data["nodes"], key=lambda x: x.get("label", ""))
relationships = data["relationships"]
nodes_schemas = data["nodes_schemas"]
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
new file mode 100644
index 0000000..60d816a
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import re
+from typing import Set, Dict, Any, Optional
+
+from hugegraph_llm.llms.base import BaseLLM
+from hugegraph_llm.llms.openai_llm import OpenAIChat
+from hugegraph_llm.operators.utils_op import nltk_helper
+
+
+DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL = (
+ "A question is provided below. Given the question, "
+ "extract up to {max_keywords} keywords from the text. "
+ "Focus on extracting the keywords that we can use "
+ "to best lookup answers to the question. "
+ "Avoid stopwords.\n"
+ "---------------------\n"
+ "{question}\n"
+ "---------------------\n"
+ "Provide keywords in the following comma-separated format: 'KEYWORDS:
<keywords>'"
+)
+
+DEFAULT_KEYWORDS_EXPAND_TEMPLATE_TMPL = (
+ "Generate synonyms or possible form of keywords up to {max_keywords} in
total,\n"
+ "considering possible cases of capitalization, pluralization, common
expressions, etc.\n"
+ "Provide all synonyms of keywords in comma-separated format: 'SYNONYMS:
<keywords>'\n"
+ "Note, result should be in one-line with only one 'SYNONYMS: ' prefix\n"
+ "----\n"
+ "KEYWORDS: {question}\n"
+ "----"
+)
+
+
+class KeywordExtract:
+ def __init__(
+ self,
+ text: Optional[str] = None,
+ llm: Optional[BaseLLM] = None,
+ max_keywords: int = 5,
+ extract_template: Optional[str] = None,
+ expand_template: Optional[str] = None,
+ language: str = 'english',
+ ):
+ self._llm = llm
+ self._query = text
+ self._language = language.lower()
+ self._max_keywords = max_keywords
+ self._extract_template = (
+ extract_template or DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL
+ )
+ self._expand_template = expand_template or
DEFAULT_KEYWORDS_EXPAND_TEMPLATE_TMPL
+
+ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
+ if self._query is None:
+ self._query = context.get("query")
+ assert self._query is not None, "No query for keywords extraction."
+ else:
+ context["query"] = self._query
+
+ if self._llm is None:
+ self._llm = context.get("llm") or OpenAIChat()
+ assert isinstance(self._llm, BaseLLM), "Invalid LLM Object."
+ if context.get("llm") is None:
+ context["llm"] = self._llm
+
+ if isinstance(context.get('language'), str):
+ self._language = context['language'].lower()
+ else:
+ context["language"] = self._language
+
+ if isinstance(context.get("max_keywords"), int):
+ self._max_keywords = context["max_keywords"]
+
+ prompt = self._extract_template.format(
+ question=self._query,
+ max_keywords=self._max_keywords,
+ )
+ response = self._llm.generate(prompt=prompt)
+
+ keywords = self._extract_keywords_from_response(
+ response=response, lowercase=False, start_token="KEYWORDS:"
+ )
+ keywords.union(self._expand_synonyms(keywords=keywords))
+ context["keywords"] = list(keywords)
+
+ verbose = context.get("verbose") or False
+ if verbose:
+ print(f"\033[92mKEYWORDS: {context['keywords']}\033[0m")
+
+ return context
+
+ def _expand_synonyms(self, keywords: Set[str]) -> Set[str]:
+ prompt = self._expand_template.format(
+ question=str(keywords),
+ max_keywords=self._max_keywords,
+ )
+ response = self._llm.generate(prompt=prompt)
+ keywords = self._extract_keywords_from_response(
+ response=response, lowercase=False, start_token="SYNONYMS:"
+ )
+ return keywords
+
+ def _extract_keywords_from_response(
+ self,
+ response: str,
+ lowercase: bool = True,
+ start_token: str = "",
+ ) -> Set[str]:
+ keywords = []
+ response = response.strip() # Strip newlines from responses.
+
+ if response.startswith(start_token):
+ response = response[len(start_token):]
+
+ for k in response.split(","):
+ rk = k
+ if lowercase:
+ rk = rk.lower()
+ keywords.append(rk.strip())
+
+ # if keyword consists of multiple words, split into sub-words
+ # (removing stopwords)
+ results = set()
+ for token in keywords:
+ results.add(token)
+ sub_tokens = re.findall(r"\w+", token)
+ if len(sub_tokens) > 1:
+ results.update(
+ {w for w in sub_tokens if w not in
nltk_helper.stopwords(lang=self._language)}
+ )
+
+ return results
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/parse_text_to_data.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/parse_text_to_data.py
index 42d36ac..05d41b2 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/parse_text_to_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/parse_text_to_data.py
@@ -17,15 +17,15 @@
import re
-from typing import List
+from typing import List, Any, Dict
+from hugegraph_llm.llms.base import BaseLLM
from hugegraph_llm.operators.llm_op.unstructured_data_utils import (
nodes_text_to_list_of_dict,
nodes_schemas_text_to_list_of_dict,
relationships_schemas_text_to_list_of_dict,
relationships_text_to_list_of_dict,
)
-from hugegraph_llm.llms.base import BaseLLM
def generate_system_message() -> str:
@@ -79,11 +79,11 @@ RelationshipsSchemas: {relationships_schemas}"""
def split_string(string, max_length) -> List[str]:
- return [string[i : i + max_length] for i in range(0, len(string),
max_length)]
+ return [string[i: i + max_length] for i in range(0, len(string),
max_length)]
def split_string_to_fit_token_space(
- llm: BaseLLM, string: str, token_use_per_string: int
+ llm: BaseLLM, string: str, token_use_per_string: int
) -> List[str]:
allowed_tokens = llm.max_allowed_token_length() - token_use_per_string
chunked_data = split_string(string, 500)
@@ -91,8 +91,8 @@ def split_string_to_fit_token_space(
current_chunk = ""
for chunk in chunked_data:
if (
- llm.num_tokens_from_string(current_chunk) +
llm.num_tokens_from_string(chunk)
- < allowed_tokens
+ llm.num_tokens_from_string(current_chunk) +
llm.num_tokens_from_string(chunk)
+ < allowed_tokens
):
current_chunk += chunk
else:
@@ -125,11 +125,7 @@ def get_nodes_and_relationships_from_result(result):
relationships.extend(re.findall(internal_regex, raw_relationships))
nodes_schemas.extend(re.findall(internal_regex, raw_nodes_schemas))
relationships_schemas.extend(re.findall(internal_regex,
raw_relationships_schemas))
- result = {}
- result["nodes"] = []
- result["relationships"] = []
- result["nodes_schemas"] = []
- result["relationships_schemas"] = []
+ result = {"nodes": [], "relationships": [], "nodes_schemas": [],
"relationships_schemas": []}
result["nodes"].extend(nodes_text_to_list_of_dict(nodes))
result["relationships"].extend(relationships_text_to_list_of_dict(relationships))
result["nodes_schemas"].extend(nodes_schemas_text_to_list_of_dict(nodes_schemas))
@@ -155,7 +151,7 @@ class ParseTextToData:
output = self.llm.generate(messages)
return output
- def run(self, data: dict) -> dict[str, list[any]]:
+ def run(self, data: Dict) -> Dict[str, List[Any]]:
system_message = generate_system_message()
prompt_string = generate_prompt("")
token_usage_per_prompt =
self.llm.num_tokens_from_string(system_message + prompt_string)
@@ -196,7 +192,7 @@ class ParseTextToDataWithSchemas:
output = self.llm.generate(messages)
return output
- def run(self) -> dict[str, list[any]]:
+ def run(self) -> Dict[str, List[Any]]:
system_message = generate_system_message_with_schemas()
prompt_string = generate_prompt_with_schemas("", "", "")
token_usage_per_prompt =
self.llm.num_tokens_from_string(system_message + prompt_string)
diff --git a/hugegraph-python-client/src/pyhugegraph/structure/respon_data.py
b/hugegraph-llm/src/hugegraph_llm/operators/utils_op/__init__.py
similarity index 60%
copy from hugegraph-python-client/src/pyhugegraph/structure/respon_data.py
copy to hugegraph-llm/src/hugegraph_llm/operators/utils_op/__init__.py
index 0f4f1b1..309b3ca 100644
--- a/hugegraph-python-client/src/pyhugegraph/structure/respon_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/utils_op/__init__.py
@@ -1,39 +1,19 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-class ResponseData:
- def __init__(self, dic):
- self.__id = dic["requestId"]
- self.__status = dic["status"]
- self.__result = dic["result"]
-
- @property
- def id(self):
- return self.__id
-
- @property
- def status(self):
- return self.__status
-
- @property
- def result(self):
- return self.__result
-
- def __repr__(self):
- res = f"id: {self.__id}, status: {self.__status}, result:
{self.__result}"
- return res
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from .nltk_helper import nltk_helper
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/utils_op/nltk_helper.py
b/hugegraph-llm/src/hugegraph_llm/operators/utils_op/nltk_helper.py
new file mode 100644
index 0000000..35aa921
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/operators/utils_op/nltk_helper.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import os
+import sys
+from pathlib import Path
+from typing import List, Optional, Dict
+
+import nltk
+from nltk.corpus import stopwords
+
+
+class NLTKHelper:
+
+ _stopwords: Dict[str, Optional[List[str]]] = {
+ "english": None,
+ "chinese": None,
+ }
+
+ def stopwords(self, lang: str = "english") -> List[str]:
+ """Get stopwords."""
+ if self._stopwords.get(lang) is None:
+ cache_dir = self.get_cache_dir()
+ nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
+
+ # update nltk path for nltk so that it finds the data
+ if nltk_data_dir not in nltk.data.path:
+ nltk.data.path.append(nltk_data_dir)
+
+ try:
+ nltk.data.find("corpora/stopwords")
+ except LookupError:
+ nltk.download("stopwords", download_dir=nltk_data_dir)
+ self._stopwords[lang] = stopwords.words(lang)
+
+ return self._stopwords[lang]
+
+ @staticmethod
+ def get_cache_dir() -> str:
+ """Locate a platform-appropriate cache directory for hugegraph-llm,
+ and create it if it doesn't yet exist
+ """
+ # User override
+ if "HG_AI_CACHE_DIR" in os.environ:
+ path = Path(os.environ["HG_AI_CACHE_DIR"])
+
+ # Linux, Unix, AIX, etc.
+ elif os.name == "posix" and sys.platform != "darwin":
+ path = Path("/tmp/hugegraph_llm")
+
+ # Mac OS
+ elif sys.platform == "darwin":
+ path = Path(os.path.expanduser("~"),
"Library/Caches/hugegraph_llm")
+
+ # Windows (hopefully)
+ else:
+ local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
+ "~\\AppData\\Local"
+ )
+ path = Path(local, "hugegraph_llm")
+
+ if not os.path.exists(path):
+ os.makedirs(path, exist_ok=True)
+
+ return str(path)
+
+
+nltk_helper = NLTKHelper()
diff --git a/hugegraph-python-client/src/pyhugegraph/api/gremlin.py
b/hugegraph-python-client/src/pyhugegraph/api/gremlin.py
index 8126e5a..97b63cb 100644
--- a/hugegraph-python-client/src/pyhugegraph/api/gremlin.py
+++ b/hugegraph-python-client/src/pyhugegraph/api/gremlin.py
@@ -18,7 +18,7 @@ import json
import re
from pyhugegraph.api.common import HugeParamsBase
-from pyhugegraph.structure.respon_data import ResponseData
+from pyhugegraph.structure.response_data import ResponseData
from pyhugegraph.utils.exceptions import NotFoundError
from pyhugegraph.utils.huge_requests import HugeSession
from pyhugegraph.utils.util import check_if_success
diff --git a/hugegraph-python-client/src/pyhugegraph/structure/respon_data.py
b/hugegraph-python-client/src/pyhugegraph/structure/response_data.py
similarity index 100%
rename from hugegraph-python-client/src/pyhugegraph/structure/respon_data.py
rename to hugegraph-python-client/src/pyhugegraph/structure/response_data.py