This is an automated email from the ASF dual-hosted git repository.
ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new e6136ae feat(llm): knowledge graph construction by llm (#7)
e6136ae is described below
commit e6136aef1d56ebc24237bacc73f5b37885ecd33e
Author: lzyxx <[email protected]>
AuthorDate: Fri Oct 20 14:57:48 2023 +0800
feat(llm): knowledge graph construction by llm (#7)
* add hugegraph-llm text2kg
* fix codestyle
* Delete .idea directory
* Update build_kg_test.py(delete api)
* fix codestyle
---------
Co-authored-by: imbajin <[email protected]>
---
hugegraph-llm/README.md | 22 ++
hugegraph-llm/examples/__init__.py | 16 ++
hugegraph-llm/examples/build_kg_test.py | 70 ++++++
hugegraph-llm/src/__init__.py | 16 ++
hugegraph-llm/src/operators/__init__.py | 16 ++
hugegraph-llm/src/operators/build_kg/__init__.py | 16 ++
.../src/operators/build_kg/commit_data_to_kg.py | 188 ++++++++++++++++
.../src/operators/build_kg/disambiguate_data.py | 244 +++++++++++++++++++++
.../src/operators/build_kg/parse_text_to_data.py | 219 ++++++++++++++++++
.../operators/build_kg/unstructured_data_utils.py | 138 ++++++++++++
hugegraph-llm/src/operators/build_kg_operator.py | 69 ++++++
hugegraph-llm/src/operators/llm/__init__.py | 16 ++
hugegraph-llm/src/operators/llm/base.py | 52 +++++
hugegraph-llm/src/operators/llm/openai_llm.py | 96 ++++++++
14 files changed, 1178 insertions(+)
diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md
new file mode 100644
index 0000000..3fb6441
--- /dev/null
+++ b/hugegraph-llm/README.md
@@ -0,0 +1,22 @@
+# hugegraph-llm
+
+## Summary
+
+The hugegraph-llm is a tool for the implementation and research related to
large language models. This project includes runnable demos, it can also be
used as a third-party library.
+
+As we know, graph systems can help large models address challenges like
timeliness and hallucination, while large models can assist graph systems with
cost-related issues.
+
+With this project, we aim to reduce the cost of using graph systems, and
decrease the complexity of building knowledge graphs. This project will offers
more applications and integration solutions for graph systems and large
language models.
+1. Construct knowledge graph by LLM + HugeGraph
+2. Use natural language to operate graph databases (gremlin)
+3. Knowledge graph supplements answer context (RAG)
+
+# Examples
+
+## Examples(knowledge graph construction by llm)
+
+1. Start the HugeGraph database, you can do it via Docker. Refer to this
[link](https://hub.docker.com/r/hugegraph/hugegraph) for guidance
+2. Run example like python hugegraph-llm/examples/build_kg_test.py
+
+Note: If you need a proxy to access OpenAI's API, please set your HTTP proxy
in `build_kg_test.py`.
+
diff --git a/hugegraph-llm/examples/__init__.py
b/hugegraph-llm/examples/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/hugegraph-llm/examples/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/hugegraph-llm/examples/build_kg_test.py
b/hugegraph-llm/examples/build_kg_test.py
new file mode 100644
index 0000000..ef6dad5
--- /dev/null
+++ b/hugegraph-llm/examples/build_kg_test.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from src.operators.build_kg_operator import KgBuilder
+from src.operators.llm.openai_llm import OpenAIChat
+
+if __name__ == "__main__":
+ # If you need a proxy to access OpenAI's API, please set your HTTP proxy
here
+ os.environ["http_proxy"] = ""
+ os.environ["https_proxy"] = ""
+ api_key = ""
+
+ default_llm = OpenAIChat(
+ api_key=api_key, model_name="gpt-3.5-turbo-16k", max_tokens=4000
+ )
+ text = (
+ "Meet Sarah, a 30-year-old attorney, and her roommate, James, whom
she's shared a home with since 2010. James, "
+ "in his professional life, works as a journalist. Additionally, Sarah
is the proud owner of the website "
+ "www.sarahsplace.com, while James manages his own webpage, though the
specific URL is not mentioned here. "
+ "These two individuals, Sarah and James, have not only forged a strong
personal bond as roommates but have "
+ "also carved out their distinctive digital presence through their
respective webpages, showcasing their "
+ "varied interests and experiences."
+ )
+ builder = KgBuilder(default_llm)
+ # build kg with only text
+
builder.parse_text_to_data(text).disambiguate_data().commit_data_to_kg().run()
+ # build kg with text and schemas
+ nodes_schemas = [
+ {
+ "label": "Person",
+ "primary_key": "name",
+ "properties": {"age": "int", "name": "text", "occupation": "text"},
+ },
+ {
+ "label": "Webpage",
+ "primary_key": "name",
+ "properties": {"name": "text", "url": "text"},
+ },
+ ]
+ relationships_schemas = [
+ {
+ "start": "Person",
+ "end": "Person",
+ "type": "roommate",
+ "properties": {"start": "int"},
+ },
+ {"start": "Person", "end": "Webpage", "type": "owns", "properties":
{}},
+ ]
+ (
+ builder.parse_text_to_data_with_schemas(
+ text, nodes_schemas, relationships_schemas
+ )
+ .disambiguate_data_with_schemas()
+ .commit_data_to_kg()
+ .run()
+ )
diff --git a/hugegraph-llm/src/__init__.py b/hugegraph-llm/src/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/hugegraph-llm/src/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/hugegraph-llm/src/operators/__init__.py
b/hugegraph-llm/src/operators/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/hugegraph-llm/src/operators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/hugegraph-llm/src/operators/build_kg/__init__.py
b/hugegraph-llm/src/operators/build_kg/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/hugegraph-llm/src/operators/build_kg/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/hugegraph-llm/src/operators/build_kg/commit_data_to_kg.py
b/hugegraph-llm/src/operators/build_kg/commit_data_to_kg.py
new file mode 100644
index 0000000..79aec79
--- /dev/null
+++ b/hugegraph-llm/src/operators/build_kg/commit_data_to_kg.py
@@ -0,0 +1,188 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from hugegraph.connection import PyHugeGraph
+
+
+def generate_new_relationships(nodes_schemas_data, relationships_data):
+ label_id = dict()
+ i = 1
+ old_label = []
+ for item in nodes_schemas_data:
+ label = item["label"]
+ if label in old_label:
+ continue
+ else:
+ label_id[label] = i
+ i += 1
+ old_label.append(label)
+ new_relationships_data = []
+ for relationship in relationships_data:
+ start = relationship["start"]
+ end = relationship["end"]
+ relationships_type = relationship["type"]
+ properties = relationship["properties"]
+ new_start = []
+ new_end = []
+ for key, value in label_id.items():
+ for key1, value1 in start.items():
+ if key1 == key:
+ new_start = f"{value}" + ":" + f"{value1}"
+ for key1, value1 in end.items():
+ if key1 == key:
+ new_end = f"{value}" + ":" + f"{value1}"
+ relationships_data = dict()
+ relationships_data["start"] = new_start
+ relationships_data["end"] = new_end
+ relationships_data["type"] = relationships_type
+ relationships_data["properties"] = properties
+ new_relationships_data.append(relationships_data)
+ return new_relationships_data
+
+
+def generate_schema_properties(data):
+ schema_properties_statements = []
+ if len(data) == 3:
+ for item in data:
+ properties = item["properties"]
+ for key, value in properties.items():
+ if value == "int":
+ schema_properties_statements.append(
+
f"schema.propertyKey('{key}').asInt().ifNotExist().create()"
+ )
+ elif value == "text":
+ schema_properties_statements.append(
+
f"schema.propertyKey('{key}').asText().ifNotExist().create()"
+ )
+ else:
+ for item in data:
+ properties = item["properties"]
+ for key, value in properties.items():
+ if value == "int":
+ schema_properties_statements.append(
+
f"schema.propertyKey('{key}').asInt().ifNotExist().create()"
+ )
+ elif value == "text":
+ schema_properties_statements.append(
+
f"schema.propertyKey('{key}').asText().ifNotExist().create()"
+ )
+ return schema_properties_statements
+
+
+def generate_schema_nodes(data):
+ schema_nodes_statements = []
+ for item in data:
+ label = item["label"]
+ primary_key = item["primary_key"]
+ properties = item["properties"]
+ schema_statement = f"schema.vertexLabel('{label}').properties("
+ schema_statement += ", ".join(f"'{prop}'" for prop in
properties.keys())
+ schema_statement += f").nullableKeys("
+ schema_statement += ", ".join(
+ f"'{prop}'" for prop in properties.keys() if prop != primary_key
+ )
+ schema_statement += (
+
f").usePrimaryKeyId().primaryKeys('{primary_key}').ifNotExist().create()"
+ )
+ schema_nodes_statements.append(schema_statement)
+ return schema_nodes_statements
+
+
+def generate_schema_relationships(data):
+ schema_relationships_statements = []
+ for item in data:
+ start = item["start"]
+ end = item["end"]
+ schema_relationships_type = item["type"]
+ properties = item["properties"]
+ schema_statement =
f"schema.edgeLabel('{schema_relationships_type}').sourceLabel('{start}').targetLabel('{end}').properties("
+ schema_statement += ", ".join(f"'{prop}'" for prop in
properties.keys())
+ schema_statement += f").nullableKeys("
+ schema_statement += ", ".join(f"'{prop}'" for prop in
properties.keys())
+ schema_statement += f").ifNotExist().create()"
+ schema_relationships_statements.append(schema_statement)
+ return schema_relationships_statements
+
+
+def generate_nodes(data):
+ nodes = []
+ for item in data:
+ label = item["label"]
+ properties = item["properties"]
+ nodes.append(f"g.addVertex('{label}', {properties})")
+ return nodes
+
+
+def generate_relationships(data):
+ relationships = []
+ for item in data:
+ start = item["start"]
+ end = item["end"]
+ types = item["type"]
+ properties = item["properties"]
+ relationships.append(f"g.addEdge('{types}', '{start}', '{end}',
{properties})")
+ return relationships
+
+
+class CommitDataToKg:
+ def __init__(self):
+ self.client = PyHugeGraph(
+ "127.0.0.1", "8080", user="admin", pwd="admin", graph="hugegraph"
+ )
+ self.schema = self.client.schema()
+
+ def run(self, data: dict):
+ # If you are using a http proxy, you can run the following code to
unset http proxy
+ os.environ.pop("http_proxy")
+ os.environ.pop("https_proxy")
+ nodes = data["nodes"]
+ relationships = data["relationships"]
+ nodes_schemas = data["nodes_schemas"]
+ relationships_schemas = data["relationships_schemas"]
+ schema = self.schema
+ # properties schema
+ schema_nodes_properties = generate_schema_properties(nodes_schemas)
+ schema_relationships_properties = generate_schema_properties(
+ relationships_schemas
+ )
+ for schema_nodes_property in schema_nodes_properties:
+ exec(schema_nodes_property)
+
+ for schema_relationships_property in schema_relationships_properties:
+ exec(schema_relationships_property)
+
+ # nodes schema
+ schema_nodes = generate_schema_nodes(nodes_schemas)
+ for schema_node in schema_nodes:
+ exec(schema_node)
+
+ # relationships schema
+ schema_relationships =
generate_schema_relationships(relationships_schemas)
+ for schema_relationship in schema_relationships:
+ exec(schema_relationship)
+
+ g = self.client.graph()
+ # nodes
+ nodes = generate_nodes(nodes)
+ for node in nodes:
+ exec(node)
+
+ # relationships
+ new_relationships = generate_new_relationships(nodes_schemas,
relationships)
+ relationships_schemas = generate_relationships(new_relationships)
+ for relationship in relationships_schemas:
+ exec(relationship)
diff --git a/hugegraph-llm/src/operators/build_kg/disambiguate_data.py
b/hugegraph-llm/src/operators/build_kg/disambiguate_data.py
new file mode 100644
index 0000000..cd7f3fe
--- /dev/null
+++ b/hugegraph-llm/src/operators/build_kg/disambiguate_data.py
@@ -0,0 +1,244 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import re
+from itertools import groupby
+
+from src.operators.build_kg.unstructured_data_utils import (
+ nodes_text_to_list_of_dict,
+ relationships_text_to_list_of_dict,
+ relationships_schemas_text_to_list_of_dict,
+ nodes_schemas_text_to_list_of_dict,
+)
+from src.operators.llm.base import BaseLLM
+
+
+def disambiguate_nodes() -> str:
+ return """
+Your task is to identify if there are duplicated nodes and if so merge them
into one nod. Only merge the nodes that refer to the same entity.
+You will be given different datasets of nodes and some of these nodes may be
duplicated or refer to the same entity.
+The datasets contains nodes in the form [ENTITY_ID, TYPE, PROPERTIES]. When
you have completed your task please give me the
+resulting nodes in the same format. Only return the nodes and relationships no
other text. If there is no duplicated nodes return the original nodes.
+
+Here is an example
+The input you will be given:
+["Alice", "Person", {"age" : 25, "occupation": "lawyer", "name":"Alice"}],
["Bob", "Person", {"occupation": "journalist", "name": "Bob"}], ["alice.com",
"Webpage", {"url": "www.alice.com"}], ["bob.com", "Webpage", {"url":
"www.bob.com"}], ["Bob", "Person", {"occupation": "journalist", "name": "Bob"}]
+The output you need to provide:
+["Alice", "Person", {"age" : 25, "occupation": "lawyer", "name":"Alice"}],
["Bob", "Person", {"occupation": "journalist", "name": "Bob"}], ["alice.com",
"Webpage", {"url": "www.alice.com"}], ["bob.com", "Webpage", {"url":
"www.bob.com"}]
+"""
+
+
+def disambiguate_relationships() -> str:
+ return """
+Your task is to identify if a set of relationships make sense.
+If they do not make sense please remove them from the dataset.
+Some relationships may be duplicated or refer to the same entity.
+Please merge relationships that refer to the same entity.
+The datasets contains relationships in the form [{"ENTITY_TYPE_1":
"ENTITY_ID_1"}, RELATIONSHIP, {"ENTITY_TYPE_2": "ENTITY_ID_2"}, PROPERTIES].
+You will also be given a set of ENTITY_IDs that are valid.
+Some relationships may use ENTITY_IDs that are not in the valid set but refer
to a entity in the valid set.
+If a relationships refer to a ENTITY_ID in the valid set please change the ID
so it matches the valid ID.
+When you have completed your task please give me the valid relationships in
the same format. Only return the relationships no other text.
+
+Here is an example
+The input you will be given:
+[{"Person": "Alice"}, "roommate", {"Person": "bob"}, {"start": 2021}],
[{"Person": "Alice"}, "owns", {"Webpage": "alice.com"}, {}], [{"Person":
"Bob"}, "owns", {"Webpage": "bob.com"}, {}], [{"Person": "Alice"}, "owns",
{"Webpage": "alice.com"}, {}]
+The output you need to provide:
+[{"Person": "Alice"}, "roommate", {"Person": "bob"}, {"start": 2021}],
[{"Person": "Alice"}, "owns", {"Webpage": "alice.com"}, {}], [{"Person":
"Bob"}, "owns", {"Webpage": "bob.com"}, {}]
+"""
+
+
+def disambiguate_nodes_schemas() -> str:
+ return """
+Your task is to identify if there are duplicated nodes schemas and if so merge
them into one nod. Only merge the nodes schemas that refer to the same
entty_types.
+You will be given different node schemas, some of which may duplicate or
reference the same entty_types. Note: For node schemas with the same
entty_types, you need to merge them while merging all properties of the
entty_types.
+The datasets contains nodes schemas in the form [ENTITY_TYPE, PRIMARY KEY,
PROPERTIES]. When you have completed your task please give me the
+resulting nodes schemas in the same format. Only return the nodes schemas no
other text. If there is no duplicated nodes return the original nodes schemas.
+
+Here is an example
+The input you will be given:
+["Person", "name", {"age": "int", "name": "text", "occupation": "text"}],
["Webpage", "url", {url: "text"}], ["Webpage", "url", {url: "text"}]
+The output you need to provide:
+["Person", "name", {"age": "int", "name": "text", "occupation": "text"}],
["Webpage", "url", {url: "text"}]
+"""
+
+
+def disambiguate_relationships_schemas() -> str:
+ return """
+Your task is to identify if a set of relationships schemas make sense.
+If they do not make sense please remove them from the dataset.
+Some relationships may be duplicated or refer to the same label.
+Please merge relationships that refer to the same label.
+The datasets contains relationships in the form [LABEL_ID_1, RELATIONSHIP,
LABEL_ID_2, PROPERTIES].
+You will also be given a set of LABELS_IDs that are valid.
+Some relationships may use LABELS_IDs that are not in the valid set but refer
to a LABEL in the valid set.
+If a relationships refer to a LABELS_IDs in the valid set please change the ID
so it matches the valid ID.
+When you have completed your task please give me the valid relationships in
the same format. Only return the relationships no other text.
+
+Here is an example
+["Person", "roommate", "Person", {"start": 2021}], ["Person", "owns",
"Webpage", {}], ["Person", "roommate", "Person", {"start": 2021}]
+The output you need to provide:
+["Person", "roommate", "Person", {"start": 2021}], ["Person", "owns",
"Webpage", {}]
+"""
+
+
+def generate_prompt(data) -> str:
+ return f""" Here is the data:
+{data}
+"""
+
+
+internalRegex = r"\[(.*?)\]"
+
+
+class DisambiguateData:
+ def __init__(self, llm: BaseLLM, is_user_schema: bool) -> None:
+ self.llm = llm
+ self.is_user_schema = is_user_schema
+
+ def run(self, data: dict) -> dict[str, list[any]]:
+ nodes = sorted(data["nodes"], key=lambda x: x.get("label", ""))
+ relationships = data["relationships"]
+ nodes_schemas = data["nodes_schemas"]
+ relationships_schemas = data["relationships_schemas"]
+ new_nodes = []
+ new_relationships = []
+ new_nodes_schemas = []
+ new_relationships_schemas = []
+
+ node_groups = groupby(nodes, lambda x: x["label"])
+ for group in node_groups:
+ dis_string = ""
+ nodes_in_group = list(group[1])
+ if len(nodes_in_group) == 1:
+ new_nodes.extend(nodes_in_group)
+ continue
+
+ for node in nodes_in_group:
+ dis_string += (
+ '["'
+ + node["name"]
+ + '", "'
+ + node["label"]
+ + '", '
+ + json.dumps(node["properties"])
+ + "]\n"
+ )
+
+ messages = [
+ {"role": "system", "content": disambiguate_nodes()},
+ {"role": "user", "content": generate_prompt(dis_string)},
+ ]
+ raw_nodes = self.llm.generate(messages)
+ n = re.findall(internalRegex, raw_nodes)
+ new_nodes.extend(nodes_text_to_list_of_dict(n))
+
+ relationship_data = ""
+ for relation in relationships:
+ relationship_data += (
+ '["'
+ + json.dumps(relation["start"])
+ + '", "'
+ + relation["type"]
+ + '", "'
+ + json.dumps(relation["end"])
+ + '", '
+ + json.dumps(relation["properties"])
+ + "]\n"
+ )
+
+ node_labels = [node["name"] for node in new_nodes]
+ relationship_data += "Valid Nodes:\n" + "\n".join(node_labels)
+
+ messages = [
+ {
+ "role": "system",
+ "content": disambiguate_relationships(),
+ },
+ {"role": "user", "content": generate_prompt(relationship_data)},
+ ]
+ raw_relationships = self.llm.generate(messages)
+ rels = re.findall(internalRegex, raw_relationships)
+ new_relationships.extend(relationships_text_to_list_of_dict(rels))
+
+ if not self.is_user_schema:
+ nodes_schemas_data = ""
+ for node_schema in nodes_schemas:
+ nodes_schemas_data += (
+ '["'
+ + node_schema["label"]
+ + '", '
+ + node_schema["primary_key"]
+ + '", '
+ + json.dumps(node_schema["properties"])
+ + "]\n"
+ )
+
+ messages = [
+ {"role": "system", "content": disambiguate_nodes_schemas()},
+ {"role": "user", "content":
generate_prompt(nodes_schemas_data)},
+ ]
+ raw_nodes_schemas = self.llm.generate(messages)
+ n = re.findall(internalRegex, raw_nodes_schemas)
+ new_nodes_schemas.extend(nodes_schemas_text_to_list_of_dict(n))
+
+ relationships_schemas_data = ""
+ for relationships_schema in relationships_schemas:
+ relationships_schemas_data += (
+ '["'
+ + relationships_schema["start"]
+ + '", "'
+ + relationships_schema["type"]
+ + '", "'
+ + relationships_schema["end"]
+ + '", '
+ + json.dumps(relationships_schema["properties"])
+ + "]\n"
+ )
+
+ node_schemas_labels = [
+ nodes_schemas["label"] for nodes_schemas in new_nodes_schemas
+ ]
+ relationships_schemas_data += "Valid Labels:\n" + "\n".join(
+ node_schemas_labels
+ )
+
+ messages = [
+ {
+ "role": "system",
+ "content": disambiguate_relationships_schemas(),
+ },
+ {
+ "role": "user",
+ "content": generate_prompt(relationships_schemas_data),
+ },
+ ]
+ raw_relationships_schemas = self.llm.generate(messages)
+ schemas_rels = re.findall(internalRegex, raw_relationships_schemas)
+ new_relationships_schemas.extend(
+ relationships_schemas_text_to_list_of_dict(schemas_rels)
+ )
+ else:
+ new_nodes_schemas = nodes_schemas
+ new_relationships_schemas = relationships_schemas
+
+ return {
+ "nodes": new_nodes,
+ "relationships": new_relationships,
+ "nodes_schemas": new_nodes_schemas,
+ "relationships_schemas": new_relationships_schemas,
+ }
diff --git a/hugegraph-llm/src/operators/build_kg/parse_text_to_data.py
b/hugegraph-llm/src/operators/build_kg/parse_text_to_data.py
new file mode 100644
index 0000000..8e1a3fb
--- /dev/null
+++ b/hugegraph-llm/src/operators/build_kg/parse_text_to_data.py
@@ -0,0 +1,219 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import re
+from typing import List
+
+from src.operators.build_kg.unstructured_data_utils import (
+ nodes_text_to_list_of_dict,
+ nodes_schemas_text_to_list_of_dict,
+ relationships_schemas_text_to_list_of_dict,
+ relationships_text_to_list_of_dict,
+)
+from src.operators.llm.base import BaseLLM
+
+
+def generate_system_message() -> str:
+ return """
+You are a data scientist working for a company that is building a graph
database. Your task is to extract information from data and convert it into a
graph database.
+Provide a set of Nodes in the form [ENTITY_ID, TYPE, PROPERTIES] and a set of
relationships in the form [ENTITY_ID_1, RELATIONSHIP, ENTITY_ID_2, PROPERTIES]
and a set of NodesSchemas in the form [ENTITY_TYPE, PRIMARY_KEY, PROPERTIES]
and a set of RelationshipsSchemas in the form [ENTITY_TYPE_1, RELATIONSHIP,
ENTITY_TYPE_2, PROPERTIES]
+It is important that the ENTITY_ID_1 and ENTITY_ID_2 exists as nodes with a
matching ENTITY_ID. If you can't pair a relationship with a pair of nodes don't
add it.
+When you find a node or relationship you want to add try to create a generic
TYPE for it that describes the entity you can also think of it as a label.
+
+Here is an example
+The input you will be given:
+Data: Alice lawyer and is 25 years old and Bob is her roommate since 2001. Bob
works as a journalist. Alice owns a the webpage www.alice.com and Bob owns the
webpage www.bob.com.
+The output you need to provide:
+Nodes: ["Alice", "Person", {"age": 25, "occupation": "lawyer", "name":
"Alice"}], ["Bob", "Person", {"occupation": "journalist", "name": "Bob"}],
["alice.com", "Webpage", {"name": "alice.com", "url": "www.alice.com"}],
["bob.com", "Webpage", {"name": "bob.com", "url": "www.bob.com"}]
+Relationships: [{"Person": "Alice"}, "roommate", {"Person": "Bob"}, {"start":
2021}], [{"Person": "Alice"}, "owns", {"Webpage": "alice.com"}, {}],
[{"Person": "Bob"}, "owns", {"Webpage": "bob.com"}, {}]
+NodesSchemas: ["Person", "name", {"age": "int", "name": "text", "occupation":
"text"}], ["Webpage", "name", {"name": "text", "url": "text"}]
+RelationshipsSchemas :["Person", "roommate", "Person", {"start": "int"}],
["Person", "owns", "Webpage", {}]
+"""
+
+
+def generate_system_message_with_schemas() -> str:
+ return """
+You are a data scientist working for a company that is building a graph
database. Your task is to extract information from data and convert it into a
graph database.
+Provide a set of Nodes in the form [ENTITY_ID, TYPE, PROPERTIES] and a set of
relationships in the form [ENTITY_ID_1, RELATIONSHIP, ENTITY_ID_2, PROPERTIES]
and a set of NodesSchemas in the form [ENTITY_TYPE, PRIMARY_KEY, PROPERTIES]
and a set of RelationshipsSchemas in the form [ENTITY_TYPE_1, RELATIONSHIP,
ENTITY_TYPE_2, PROPERTIES]
+It is important that the ENTITY_ID_1 and ENTITY_ID_2 exists as nodes with a
matching ENTITY_ID. If you can't pair a relationship with a pair of nodes don't
add it.
+When you find a node or relationship you want to add try to create a generic
TYPE for it that describes the entity you can also think of it as a label.
+
+Here is an example
+The input you will be given:
+Data: Alice lawyer and is 25 years old and Bob is her roommate since 2001. Bob
works as a journalist. Alice owns a the webpage www.alice.com and Bob owns the
webpage www.bob.com.
+NodesSchemas: ["Person", "name", {"age": "int", "name": "text", "occupation":
"text"}], ["Webpage", "name", {"name": "text", "url": "text"}]
+RelationshipsSchemas :["Person", "roommate", "Person", {"start": "int"}],
["Person", "owns", "Webpage", {}]
+The output you need to provide:
+Nodes: ["Alice", "Person", {"age": 25, "occupation": "lawyer", "name":
"Alice"}], ["Bob", "Person", {"occupation": "journalist", "name": "Bob"}],
["alice.com", "Webpage", {"name": "alice.com", "url": "www.alice.com"}],
["bob.com", "Webpage", {"name": "bob.com", "url": "www.bob.com"}]
+Relationships: [{"Person": "Alice"}, "roommate", {"Person": "Bob"}, {"start":
2021}], [{"Person": "Alice"}, "owns", {"Webpage": "alice.com"}, {}],
[{"Person": "Bob"}, "owns", {"Webpage": "bob.com"}, {}]
+NodesSchemas: ["Person", "name", {"age": "int", "name": "text", "occupation":
"text"}], ["Webpage", "name", {"name": "text", "url": "text"}]
+RelationshipsSchemas :["Person", "roommate", "Person", {"start": "int"}],
["Person", "owns", "Webpage", {}]
+"""
+
+
+def generate_prompt(data) -> str:
+ return f"""
+Data: {data}"""
+
+
+def generate_prompt_with_schemas(data, nodes_schemas, relationships_schemas)
-> str:
+ return f"""
+Data: {data}
+NodesSchemas: {nodes_schemas}
+RelationshipsSchemas: {relationships_schemas}"""
+
+
+def split_string(string, max_length) -> List[str]:
+ return [string[i : i + max_length] for i in range(0, len(string),
max_length)]
+
+
+def split_string_to_fit_token_space(
+ llm: BaseLLM, string: str, token_use_per_string: int
+) -> List[str]:
+ allowed_tokens = llm.max_allowed_token_length() - token_use_per_string
+ chunked_data = split_string(string, 500)
+ combined_chunks = []
+ current_chunk = ""
+ for chunk in chunked_data:
+ if (
+ llm.num_tokens_from_string(current_chunk)
+ + llm.num_tokens_from_string(chunk)
+ < allowed_tokens
+ ):
+ current_chunk += chunk
+ else:
+ combined_chunks.append(current_chunk)
+ current_chunk = chunk
+ combined_chunks.append(current_chunk)
+
+ return combined_chunks
+
+
+def get_nodes_and_relationships_from_result(result):
+ regex = (
+
r"Nodes:\s+(.*?)\s?\s?Relationships:\s+(.*?)\s?\s?NodesSchemas:\s+(.*?)\s?\s?\s?"
+ r"RelationshipsSchemas:\s?\s?(.*)"
+ )
+ internal_regex = r"\[(.*?)\]"
+ nodes = []
+ relationships = []
+ nodes_schemas = []
+ relationships_schemas = []
+ for row in result:
+ parsing = re.match(regex, row, flags=re.S)
+ if parsing is None:
+ continue
+ raw_nodes = str(parsing.group(1))
+ raw_relationships = parsing.group(2)
+ raw_nodes_schemas = parsing.group(3)
+ raw_relationships_schemas = parsing.group(4)
+ nodes.extend(re.findall(internal_regex, raw_nodes))
+ relationships.extend(re.findall(internal_regex, raw_relationships))
+ nodes_schemas.extend(re.findall(internal_regex, raw_nodes_schemas))
+ relationships_schemas.extend(
+ re.findall(internal_regex, raw_relationships_schemas)
+ )
+ result = dict()
+ result["nodes"] = []
+ result["relationships"] = []
+ result["nodes_schemas"] = []
+ result["relationships_schemas"] = []
+ result["nodes"].extend(nodes_text_to_list_of_dict(nodes))
+
result["relationships"].extend(relationships_text_to_list_of_dict(relationships))
+
result["nodes_schemas"].extend(nodes_schemas_text_to_list_of_dict(nodes_schemas))
+ result["relationships_schemas"].extend(
+ relationships_schemas_text_to_list_of_dict(relationships_schemas)
+ )
+ return result
+
+
+class ParseTextToData:
+ llm: BaseLLM
+
+ def __init__(self, llm: BaseLLM, text: str) -> None:
+ self.llm = llm
+ self.text = text
+
+ def process(self, chunk):
+ messages = [
+ {"role": "system", "content": generate_system_message()},
+ {"role": "user", "content": generate_prompt(chunk)},
+ ]
+
+ output = self.llm.generate(messages)
+ return output
+
+ def run(self, data: dict) -> dict[str, list[any]]:
+ system_message = generate_system_message()
+ prompt_string = generate_prompt("")
+ token_usage_per_prompt = self.llm.num_tokens_from_string(
+ system_message + prompt_string
+ )
+ chunked_data = split_string_to_fit_token_space(
+ llm=self.llm, string=self.text,
token_use_per_string=token_usage_per_prompt
+ )
+
+ results = []
+ for chunk in chunked_data:
+ proceeded_chunk = self.process(chunk)
+ results.append(proceeded_chunk)
+ results = get_nodes_and_relationships_from_result(results)
+
+ return results
+
+
+class ParseTextToDataWithSchemas:
+ llm: BaseLLM
+
+ def __init__(
+ self, llm: BaseLLM, text: str, nodes_schema, relationships_schemas
+ ) -> None:
+ self.llm = llm
+ self.text = text
+ self.data = {}
+ self.nodes_schemas = nodes_schema
+ self.relationships_schemas = relationships_schemas
+
+ def process_with_schemas(self, chunk):
+ messages = [
+ {"role": "system", "content":
generate_system_message_with_schemas()},
+ {
+ "role": "user",
+ "content": generate_prompt_with_schemas(
+ chunk, self.nodes_schemas, self.relationships_schemas
+ ),
+ },
+ ]
+
+ output = self.llm.generate(messages)
+ return output
+
+ def run(self) -> dict[str, list[any]]:
+ system_message = generate_system_message_with_schemas()
+ prompt_string = generate_prompt_with_schemas("", "", "")
+ token_usage_per_prompt = self.llm.num_tokens_from_string(
+ system_message + prompt_string
+ )
+ chunked_data = split_string_to_fit_token_space(
+ llm=self.llm, string=self.text,
token_use_per_string=token_usage_per_prompt
+ )
+
+ results = []
+ for chunk in chunked_data:
+ proceeded_chunk = self.process_with_schemas(chunk)
+ results.append(proceeded_chunk)
+ results = get_nodes_and_relationships_from_result(results)
+ return results
diff --git a/hugegraph-llm/src/operators/build_kg/unstructured_data_utils.py
b/hugegraph-llm/src/operators/build_kg/unstructured_data_utils.py
new file mode 100644
index 0000000..9451b65
--- /dev/null
+++ b/hugegraph-llm/src/operators/build_kg/unstructured_data_utils.py
@@ -0,0 +1,138 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import re
+
+regex =
r"Nodes:\s+(.*?)\s?\s?Relationships:\s?\s?NodesSchemas:\s+(.*?)\s?\s?RelationshipsSchemas:\s?\s?(.*)"
+internalRegex = r"\[(.*?)\]"
+jsonRegex = r"\{.*\}"
+jsonRegex_relationships = r"\{.*?\}"
+
+
+def nodes_text_to_list_of_dict(nodes):
+ result = []
+ for node in nodes:
+ node_list = node.split(",")
+ if len(node_list) < 2:
+ continue
+
+ name = node_list[0].strip().replace('"', "")
+ label = node_list[1].strip().replace('"', "")
+ properties = re.search(jsonRegex, node)
+ if properties is None:
+ properties = "{}"
+ else:
+ properties = properties.group(0)
+ properties = properties.replace("True", "true")
+ try:
+ properties = json.loads(properties)
+ except json.decoder.JSONDecodeError:
+ properties = {}
+ result.append({"name": name, "label": label, "properties": properties})
+ return result
+
+
+def relationships_text_to_list_of_dict(relationships):
+ result = []
+ for relationship in relationships:
+ relationship_list = relationship.split(",")
+ if len(relationship_list) < 3:
+ continue
+ start = {}
+ end = {}
+ properties = {}
+ relationship_type = relationship_list[1].strip().replace('"', "")
+ matches = re.findall(jsonRegex_relationships, relationship)
+ i = 1
+ for match in matches:
+ if i == 1:
+ start = json.loads(match)
+ i = 2
+ continue
+ if i == 2:
+ end = json.loads(match)
+ i = 3
+ continue
+ if i == 3:
+ properties = json.loads(match)
+ result.append(
+ {
+ "start": start,
+ "end": end,
+ "type": relationship_type,
+ "properties": properties,
+ }
+ )
+ return result
+
+
+def nodes_schemas_text_to_list_of_dict(nodes_schemas):
+ result = []
+ for nodes_schema in nodes_schemas:
+ nodes_schema_list = nodes_schema.split(",")
+ if len(nodes_schema) < 1:
+ continue
+
+ label = nodes_schema_list[0].strip().replace('"', "")
+ primary_key = nodes_schema_list[1].strip().replace('"', "")
+ properties = re.search(jsonRegex, nodes_schema)
+ if properties is None:
+ properties = "{}"
+ else:
+ properties = properties.group(0)
+ properties = properties.replace("True", "true")
+ try:
+ properties = json.loads(properties)
+ except json.decoder.JSONDecodeError:
+ properties = {}
+ result.append(
+ {"label": label, "primary_key": primary_key, "properties":
properties}
+ )
+ return result
+
+
+def relationships_schemas_text_to_list_of_dict(relationships_schemas):
+ result = []
+ for relationships_schema in relationships_schemas:
+ relationships_schema_list = relationships_schema.split(",")
+ if len(relationships_schema_list) < 3:
+ continue
+ start = relationships_schema_list[0].strip().replace('"', "")
+ end = relationships_schema_list[2].strip().replace('"', "")
+ relationships_schema_type = (
+ relationships_schema_list[1].strip().replace('"', "")
+ )
+
+ properties = re.search(jsonRegex, relationships_schema)
+ if properties is None:
+ properties = "{}"
+ else:
+ properties = properties.group(0)
+ properties = properties.replace("True", "true")
+ try:
+ properties = json.loads(properties)
+ except json.decoder.JSONDecodeError:
+ properties = {}
+ result.append(
+ {
+ "start": start,
+ "end": end,
+ "type": relationships_schema_type,
+ "properties": properties,
+ }
+ )
+ return result
diff --git a/hugegraph-llm/src/operators/build_kg_operator.py
b/hugegraph-llm/src/operators/build_kg_operator.py
new file mode 100644
index 0000000..0b6753f
--- /dev/null
+++ b/hugegraph-llm/src/operators/build_kg_operator.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from src.operators.build_kg.commit_data_to_kg import CommitDataToKg
+from src.operators.build_kg.disambiguate_data import DisambiguateData
+from src.operators.build_kg.parse_text_to_data import (
+ ParseTextToData,
+ ParseTextToDataWithSchemas,
+)
+from src.operators.llm.base import BaseLLM
+
+
+class KgBuilder:
+ def __init__(self, llm: BaseLLM):
+ self.parse_text_to_kg = []
+ self.llm = llm
+ self.data = {}
+
+ def parse_text_to_data(self, text: str):
+ self.parse_text_to_kg.append(ParseTextToData(llm=self.llm, text=text))
+ return self
+
+ def parse_text_to_data_with_schemas(
+ self, text: str, nodes_schemas, relationships_schemas
+ ):
+ self.parse_text_to_kg.append(
+ ParseTextToDataWithSchemas(
+ llm=self.llm,
+ text=text,
+ nodes_schema=nodes_schemas,
+ relationships_schemas=relationships_schemas,
+ )
+ )
+ return self
+
+ def disambiguate_data(self):
+ self.parse_text_to_kg.append(
+ DisambiguateData(llm=self.llm, is_user_schema=False)
+ )
+ return self
+
+ def disambiguate_data_with_schemas(self):
+ self.parse_text_to_kg.append(
+ DisambiguateData(llm=self.llm, is_user_schema=True)
+ )
+ return self
+
+ def commit_data_to_kg(self):
+ self.parse_text_to_kg.append(CommitDataToKg())
+ return self
+
+ def run(self):
+ result = ""
+ for i in self.parse_text_to_kg:
+ result = i.run(result)
+ print(result)
diff --git a/hugegraph-llm/src/operators/llm/__init__.py
b/hugegraph-llm/src/operators/llm/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/hugegraph-llm/src/operators/llm/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/hugegraph-llm/src/operators/llm/base.py
b/hugegraph-llm/src/operators/llm/base.py
new file mode 100644
index 0000000..1ff2923
--- /dev/null
+++ b/hugegraph-llm/src/operators/llm/base.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from abc import ABC, abstractmethod
+from typing import Any, List, Optional, Callable
+
+
+class BaseLLM(ABC):
+ """LLM wrapper should take in a prompt and return a string."""
+
+ @abstractmethod
+ def generate(
+ self,
+ messages: Optional[List[str]] = None,
+ prompt: Optional[str] = None,
+ ) -> str:
+ """Comment"""
+
+ @abstractmethod
+ async def generate_streaming(
+ self,
+ messages: Optional[List[str]] = None,
+ prompt: Optional[str] = None,
+ on_token_callback: Callable = None,
+ ) -> List[Any]:
+ """Comment"""
+
+ @abstractmethod
+ async def num_tokens_from_string(
+ self,
+ string: str,
+ ) -> str:
+ """Given a string returns the number of tokens the given string
consists of"""
+
+ @abstractmethod
+ async def max_allowed_token_length(
+ self,
+ ) -> int:
+ """Returns the maximum number of tokens the LLM can handle"""
diff --git a/hugegraph-llm/src/operators/llm/openai_llm.py
b/hugegraph-llm/src/operators/llm/openai_llm.py
new file mode 100644
index 0000000..b8da930
--- /dev/null
+++ b/hugegraph-llm/src/operators/llm/openai_llm.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Callable, List, Optional
+import openai
+import tiktoken
+from retry import retry
+
+from src.operators.llm.base import BaseLLM
+
+
+class OpenAIChat(BaseLLM):
+ """Wrapper around OpenAI Chat large language models."""
+
+ def __init__(
+ self,
+ api_key: Optional[str] = None,
+ model_name: str = "gpt-3.5-turbo",
+ max_tokens: int = 1000,
+ temperature: float = 0.0,
+ ) -> None:
+ openai.api_key = api_key
+ self.model = model_name
+ self.max_tokens = max_tokens
+ self.temperature = temperature
+
+ @retry(tries=3, delay=1)
+ def generate(
+ self,
+ messages: Optional[List[str]] = None,
+ prompt: Optional[str] = None,
+ ) -> str:
+ try:
+ completions = openai.ChatCompletion.create(
+ model=self.model,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ messages=messages,
+ )
+ return completions.choices[0].message.content
+ # catch context length / do not retry
+ except openai.error.InvalidRequestError as e:
+ return str(f"Error: {e}")
+ # catch authorization errors / do not retry
+ except openai.error.AuthenticationError as e:
+ return "Error: The provided OpenAI API key is invalid"
+ except Exception as e:
+ print(f"Retrying LLM call {e}")
+ raise Exception()
+
+ async def generate_streaming(
+ self,
+ messages: Optional[List[str]] = None,
+ prompt: Optional[str] = None,
+ on_token_callback: Callable = None,
+ ) -> str:
+ if messages is None:
+ assert prompt is not None, "Messages or prompt must be provided."
+ messages = [{"role": "user", "content": prompt}]
+ completions = openai.ChatCompletion.create(
+ model=self.model,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ messages=messages,
+ stream=True,
+ )
+ result = ""
+ for message in completions:
+ # Process the streamed messages or perform any other desired action
+ delta = message["choices"][0]["delta"]
+ if "content" in delta:
+ result += delta["content"]
+ await on_token_callback(message)
+ return result
+
+ def num_tokens_from_string(self, string: str) -> int:
+ encoding = tiktoken.encoding_for_model(self.model)
+ num_tokens = len(encoding.encode(string))
+ return num_tokens
+
+ def max_allowed_token_length(self) -> int:
+ # TODO: list all models and their max tokens from api
+ return 2049