This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit 0bdbf8997edec7661ef79172a0efc1e03dbd78ba Author: Linyu <[email protected]> AuthorDate: Tue Sep 16 14:54:03 2025 +0800 refactor: refactor scheduler to support dynamic workflow scheduling and pipeline pooling (#48) --- hugegraph-llm/pyproject.toml | 2 + hugegraph-llm/src/hugegraph_llm/flows/__init__.py | 16 ++ .../src/hugegraph_llm/flows/build_vector_index.py | 55 +++++ hugegraph-llm/src/hugegraph_llm/flows/common.py | 45 ++++ .../src/hugegraph_llm/flows/graph_extract.py | 127 ++++++++++ hugegraph-llm/src/hugegraph_llm/flows/scheduler.py | 90 +++++++ .../models/embeddings/init_embedding.py | 36 ++- .../src/hugegraph_llm/models/llms/init_llm.py | 80 ++++++- .../operators/common_op/check_schema.py | 258 +++++++++++++++++++-- .../operators/document_op/chunk_split.py | 59 +++++ .../operators/hugegraph_op/schema_manager.py | 88 ++++++- .../operators/index_op/build_vector_index.py | 65 +++++- .../hugegraph_llm/operators/llm_op/info_extract.py | 220 ++++++++++++++++-- .../operators/llm_op/property_graph_extract.py | 190 +++++++++++++-- hugegraph-llm/src/hugegraph_llm/operators/util.py | 27 +++ hugegraph-llm/src/hugegraph_llm/state/__init__.py | 16 ++ hugegraph-llm/src/hugegraph_llm/state/ai_state.py | 81 +++++++ .../src/hugegraph_llm/utils/graph_index_utils.py | 83 +++++-- .../src/hugegraph_llm/utils/vector_index_utils.py | 67 +++--- 19 files changed, 1472 insertions(+), 133 deletions(-) diff --git a/hugegraph-llm/pyproject.toml b/hugegraph-llm/pyproject.toml index 1bd3b748..2b0f29ac 100644 --- a/hugegraph-llm/pyproject.toml +++ b/hugegraph-llm/pyproject.toml @@ -61,6 +61,7 @@ dependencies = [ "apscheduler", "litellm", "hugegraph-python-client", + "pycgraph", ] [project.urls] homepage = "https://hugegraph.apache.org/" @@ -88,3 +89,4 @@ allow-direct-references = true [tool.uv.sources] hugegraph-python-client = { workspace = true } +pycgraph = { git = "https://github.com/ChunelFeng/CGraph.git", subdirectory = "python", rev = "main", marker = "sys_platform == 'linux'" } diff --git a/hugegraph-llm/src/hugegraph_llm/flows/__init__.py b/hugegraph-llm/src/hugegraph_llm/flows/__init__.py new file mode 100644 index 00000000..13a83393 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py new file mode 100644 index 00000000..f1ee8c1c --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.state.ai_state import WkFlowInput + +import json +from PyCGraph import GPipeline + +from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode +from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndexNode +from hugegraph_llm.state.ai_state import WkFlowState + + +class BuildVectorIndexFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, texts): + prepared_input.texts = texts + prepared_input.language = "zh" + prepared_input.split_type = "paragraph" + return + + def build_flow(self, texts): + pipeline = GPipeline() + # prepare for workflow input + prepared_input = WkFlowInput() + self.prepare(prepared_input, texts) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + chunk_split_node = ChunkSplitNode() + build_vector_node = BuildVectorIndexNode() + pipeline.registerGElement(chunk_split_node, set(), "chunk_split") + pipeline.registerGElement(build_vector_node, {chunk_split_node}, "build_vector") + + return pipeline + + def post_deal(self, pipeline=None): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + return json.dumps(res, ensure_ascii=False, indent=2) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/common.py b/hugegraph-llm/src/hugegraph_llm/flows/common.py new file mode 100644 index 00000000..4c552626 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/common.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from hugegraph_llm.state.ai_state import WkFlowInput + + +class BaseFlow(ABC): + """ + Base class for flows, defines three interface methods: prepare, build_flow, and post_deal. + """ + + @abstractmethod + def prepare(self, prepared_input: WkFlowInput, *args, **kwargs): + """ + Pre-processing interface. + """ + pass + + @abstractmethod + def build_flow(self, *args, **kwargs): + """ + Interface for building the flow. + """ + pass + + @abstractmethod + def post_deal(self, *args, **kwargs): + """ + Post-processing interface. + """ + pass diff --git a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py new file mode 100644 index 00000000..f1a6c5f6 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from PyCGraph import GPipeline +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.operators.common_op.check_schema import CheckSchemaNode +from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode +from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManagerNode +from hugegraph_llm.operators.llm_op.info_extract import InfoExtractNode +from hugegraph_llm.operators.llm_op.property_graph_extract import ( + PropertyGraphExtractNode, +) +from hugegraph_llm.utils.log import log + + +class GraphExtractFlow(BaseFlow): + def __init__(self): + pass + + def _import_schema( + self, + from_hugegraph=None, + from_extraction=None, + from_user_defined=None, + ): + if from_hugegraph: + return SchemaManagerNode() + elif from_user_defined: + return CheckSchemaNode() + elif from_extraction: + raise NotImplementedError("Not implemented yet") + else: + raise ValueError("No input data / invalid schema type") + + def prepare( + self, prepared_input: WkFlowInput, schema, texts, example_prompt, extract_type + ): + # prepare input data + prepared_input.texts = texts + prepared_input.language = "zh" + prepared_input.split_type = "document" + prepared_input.example_prompt = example_prompt + prepared_input.schema = schema + schema = schema.strip() + if schema.startswith("{"): + try: + schema = json.loads(schema) + prepared_input.schema = schema + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", schema) + prepared_input.graph_name = schema + return + + def build_flow(self, schema, texts, example_prompt, extract_type): + pipeline = GPipeline() + prepared_input = WkFlowInput() + # prepare input data + self.prepare(prepared_input, schema, texts, example_prompt, extract_type) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + schema = schema.strip() + schema_node = None + if schema.startswith("{"): + try: + schema = json.loads(schema) + schema_node = self._import_schema(from_user_defined=schema) + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", schema) + schema_node = self._import_schema(from_hugegraph=schema) + + chunk_split_node = ChunkSplitNode() + graph_extract_node = None + if extract_type == "triples": + graph_extract_node = InfoExtractNode() + elif extract_type == "property_graph": + graph_extract_node = PropertyGraphExtractNode() + else: + raise ValueError(f"Unsupported extract_type: {extract_type}") + pipeline.registerGElement(schema_node, set(), "schema_node") + pipeline.registerGElement(chunk_split_node, set(), "chunk_split") + pipeline.registerGElement( + graph_extract_node, {schema_node, chunk_split_node}, "graph_extract" + ) + + return pipeline + + def post_deal(self, pipeline=None): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + vertices = res.get("vertices", []) + edges = res.get("edges", []) + if not vertices and not edges: + log.info("Please check the schema.(The schema may not match the Doc)") + return json.dumps( + { + "vertices": vertices, + "edges": edges, + "warning": "The schema may not match the Doc", + }, + ensure_ascii=False, + indent=2, + ) + return json.dumps( + {"vertices": vertices, "edges": edges}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py new file mode 100644 index 00000000..b096310d --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from typing import Dict, Any +from PyCGraph import GPipelineManager +from hugegraph_llm.flows.build_vector_index import BuildVectorIndexFlow +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.flows.graph_extract import GraphExtractFlow +from hugegraph_llm.utils.log import log + + +class Scheduler: + pipeline_pool: Dict[str, Any] = None + max_pipeline: int + + def __init__(self, max_pipeline: int = 10): + self.pipeline_pool = {} + # pipeline_pool act as a manager of GPipelineManager which used for pipeline management + self.pipeline_pool["build_vector_index"] = { + "manager": GPipelineManager(), + "flow": BuildVectorIndexFlow(), + } + self.pipeline_pool["graph_extract"] = { + "manager": GPipelineManager(), + "flow": GraphExtractFlow(), + } + self.max_pipeline = max_pipeline + + # TODO: Implement Agentic Workflow + def agentic_flow(self): + pass + + def schedule_flow(self, flow: str, *args, **kwargs): + if flow not in self.pipeline_pool: + raise ValueError(f"Unsupported workflow {flow}") + manager = self.pipeline_pool[flow]["manager"] + flow: BaseFlow = self.pipeline_pool[flow]["flow"] + pipeline = manager.fetch() + if pipeline is None: + # call coresponding flow_func to create new workflow + pipeline = flow.build_flow(*args, **kwargs) + status = pipeline.init() + if status.isErr(): + error_msg = f"Error in flow init: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + status = pipeline.run() + if status.isErr(): + error_msg = f"Error in flow execution: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + res = flow.post_deal(pipeline) + manager.add(pipeline) + return res + else: + # fetch pipeline & prepare input for flow + prepared_input = pipeline.getGParamWithNoEmpty("wkflow_input") + flow.prepare(prepared_input, *args, **kwargs) + status = pipeline.run() + if status.isErr(): + raise RuntimeError(f"Error in flow execution {status.getInfo()}") + res = flow.post_deal(pipeline) + manager.release(pipeline) + return res + + +class SchedulerSingleton: + _instance = None + _instance_lock = threading.Lock() + + @classmethod + def get_instance(cls): + if cls._instance is None: + with cls._instance_lock: + if cls._instance is None: + cls._instance = Scheduler() + return cls._instance diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py index 48e4968c..3ad50b3e 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py @@ -17,10 +17,40 @@ from hugegraph_llm.config import llm_settings +from hugegraph_llm.config import LLMConfig from hugegraph_llm.models.embeddings.litellm import LiteLLMEmbedding from hugegraph_llm.models.embeddings.ollama import OllamaEmbedding from hugegraph_llm.models.embeddings.openai import OpenAIEmbedding +model_map = { + "openai": llm_settings.openai_embedding_model, + "ollama/local": llm_settings.ollama_embedding_model, + "litellm": llm_settings.litellm_embedding_model, +} + + +def get_embedding(llm_settings: LLMConfig): + if llm_settings.embedding_type == "openai": + return OpenAIEmbedding( + model_name=llm_settings.openai_embedding_model, + api_key=llm_settings.openai_embedding_api_key, + api_base=llm_settings.openai_embedding_api_base, + ) + if llm_settings.embedding_type == "ollama/local": + return OllamaEmbedding( + model_name=llm_settings.ollama_embedding_model, + host=llm_settings.ollama_embedding_host, + port=llm_settings.ollama_embedding_port, + ) + if llm_settings.embedding_type == "litellm": + return LiteLLMEmbedding( + model_name=llm_settings.litellm_embedding_model, + api_key=llm_settings.litellm_embedding_api_key, + api_base=llm_settings.litellm_embedding_api_base, + ) + + raise Exception("embedding type is not supported !") + class Embeddings: def __init__(self): @@ -31,19 +61,19 @@ class Embeddings: return OpenAIEmbedding( model_name=llm_settings.openai_embedding_model, api_key=llm_settings.openai_embedding_api_key, - api_base=llm_settings.openai_embedding_api_base + api_base=llm_settings.openai_embedding_api_base, ) if self.embedding_type == "ollama/local": return OllamaEmbedding( model_name=llm_settings.ollama_embedding_model, host=llm_settings.ollama_embedding_host, - port=llm_settings.ollama_embedding_port + port=llm_settings.ollama_embedding_port, ) if self.embedding_type == "litellm": return LiteLLMEmbedding( model_name=llm_settings.litellm_embedding_model, api_key=llm_settings.litellm_embedding_api_key, - api_base=llm_settings.litellm_embedding_api_base + api_base=llm_settings.litellm_embedding_api_base, ) raise Exception("embedding type is not supported !") diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py index e70b0d9d..7e1eaab6 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py @@ -15,13 +15,85 @@ # specific language governing permissions and limitations # under the License. - +from hugegraph_llm.config import LLMConfig from hugegraph_llm.models.llms.ollama import OllamaClient from hugegraph_llm.models.llms.openai import OpenAIClient from hugegraph_llm.models.llms.litellm import LiteLLMClient from hugegraph_llm.config import llm_settings +def get_chat_llm(llm_settings: LLMConfig): + if llm_settings.chat_llm_type == "openai": + return OpenAIClient( + api_key=llm_settings.openai_chat_api_key, + api_base=llm_settings.openai_chat_api_base, + model_name=llm_settings.openai_chat_language_model, + max_tokens=llm_settings.openai_chat_tokens, + ) + if llm_settings.chat_llm_type == "ollama/local": + return OllamaClient( + model=llm_settings.ollama_chat_language_model, + host=llm_settings.ollama_chat_host, + port=llm_settings.ollama_chat_port, + ) + if llm_settings.chat_llm_type == "litellm": + return LiteLLMClient( + api_key=llm_settings.litellm_chat_api_key, + api_base=llm_settings.litellm_chat_api_base, + model_name=llm_settings.litellm_chat_language_model, + max_tokens=llm_settings.litellm_chat_tokens, + ) + raise Exception("chat llm type is not supported !") + + +def get_extract_llm(llm_settings: LLMConfig): + if llm_settings.extract_llm_type == "openai": + return OpenAIClient( + api_key=llm_settings.openai_extract_api_key, + api_base=llm_settings.openai_extract_api_base, + model_name=llm_settings.openai_extract_language_model, + max_tokens=llm_settings.openai_extract_tokens, + ) + if llm_settings.extract_llm_type == "ollama/local": + return OllamaClient( + model=llm_settings.ollama_extract_language_model, + host=llm_settings.ollama_extract_host, + port=llm_settings.ollama_extract_port, + ) + if llm_settings.extract_llm_type == "litellm": + return LiteLLMClient( + api_key=llm_settings.litellm_extract_api_key, + api_base=llm_settings.litellm_extract_api_base, + model_name=llm_settings.litellm_extract_language_model, + max_tokens=llm_settings.litellm_extract_tokens, + ) + raise Exception("extract llm type is not supported !") + + +def get_text2gql_llm(llm_settings: LLMConfig): + if llm_settings.text2gql_llm_type == "openai": + return OpenAIClient( + api_key=llm_settings.openai_text2gql_api_key, + api_base=llm_settings.openai_text2gql_api_base, + model_name=llm_settings.openai_text2gql_language_model, + max_tokens=llm_settings.openai_text2gql_tokens, + ) + if llm_settings.text2gql_llm_type == "ollama/local": + return OllamaClient( + model=llm_settings.ollama_text2gql_language_model, + host=llm_settings.ollama_text2gql_host, + port=llm_settings.ollama_text2gql_port, + ) + if llm_settings.text2gql_llm_type == "litellm": + return LiteLLMClient( + api_key=llm_settings.litellm_text2gql_api_key, + api_base=llm_settings.litellm_text2gql_api_base, + model_name=llm_settings.litellm_text2gql_language_model, + max_tokens=llm_settings.litellm_text2gql_tokens, + ) + raise Exception("text2gql llm type is not supported !") + + class LLMs: def __init__(self): self.chat_llm_type = llm_settings.chat_llm_type @@ -101,4 +173,8 @@ class LLMs: if __name__ == "__main__": client = LLMs().get_chat_llm() print(client.generate(prompt="What is the capital of China?")) - print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}])) + print( + client.generate( + messages=[{"role": "user", "content": "What is the capital of China?"}] + ) + ) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py index 3220d9f3..7a533517 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -20,8 +20,12 @@ from typing import Any, Optional, Dict from hugegraph_llm.enums.property_cardinality import PropertyCardinality from hugegraph_llm.enums.property_data_type import PropertyDataType +from hugegraph_llm.operators.util import init_context from hugegraph_llm.utils.log import log +from PyCGraph import GNode, CStatus +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + def log_and_raise(message: str) -> None: log.warning(message) @@ -59,64 +63,270 @@ class CheckSchema: check_type(schema, dict, "Input data is not a dictionary.") if "vertexlabels" not in schema or "edgelabels" not in schema: log_and_raise("Input data does not contain 'vertexlabels' or 'edgelabels'.") - check_type(schema["vertexlabels"], list, "'vertexlabels' in input data is not a list.") - check_type(schema["edgelabels"], list, "'edgelabels' in input data is not a list.") + check_type( + schema["vertexlabels"], list, "'vertexlabels' in input data is not a list." + ) + check_type( + schema["edgelabels"], list, "'edgelabels' in input data is not a list." + ) + + def _process_property_labels(self, schema: Dict[str, Any]) -> (list, set): + property_labels = schema.get("propertykeys", []) + check_type( + property_labels, + list, + "'propertykeys' in input data is not of correct type.", + ) + property_label_set = {label["name"] for label in property_labels} + return property_labels, property_label_set + + def _process_vertex_labels( + self, schema: Dict[str, Any], property_labels: list, property_label_set: set + ) -> None: + for vertex_label in schema["vertexlabels"]: + self._validate_vertex_label(vertex_label) + properties = vertex_label["properties"] + primary_keys = self._process_keys( + vertex_label, "primary_keys", properties[:1] + ) + if len(primary_keys) == 0: + log_and_raise(f"'primary_keys' of {vertex_label['name']} is empty.") + vertex_label["primary_keys"] = primary_keys + nullable_keys = self._process_keys( + vertex_label, "nullable_keys", properties[1:] + ) + vertex_label["nullable_keys"] = nullable_keys + self._add_missing_properties( + properties, property_labels, property_label_set + ) + + def _process_edge_labels( + self, schema: Dict[str, Any], property_labels: list, property_label_set: set + ) -> None: + for edge_label in schema["edgelabels"]: + self._validate_edge_label(edge_label) + properties = edge_label.get("properties", []) + self._add_missing_properties( + properties, property_labels, property_label_set + ) + + def _validate_vertex_label(self, vertex_label: Dict[str, Any]) -> None: + check_type(vertex_label, dict, "VertexLabel in input data is not a dictionary.") + if "name" not in vertex_label: + log_and_raise("VertexLabel in input data does not contain 'name'.") + check_type( + vertex_label["name"], str, "'name' in vertex_label is not of correct type." + ) + if "properties" not in vertex_label: + log_and_raise("VertexLabel in input data does not contain 'properties'.") + check_type( + vertex_label["properties"], + list, + "'properties' in vertex_label is not of correct type.", + ) + if len(vertex_label["properties"]) == 0: + log_and_raise("'properties' in vertex_label is empty.") + + def _validate_edge_label(self, edge_label: Dict[str, Any]) -> None: + check_type(edge_label, dict, "EdgeLabel in input data is not a dictionary.") + if ( + "name" not in edge_label + or "source_label" not in edge_label + or "target_label" not in edge_label + ): + log_and_raise( + "EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'." + ) + check_type( + edge_label["name"], str, "'name' in edge_label is not of correct type." + ) + check_type( + edge_label["source_label"], + str, + "'source_label' in edge_label is not of correct type.", + ) + check_type( + edge_label["target_label"], + str, + "'target_label' in edge_label is not of correct type.", + ) + + def _process_keys( + self, label: Dict[str, Any], key_type: str, default_keys: list + ) -> list: + keys = label.get(key_type, default_keys) + check_type( + keys, list, f"'{key_type}' in {label['name']} is not of correct type." + ) + new_keys = [key for key in keys if key in label["properties"]] + return new_keys + + def _add_missing_properties( + self, properties: list, property_labels: list, property_label_set: set + ) -> None: + for prop in properties: + if prop not in property_label_set: + property_labels.append( + { + "name": prop, + "data_type": PropertyDataType.DEFAULT.value, + "cardinality": PropertyCardinality.DEFAULT.value, + } + ) + property_label_set.add(prop) + + +class CheckSchemaNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + if self.wk_input.schema is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + self.data = self.wk_input.schema + return CStatus() + + def run(self) -> CStatus: + # init workflow input + sts = self.node_init() + if sts.isErr(): + return sts + # 1. Validate the schema structure + self.context.lock() + schema = self.data or self.context.schema + self._validate_schema(schema) + # 2. Process property labels and also create a set for it + property_labels, property_label_set = self._process_property_labels(schema) + # 3. Process properties in given vertex/edge labels + self._process_vertex_labels(schema, property_labels, property_label_set) + self._process_edge_labels(schema, property_labels, property_label_set) + # 4. Update schema with processed pks + schema["propertykeys"] = property_labels + self.context.schema = schema + self.context.unlock() + return CStatus() + + def _validate_schema(self, schema: Dict[str, Any]) -> None: + check_type(schema, dict, "Input data is not a dictionary.") + if "vertexlabels" not in schema or "edgelabels" not in schema: + log_and_raise("Input data does not contain 'vertexlabels' or 'edgelabels'.") + check_type( + schema["vertexlabels"], list, "'vertexlabels' in input data is not a list." + ) + check_type( + schema["edgelabels"], list, "'edgelabels' in input data is not a list." + ) def _process_property_labels(self, schema: Dict[str, Any]) -> (list, set): property_labels = schema.get("propertykeys", []) - check_type(property_labels, list, "'propertykeys' in input data is not of correct type.") + check_type( + property_labels, + list, + "'propertykeys' in input data is not of correct type.", + ) property_label_set = {label["name"] for label in property_labels} return property_labels, property_label_set - def _process_vertex_labels(self, schema: Dict[str, Any], property_labels: list, property_label_set: set) -> None: + def _process_vertex_labels( + self, schema: Dict[str, Any], property_labels: list, property_label_set: set + ) -> None: for vertex_label in schema["vertexlabels"]: self._validate_vertex_label(vertex_label) properties = vertex_label["properties"] - primary_keys = self._process_keys(vertex_label, "primary_keys", properties[:1]) + primary_keys = self._process_keys( + vertex_label, "primary_keys", properties[:1] + ) if len(primary_keys) == 0: log_and_raise(f"'primary_keys' of {vertex_label['name']} is empty.") vertex_label["primary_keys"] = primary_keys - nullable_keys = self._process_keys(vertex_label, "nullable_keys", properties[1:]) + nullable_keys = self._process_keys( + vertex_label, "nullable_keys", properties[1:] + ) vertex_label["nullable_keys"] = nullable_keys - self._add_missing_properties(properties, property_labels, property_label_set) + self._add_missing_properties( + properties, property_labels, property_label_set + ) - def _process_edge_labels(self, schema: Dict[str, Any], property_labels: list, property_label_set: set) -> None: + def _process_edge_labels( + self, schema: Dict[str, Any], property_labels: list, property_label_set: set + ) -> None: for edge_label in schema["edgelabels"]: self._validate_edge_label(edge_label) properties = edge_label.get("properties", []) - self._add_missing_properties(properties, property_labels, property_label_set) + self._add_missing_properties( + properties, property_labels, property_label_set + ) def _validate_vertex_label(self, vertex_label: Dict[str, Any]) -> None: check_type(vertex_label, dict, "VertexLabel in input data is not a dictionary.") if "name" not in vertex_label: log_and_raise("VertexLabel in input data does not contain 'name'.") - check_type(vertex_label["name"], str, "'name' in vertex_label is not of correct type.") + check_type( + vertex_label["name"], str, "'name' in vertex_label is not of correct type." + ) if "properties" not in vertex_label: log_and_raise("VertexLabel in input data does not contain 'properties'.") - check_type(vertex_label["properties"], list, "'properties' in vertex_label is not of correct type.") + check_type( + vertex_label["properties"], + list, + "'properties' in vertex_label is not of correct type.", + ) if len(vertex_label["properties"]) == 0: log_and_raise("'properties' in vertex_label is empty.") def _validate_edge_label(self, edge_label: Dict[str, Any]) -> None: check_type(edge_label, dict, "EdgeLabel in input data is not a dictionary.") - if "name" not in edge_label or "source_label" not in edge_label or "target_label" not in edge_label: - log_and_raise("EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'.") - check_type(edge_label["name"], str, "'name' in edge_label is not of correct type.") - check_type(edge_label["source_label"], str, "'source_label' in edge_label is not of correct type.") - check_type(edge_label["target_label"], str, "'target_label' in edge_label is not of correct type.") + if ( + "name" not in edge_label + or "source_label" not in edge_label + or "target_label" not in edge_label + ): + log_and_raise( + "EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'." + ) + check_type( + edge_label["name"], str, "'name' in edge_label is not of correct type." + ) + check_type( + edge_label["source_label"], + str, + "'source_label' in edge_label is not of correct type.", + ) + check_type( + edge_label["target_label"], + str, + "'target_label' in edge_label is not of correct type.", + ) - def _process_keys(self, label: Dict[str, Any], key_type: str, default_keys: list) -> list: + def _process_keys( + self, label: Dict[str, Any], key_type: str, default_keys: list + ) -> list: keys = label.get(key_type, default_keys) - check_type(keys, list, f"'{key_type}' in {label['name']} is not of correct type.") + check_type( + keys, list, f"'{key_type}' in {label['name']} is not of correct type." + ) new_keys = [key for key in keys if key in label["properties"]] return new_keys - def _add_missing_properties(self, properties: list, property_labels: list, property_label_set: set) -> None: + def _add_missing_properties( + self, properties: list, property_labels: list, property_label_set: set + ) -> None: for prop in properties: if prop not in property_label_set: - property_labels.append({ - "name": prop, - "data_type": PropertyDataType.DEFAULT.value, - "cardinality": PropertyCardinality.DEFAULT.value, - }) + property_labels.append( + { + "name": prop, + "data_type": PropertyDataType.DEFAULT.value, + "cardinality": PropertyCardinality.DEFAULT.value, + } + ) property_label_set.add(prop) + + def get_result(self): + self.context.lock() + res = self.context.to_json() + self.context.unlock() + return res diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py index 8c2dd80f..d779a40a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py @@ -19,6 +19,8 @@ from typing import Literal, Dict, Any, Optional, Union, List from langchain_text_splitters import RecursiveCharacterTextSplitter +from hugegraph_llm.operators.util import init_context +from PyCGraph import GNode, CStatus # Constants LANGUAGE_ZH = "zh" @@ -27,6 +29,63 @@ SPLIT_TYPE_DOCUMENT = "document" SPLIT_TYPE_PARAGRAPH = "paragraph" SPLIT_TYPE_SENTENCE = "sentence" + +class ChunkSplitNode(GNode): + def init(self): + return init_context(self) + + def node_init(self): + if ( + self.wk_input.texts is None + or self.wk_input.language is None + or self.wk_input.split_type is None + ): + return CStatus(-1, "Error occurs when prepare for workflow input") + texts = self.wk_input.texts + language = self.wk_input.language + split_type = self.wk_input.split_type + if isinstance(texts, str): + texts = [texts] + self.texts = texts + self.separators = self._get_separators(language) + self.text_splitter = self._get_text_splitter(split_type) + return CStatus() + + def _get_separators(self, language: str) -> List[str]: + if language == LANGUAGE_ZH: + return ["\n\n", "\n", "。", ",", ""] + if language == LANGUAGE_EN: + return ["\n\n", "\n", ".", ",", " ", ""] + raise ValueError("language must be zh or en") + + def _get_text_splitter(self, split_type: str): + if split_type == SPLIT_TYPE_DOCUMENT: + return lambda text: [text] + if split_type == SPLIT_TYPE_PARAGRAPH: + return RecursiveCharacterTextSplitter( + chunk_size=500, chunk_overlap=30, separators=self.separators + ).split_text + if split_type == SPLIT_TYPE_SENTENCE: + return RecursiveCharacterTextSplitter( + chunk_size=50, chunk_overlap=0, separators=self.separators + ).split_text + raise ValueError("Type must be document, paragraph or sentence") + + def run(self): + sts = self.node_init() + if sts.isErr(): + return sts + all_chunks = [] + for text in self.texts: + chunks = self.text_splitter(text) + all_chunks.extend(chunks) + + self.context.lock() + self.context.chunks = all_chunks + self.context.unlock() + return CStatus() + + class ChunkSplit: def __init__( self, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py index 2f50bb81..670c18b4 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py @@ -17,8 +17,12 @@ from typing import Dict, Any, Optional from hugegraph_llm.config import huge_settings +from hugegraph_llm.operators.util import init_context +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from pyhugegraph.client import PyHugeClient +from PyCGraph import GNode, CStatus + class SchemaManager: def __init__(self, graph_name: str): @@ -39,15 +43,22 @@ class SchemaManager: if "vertexlabels" in schema: mini_schema["vertexlabels"] = [] for vertex in schema["vertexlabels"]: - new_vertex = {key: vertex[key] for key in ["id", "name", "properties"] if key in vertex} + new_vertex = { + key: vertex[key] + for key in ["id", "name", "properties"] + if key in vertex + } mini_schema["vertexlabels"].append(new_vertex) # Add necessary edgelabels items (4) if "edgelabels" in schema: mini_schema["edgelabels"] = [] for edge in schema["edgelabels"]: - new_edge = {key: edge[key] for key in - ["name", "source_label", "target_label", "properties"] if key in edge} + new_edge = { + key: edge[key] + for key in ["name", "source_label", "target_label", "properties"] + if key in edge + } mini_schema["edgelabels"].append(new_edge) return mini_schema @@ -63,3 +74,74 @@ class SchemaManager: # TODO: enhance the logic here context["simple_schema"] = self.simple_schema(schema) return context + + +class SchemaManagerNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + if self.wk_input.graph_name is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + graph_name = self.wk_input.graph_name + self.graph_name = graph_name + self.client = PyHugeClient( + url=huge_settings.graph_url, + graph=self.graph_name, + user=huge_settings.graph_user, + pwd=huge_settings.graph_pwd, + graphspace=huge_settings.graph_space, + ) + self.schema = self.client.schema() + return CStatus() + + def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: + mini_schema = {} + + # Add necessary vertexlabels items (3) + if "vertexlabels" in schema: + mini_schema["vertexlabels"] = [] + for vertex in schema["vertexlabels"]: + new_vertex = { + key: vertex[key] + for key in ["id", "name", "properties"] + if key in vertex + } + mini_schema["vertexlabels"].append(new_vertex) + + # Add necessary edgelabels items (4) + if "edgelabels" in schema: + mini_schema["edgelabels"] = [] + for edge in schema["edgelabels"]: + new_edge = { + key: edge[key] + for key in ["name", "source_label", "target_label", "properties"] + if key in edge + } + mini_schema["edgelabels"].append(new_edge) + + return mini_schema + + def run(self) -> CStatus: + sts = self.node_init() + if sts.isErr(): + return sts + schema = self.schema.getSchema() + if not schema["vertexlabels"] and not schema["edgelabels"]: + raise Exception(f"Can not get {self.graph_name}'s schema from HugeGraph!") + + self.context.lock() + self.context.schema = schema + # TODO: enhance the logic here + self.context.simple_schema = self.simple_schema(schema) + self.context.unlock() + return CStatus() + + def get_result(self): + self.context.lock() + res = self.context.to_json() + self.context.unlock() + return res diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py index ffb35564..ee89d330 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py @@ -23,20 +23,75 @@ from typing import Dict, Any from hugegraph_llm.config import huge_settings, resource_path, llm_settings from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel, get_filename_prefix, get_index_folder_name +from hugegraph_llm.utils.embedding_utils import ( + get_embeddings_parallel, + get_filename_prefix, + get_index_folder_name, +) from hugegraph_llm.utils.log import log +from hugegraph_llm.operators.util import init_context +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from PyCGraph import GNode, CStatus + + +class BuildVectorIndexNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + self.embedding = get_embedding(llm_settings) + self.folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + self.index_dir = str(os.path.join(resource_path, self.folder_name, "chunks")) + self.filename_prefix = get_filename_prefix( + llm_settings.embedding_type, getattr(self.embedding, "model_name", None) + ) + self.vector_index = VectorIndex.from_index_file( + self.index_dir, self.filename_prefix + ) + return CStatus() + + def run(self): + # init workflow input + sts = self.node_init() + if sts.isErr(): + return sts + self.context.lock() + try: + if self.context.chunks is None: + raise ValueError("chunks not found in context.") + chunks = self.context.chunks + finally: + self.context.unlock() + chunks_embedding = [] + log.debug("Building vector index for %s chunks...", len(chunks)) + # TODO: use async_get_texts_embedding instead of single sync method + chunks_embedding = asyncio.run(get_embeddings_parallel(self.embedding, chunks)) + if len(chunks_embedding) > 0: + self.vector_index.add(chunks_embedding, chunks) + self.vector_index.to_index_file(self.index_dir, self.filename_prefix) + return CStatus() + class BuildVectorIndex: def __init__(self, embedding: BaseEmbedding): self.embedding = embedding - self.folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + self.folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) self.index_dir = str(os.path.join(resource_path, self.folder_name, "chunks")) self.filename_prefix = get_filename_prefix( - llm_settings.embedding_type, - getattr(self.embedding, "model_name", None) + llm_settings.embedding_type, getattr(self.embedding, "model_name", None) + ) + self.vector_index = VectorIndex.from_index_file( + self.index_dir, self.filename_prefix ) - self.vector_index = VectorIndex.from_index_file(self.index_dir, self.filename_prefix) def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if "chunks" not in context: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 42bb6b10..15a8fdda 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -18,20 +18,26 @@ import re from typing import List, Any, Dict, Optional +from hugegraph_llm.config import llm_settings from hugegraph_llm.document.chunk_split import ChunkSplitter from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log +from hugegraph_llm.operators.util import init_context +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from PyCGraph import GNode, CStatus + SCHEMA_EXAMPLE_PROMPT = """## Main Task Extract Triples from the given text and graph schema ## Basic Rules 1. The output format must be: (X,Y,Z) - LABEL -In this format, Y must be a value from "properties" or "edge_label", +In this format, Y must be a value from "properties" or "edge_label", and LABEL must be X's vertex_label or Y's edge_label. 2. Don't extract attribute/property fields that do not exist in the given schema 3. Ensure the extract property is in the same type as the schema (like 'age' should be a number) -4. Translate the given schema filed into Chinese if the given text is Chinese but the schema is in English (Optional) +4. Translate the given schema filed into Chinese if the given text is Chinese but the schema is in English (Optional) ## Example (Note: Update the example to correspond to the given text and schema) ### Input example: @@ -75,8 +81,10 @@ The extracted text is: {text}""" if schema: return schema_real_prompt - log.warning("Recommend to provide a graph schema to improve the extraction accuracy. " - "Now using the default schema.") + log.warning( + "Recommend to provide a graph schema to improve the extraction accuracy. " + "Now using the default schema." + ) return text_based_prompt @@ -105,11 +113,17 @@ def extract_triples_by_regex_with_schema(schema, text, graph): # TODO: use a more efficient way to compare the extract & input property p_lower = p.lower() for vertex in schema["vertices"]: - if vertex["vertex_label"] == label and any(pp.lower() == p_lower - for pp in vertex["properties"]): + if vertex["vertex_label"] == label and any( + pp.lower() == p_lower for pp in vertex["properties"] + ): id = f"{label}-{s}" if id not in vertices_dict: - vertices_dict[id] = {"id": id, "name": s, "label": label, "properties": {p: o}} + vertices_dict[id] = { + "id": id, + "name": s, + "label": label, + "properties": {p: o}, + } else: vertices_dict[id]["properties"].update({p: o}) break @@ -118,25 +132,35 @@ def extract_triples_by_regex_with_schema(schema, text, graph): source_label = edge["source_vertex_label"] source_id = f"{source_label}-{s}" if source_id not in vertices_dict: - vertices_dict[source_id] = {"id": source_id, "name": s, "label": source_label, - "properties": {}} + vertices_dict[source_id] = { + "id": source_id, + "name": s, + "label": source_label, + "properties": {}, + } target_label = edge["target_vertex_label"] target_id = f"{target_label}-{o}" if target_id not in vertices_dict: - vertices_dict[target_id] = {"id": target_id, "name": o, "label": target_label, - "properties": {}} - graph["edges"].append({"start": source_id, "end": target_id, "type": label, - "properties": {}}) + vertices_dict[target_id] = { + "id": target_id, + "name": o, + "label": target_label, + "properties": {}, + } + graph["edges"].append( + { + "start": source_id, + "end": target_id, + "type": label, + "properties": {}, + } + ) break - graph["vertices"] = vertices_dict.values() + graph["vertices"] = list(vertices_dict.values()) class InfoExtract: - def __init__( - self, - llm: BaseLLM, - example_prompt: Optional[str] = None - ) -> None: + def __init__(self, llm: BaseLLM, example_prompt: Optional[str] = None) -> None: self.llm = llm self.example_prompt = example_prompt @@ -152,7 +176,12 @@ class InfoExtract: for sentence in chunks: proceeded_chunk = self.extract_triples_by_llm(schema, sentence) - log.debug("[Legacy] %s input: %s \n output:%s", self.__class__.__name__, sentence, proceeded_chunk) + log.debug( + "[Legacy] %s input: %s \n output:%s", + self.__class__.__name__, + sentence, + proceeded_chunk, + ) if schema: extract_triples_by_regex_with_schema(schema, proceeded_chunk, context) else: @@ -175,7 +204,152 @@ class InfoExtract: return True def _filter_long_id(self, graph) -> Dict[str, List[Any]]: - graph["vertices"] = [vertex for vertex in graph["vertices"] if self.valid(vertex["id"])] - graph["edges"] = [edge for edge in graph["edges"] - if self.valid(edge["start"]) and self.valid(edge["end"])] + graph["vertices"] = [ + vertex for vertex in graph["vertices"] if self.valid(vertex["id"]) + ] + graph["edges"] = [ + edge + for edge in graph["edges"] + if self.valid(edge["start"]) and self.valid(edge["end"]) + ] return graph + + +class InfoExtractNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + self.llm = get_chat_llm(llm_settings) + if self.wk_input.example_prompt is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + self.example_prompt = self.wk_input.example_prompt + return CStatus() + + def extract_triples_by_regex_with_schema(self, schema, text): + text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") + pattern = r"\((.*?), (.*?), (.*?)\) - ([^ ]*)" + matches = re.findall(pattern, text) + + vertices_dict = {v["id"]: v for v in self.context.vertices} + for match in matches: + s, p, o, label = [item.strip() for item in match] + if None in [label, s, p, o]: + continue + # TODO: use a more efficient way to compare the extract & input property + p_lower = p.lower() + for vertex in schema["vertices"]: + if vertex["vertex_label"] == label and any( + pp.lower() == p_lower for pp in vertex["properties"] + ): + id = f"{label}-{s}" + if id not in vertices_dict: + vertices_dict[id] = { + "id": id, + "name": s, + "label": label, + "properties": {p: o}, + } + else: + vertices_dict[id]["properties"].update({p: o}) + break + for edge in schema["edges"]: + if edge["edge_label"] == label: + source_label = edge["source_vertex_label"] + source_id = f"{source_label}-{s}" + if source_id not in vertices_dict: + vertices_dict[source_id] = { + "id": source_id, + "name": s, + "label": source_label, + "properties": {}, + } + target_label = edge["target_vertex_label"] + target_id = f"{target_label}-{o}" + if target_id not in vertices_dict: + vertices_dict[target_id] = { + "id": target_id, + "name": o, + "label": target_label, + "properties": {}, + } + self.context.edges.append( + { + "start": source_id, + "end": target_id, + "type": label, + "properties": {}, + } + ) + break + self.context.vertices = list(vertices_dict.values()) + + def extract_triples_by_regex(self, text): + text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") + pattern = r"\((.*?), (.*?), (.*?)\)" + self.context.triples += re.findall(pattern, text) + + def run(self) -> CStatus: + sts = self.node_init() + if sts.isErr(): + return sts + self.context.lock() + if self.context.chunks is None: + self.context.unlock() + raise ValueError("parameter required by extract node not found in context.") + schema = self.context.schema + chunks = self.context.chunks + + if schema: + self.context.vertices = [] + self.context.edges = [] + else: + self.context.triples = [] + + self.context.unlock() + + for sentence in chunks: + proceeded_chunk = self.extract_triples_by_llm(schema, sentence) + log.debug( + "[Legacy] %s input: %s \n output:%s", + self.__class__.__name__, + sentence, + proceeded_chunk, + ) + if schema: + self.extract_triples_by_regex_with_schema(schema, proceeded_chunk) + else: + self.extract_triples_by_regex(proceeded_chunk) + + if self.context.call_count: + self.context.call_count += len(chunks) + else: + self.context.call_count = len(chunks) + self._filter_long_id() + return CStatus() + + def extract_triples_by_llm(self, schema, chunk) -> str: + prompt = generate_extract_triple_prompt(chunk, schema) + if self.example_prompt is not None: + prompt = self.example_prompt + prompt + return self.llm.generate(prompt=prompt) + + # TODO: make 'max_length' be a configurable param in settings.py/settings.cfg + def valid(self, element_id: str, max_length: int = 256) -> bool: + if len(element_id.encode("utf-8")) >= max_length: + log.warning("Filter out GraphElementID too long: %s", element_id) + return False + return True + + def _filter_long_id(self): + self.context.vertices = [ + vertex for vertex in self.context.vertices if self.valid(vertex["id"]) + ] + self.context.edges = [ + edge + for edge in self.context.edges + if self.valid(edge["start"]) and self.valid(edge["end"]) + ] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py index faff1c6b..6e492b8f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py @@ -21,16 +21,19 @@ import json import re from typing import List, Any, Dict -from hugegraph_llm.config import prompt +from hugegraph_llm.config import llm_settings, prompt from hugegraph_llm.document.chunk_split import ChunkSplitter from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log -""" -TODO: It is not clear whether there is any other dependence on the SCHEMA_EXAMPLE_PROMPT variable. -Because the SCHEMA_EXAMPLE_PROMPT variable will no longer change based on -prompt.extract_graph_prompt changes after the system loads, this does not seem to meet expectations. -""" +from hugegraph_llm.operators.util import init_context +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.state.ai_state import WkFlowState, WkFlowInput +from PyCGraph import GNode, CStatus + +# TODO: It is not clear whether there is any other dependence on the SCHEMA_EXAMPLE_PROMPT variable. +# Because the SCHEMA_EXAMPLE_PROMPT variable will no longer change based on +# prompt.extract_graph_prompt changes after the system loads, this does not seem to meet expectations. SCHEMA_EXAMPLE_PROMPT = prompt.extract_graph_prompt @@ -60,20 +63,18 @@ def filter_item(schema, items) -> List[Dict[str, Any]]: properties_map["vertex"][vertex["name"]] = { "primary_keys": vertex["primary_keys"], "nullable_keys": vertex["nullable_keys"], - "properties": vertex["properties"] + "properties": vertex["properties"], } for edge in schema["edgelabels"]: - properties_map["edge"][edge["name"]] = { - "properties": edge["properties"] - } + properties_map["edge"][edge["name"]] = {"properties": edge["properties"]} log.info("properties_map: %s", properties_map) for item in items: item_type = item["type"] if item_type == "vertex": label = item["label"] - non_nullable_keys = ( - set(properties_map[item_type][label]["properties"]) - .difference(set(properties_map[item_type][label]["nullable_keys"]))) + non_nullable_keys = set( + properties_map[item_type][label]["properties"] + ).difference(set(properties_map[item_type][label]["nullable_keys"])) for key in non_nullable_keys: if key not in item["properties"]: item["properties"][key] = "NULL" @@ -87,9 +88,7 @@ def filter_item(schema, items) -> List[Dict[str, Any]]: class PropertyGraphExtract: def __init__( - self, - llm: BaseLLM, - example_prompt: str = prompt.extract_graph_prompt + self, llm: BaseLLM, example_prompt: str = prompt.extract_graph_prompt ) -> None: self.llm = llm self.example_prompt = example_prompt @@ -105,7 +104,12 @@ class PropertyGraphExtract: items = [] for chunk in chunks: proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk) - log.debug("[LLM] %s input: %s \n output:%s", self.__class__.__name__, chunk, proceeded_chunk) + log.debug( + "[LLM] %s input: %s \n output:%s", + self.__class__.__name__, + chunk, + proceeded_chunk, + ) items.extend(self._extract_and_filter_label(schema, proceeded_chunk)) items = filter_item(schema, items) for item in items: @@ -125,10 +129,132 @@ class PropertyGraphExtract: def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: # Use regex to extract a JSON object with curly braces - json_match = re.search(r'({.*})', text, re.DOTALL) + json_match = re.search(r"({.*})", text, re.DOTALL) + if not json_match: + log.critical( + "Invalid property graph! No JSON object found, " + "please check the output format example in prompt." + ) + return [] + json_str = json_match.group(1).strip() + + items = [] + try: + property_graph = json.loads(json_str) + # Expect property_graph to be a dict with keys "vertices" and "edges" + if not ( + isinstance(property_graph, dict) + and "vertices" in property_graph + and "edges" in property_graph + ): + log.critical( + "Invalid property graph format; expecting 'vertices' and 'edges'." + ) + return items + + # Create sets for valid vertex and edge labels based on the schema + vertex_label_set = {vertex["name"] for vertex in schema["vertexlabels"]} + edge_label_set = {edge["name"] for edge in schema["edgelabels"]} + + def process_items(item_list, valid_labels, item_type): + for item in item_list: + if not isinstance(item, dict): + log.warning( + "Invalid property graph item type '%s'.", type(item) + ) + continue + if not self.NECESSARY_ITEM_KEYS.issubset(item.keys()): + log.warning("Invalid item keys '%s'.", item.keys()) + continue + if item["label"] not in valid_labels: + log.warning( + "Invalid %s label '%s' has been ignored.", + item_type, + item["label"], + ) + continue + items.append(item) + + process_items(property_graph["vertices"], vertex_label_set, "vertex") + process_items(property_graph["edges"], edge_label_set, "edge") + except json.JSONDecodeError: + log.critical( + "Invalid property graph JSON! Please check the extracted JSON data carefully" + ) + return items + + +class PropertyGraphExtractNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + self.NECESSARY_ITEM_KEYS = {"label", "type", "properties"} # pylint: disable=invalid-name + return init_context(self) + + def node_init(self): + self.llm = get_chat_llm(llm_settings) + if self.wk_input.example_prompt is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + self.example_prompt = self.wk_input.example_prompt + return CStatus() + + def run(self) -> CStatus: + sts = self.node_init() + if sts.isErr(): + return sts + self.context.lock() + try: + if self.context.schema is None or self.context.chunks is None: + raise ValueError( + "parameter required by extract node not found in context." + ) + schema = self.context.schema + chunks = self.context.chunks + if self.context.vertices is None: + self.context.vertices = [] + if self.context.edges is None: + self.context.edges = [] + finally: + self.context.unlock() + + items = [] + for chunk in chunks: + proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk) + log.debug( + "[LLM] %s input: %s \n output:%s", + self.__class__.__name__, + chunk, + proceeded_chunk, + ) + items.extend(self._extract_and_filter_label(schema, proceeded_chunk)) + items = filter_item(schema, items) + self.context.lock() + try: + for item in items: + if item["type"] == "vertex": + self.context.vertices.append(item) + elif item["type"] == "edge": + self.context.edges.append(item) + finally: + self.context.unlock() + self.context.call_count = (self.context.call_count or 0) + len(chunks) + return CStatus() + + def extract_property_graph_by_llm(self, schema, chunk): + prompt = generate_extract_property_graph_prompt(chunk, schema) + if self.example_prompt is not None: + prompt = self.example_prompt + prompt + return self.llm.generate(prompt=prompt) + + def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: + # Use regex to extract a JSON object with curly braces + json_match = re.search(r"({.*})", text, re.DOTALL) if not json_match: - log.critical("Invalid property graph! No JSON object found, " - "please check the output format example in prompt.") + log.critical( + "Invalid property graph! No JSON object found, " + "please check the output format example in prompt." + ) return [] json_str = json_match.group(1).strip() @@ -136,8 +262,14 @@ class PropertyGraphExtract: try: property_graph = json.loads(json_str) # Expect property_graph to be a dict with keys "vertices" and "edges" - if not (isinstance(property_graph, dict) and "vertices" in property_graph and "edges" in property_graph): - log.critical("Invalid property graph format; expecting 'vertices' and 'edges'.") + if not ( + isinstance(property_graph, dict) + and "vertices" in property_graph + and "edges" in property_graph + ): + log.critical( + "Invalid property graph format; expecting 'vertices' and 'edges'." + ) return items # Create sets for valid vertex and edge labels based on the schema @@ -147,18 +279,26 @@ class PropertyGraphExtract: def process_items(item_list, valid_labels, item_type): for item in item_list: if not isinstance(item, dict): - log.warning("Invalid property graph item type '%s'.", type(item)) + log.warning( + "Invalid property graph item type '%s'.", type(item) + ) continue if not self.NECESSARY_ITEM_KEYS.issubset(item.keys()): log.warning("Invalid item keys '%s'.", item.keys()) continue if item["label"] not in valid_labels: - log.warning("Invalid %s label '%s' has been ignored.", item_type, item["label"]) + log.warning( + "Invalid %s label '%s' has been ignored.", + item_type, + item["label"], + ) continue items.append(item) process_items(property_graph["vertices"], vertex_label_set, "vertex") process_items(property_graph["edges"], edge_label_set, "edge") except json.JSONDecodeError: - log.critical("Invalid property graph JSON! Please check the extracted JSON data carefully") + log.critical( + "Invalid property graph JSON! Please check the extracted JSON data carefully" + ) return items diff --git a/hugegraph-llm/src/hugegraph_llm/operators/util.py b/hugegraph-llm/src/hugegraph_llm/operators/util.py new file mode 100644 index 00000000..60bdc2e8 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/operators/util.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus + + +def init_context(obj) -> CStatus: + try: + obj.context = obj.getGParamWithNoEmpty("wkflow_state") + obj.wk_input = obj.getGParamWithNoEmpty("wkflow_input") + if obj.context is None or obj.wk_input is None: + return CStatus(-1, "Required workflow parameters not found") + return CStatus() + except Exception as e: + return CStatus(-1, f"Failed to initialize context: {str(e)}") diff --git a/hugegraph-llm/src/hugegraph_llm/state/__init__.py b/hugegraph-llm/src/hugegraph_llm/state/__init__.py new file mode 100644 index 00000000..13a83393 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/state/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py new file mode 100644 index 00000000..0543aa2b --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import GParam, CStatus + +from typing import Union, List, Optional, Any + + +class WkFlowInput(GParam): + texts: Union[str, List[str]] = None # texts input used by ChunkSplit Node + language: str = None # language configuration used by ChunkSplit Node + split_type: str = None # split type used by ChunkSplit Node + example_prompt: str = None # need by graph information extract + schema: str = None # Schema information requeired by SchemaNode + graph_name: str = None + + def reset(self, _: CStatus) -> None: + self.texts = None + self.language = None + self.split_type = None + self.example_prompt = None + self.schema = None + self.graph_name = None + + +class WkFlowState(GParam): + schema: Optional[str] = None # schema message + simple_schema: Optional[str] = None + chunks: Optional[List[str]] = None + edges: Optional[List[Any]] = None + vertices: Optional[List[Any]] = None + triples: Optional[List[Any]] = None + call_count: Optional[int] = None + + keywords: Optional[List[str]] = None + vector_result = None + graph_result = None + keywords_embeddings = None + + def setup(self): + self.schema = None + self.simple_schema = None + self.chunks = None + self.edges = None + self.vertices = None + self.triples = None + self.call_count = None + + self.keywords = None + self.vector_result = None + self.graph_result = None + self.keywords_embeddings = None + + return CStatus() + + def to_json(self): + """ + Automatically returns a JSON-formatted dictionary of all non-None instance members, + eliminating the need to manually maintain the member list. + + Returns: + dict: A dictionary containing non-None instance members and their serialized values. + """ + # Only export instance attributes (excluding methods and class attributes) whose values are not None + return { + k: v + for k, v in self.__dict__.items() + if not k.startswith("_") and v is not None + } diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index 9fef06d2..f61b5f84 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -22,6 +22,7 @@ import traceback from typing import Dict, Any, Union, Optional import gradio as gr +from hugegraph_llm.flows.scheduler import SchedulerSingleton from .embedding_utils import get_filename_prefix, get_index_folder_name from .hugegraph_utils import get_hg_client, clean_hg_data @@ -35,11 +36,17 @@ from ..operators.kg_construction_task import KgBuilder def get_graph_index_info(): - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) graph_summary_info = builder.fetch_graph_data().run() - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) index_dir = str(os.path.join(resource_path, folder_name, "graph_vids")) - filename_prefix = get_filename_prefix(llm_settings.embedding_type, getattr(builder.embedding, "model_name", None)) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, getattr(builder.embedding, "model_name", None) + ) vector_index = VectorIndex.from_index_file(index_dir, filename_prefix) graph_summary_info["vid_index"] = { "embed_dim": vector_index.index.d, @@ -50,15 +57,20 @@ def get_graph_index_info(): def clean_all_graph_index(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) - filename_prefix = get_filename_prefix(llm_settings.embedding_type, - getattr(Embeddings().get_embedding(), "model_name", None)) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, + getattr(Embeddings().get_embedding(), "model_name", None), + ) VectorIndex.clean( - str(os.path.join(resource_path, folder_name, "graph_vids")), - filename_prefix) + str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix + ) VectorIndex.clean( str(os.path.join(resource_path, folder_name, "gremlin_examples")), - filename_prefix) + filename_prefix, + ) log.warning("Clear graph index and text2gql index successfully!") gr.Info("Clear graph index and text2gql index successfully!") @@ -71,7 +83,7 @@ def clean_all_graph_data(): def parse_schema(schema: str, builder: KgBuilder) -> Optional[str]: schema = schema.strip() - if schema.startswith('{'): + if schema.startswith("{"): try: schema = json.loads(schema) builder.import_schema(from_user_defined=schema) @@ -84,16 +96,20 @@ def parse_schema(schema: str, builder: KgBuilder) -> Optional[str]: return None -def extract_graph(input_file, input_text, schema, example_prompt) -> str: +def extract_graph_origin(input_file, input_text, schema, example_prompt) -> str: texts = read_documents(input_file, input_text) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) if not schema: return "ERROR: please input with correct schema/format." error_message = parse_schema(schema, builder) if error_message: return error_message - builder.chunk_split(texts, "document", "zh").extract_info(example_prompt, "property_graph") + builder.chunk_split(texts, "document", "zh").extract_info( + example_prompt, "property_graph" + ) try: context = builder.run() @@ -103,19 +119,40 @@ def extract_graph(input_file, input_text, schema, example_prompt) -> str: { "vertices": context["vertices"], "edges": context["edges"], - "warning": "The schema may not match the Doc" + "warning": "The schema may not match the Doc", }, ensure_ascii=False, - indent=2 + indent=2, ) - return json.dumps({"vertices": context["vertices"], "edges": context["edges"]}, ensure_ascii=False, indent=2) + return json.dumps( + {"vertices": context["vertices"], "edges": context["edges"]}, + ensure_ascii=False, + indent=2, + ) + except Exception as e: # pylint: disable=broad-exception-caught + log.error(e) + raise gr.Error(str(e)) + + +def extract_graph(input_file, input_text, schema, example_prompt) -> str: + texts = read_documents(input_file, input_text) + scheduler = SchedulerSingleton.get_instance() + if not schema: + return "ERROR: please input with correct schema/format." + + try: + return scheduler.schedule_flow( + "graph_extract", schema, texts, example_prompt, "property_graph" + ) except Exception as e: # pylint: disable=broad-exception-caught log.error(e) raise gr.Error(str(e)) def update_vid_embedding(): - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) builder.fetch_graph_data().build_vertex_id_semantic_index() log.debug("Operators: %s", builder.operators) try: @@ -132,7 +169,9 @@ def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: try: data_json = json.loads(data.strip()) log.debug("Import graph data: %s", data) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) if schema: error_message = parse_schema(schema, builder) if error_message: @@ -154,7 +193,7 @@ def build_schema(input_text, query_example, few_shot): context = { "raw_texts": [input_text] if input_text else [], "query_examples": [], - "few_shot_schema": {} + "few_shot_schema": {}, } if few_shot: @@ -170,7 +209,7 @@ def build_schema(input_text, query_example, few_shot): context["query_examples"] = [ { "description": ex.get("description", ""), - "gremlin": ex.get("gremlin", "") + "gremlin": ex.get("gremlin", ""), } for ex in parsed_examples if isinstance(ex, dict) and "description" in ex and "gremlin" in ex @@ -178,7 +217,9 @@ def build_schema(input_text, query_example, few_shot): except json.JSONDecodeError as e: raise gr.Error(f"Query Examples is not in a valid JSON format: {e}") from e - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) try: schema = builder.build_schema().run(context) except Exception as e: diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py index 62bcdd9c..138b0d35 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py @@ -23,11 +23,12 @@ import gradio as gr from hugegraph_llm.config import resource_path, huge_settings, llm_settings from hugegraph_llm.indices.vector_index import VectorIndex -from hugegraph_llm.models.embeddings.init_embedding import Embeddings -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.kg_construction_task import KgBuilder -from hugegraph_llm.utils.embedding_utils import get_filename_prefix, get_index_folder_name -from hugegraph_llm.utils.hugegraph_utils import get_hg_client +from hugegraph_llm.models.embeddings.init_embedding import model_map +from hugegraph_llm.flows.scheduler import SchedulerSingleton +from hugegraph_llm.utils.embedding_utils import ( + get_filename_prefix, + get_index_folder_name, +) def read_documents(input_file, input_text): @@ -49,7 +50,9 @@ def read_documents(input_file, input_text): texts.append(text) elif full_path.endswith(".pdf"): # TODO: support PDF file - raise gr.Error("PDF will be supported later! Try to upload text/docx now") + raise gr.Error( + "PDF will be supported later! Try to upload text/docx now" + ) else: raise gr.Error("Please input txt or docx file.") else: @@ -59,33 +62,44 @@ def read_documents(input_file, input_text): # pylint: disable=C0301 def get_vector_index_info(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) - filename_prefix = get_filename_prefix(llm_settings.embedding_type, - getattr(Embeddings().get_embedding(), "model_name", None)) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, model_map.get(llm_settings.embedding_type) + ) chunk_vector_index = VectorIndex.from_index_file( str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix, - record_miss=False + record_miss=False, ) graph_vid_vector_index = VectorIndex.from_index_file( - str(os.path.join(resource_path, folder_name, "graph_vids")), - filename_prefix + str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix + ) + return json.dumps( + { + "embed_dim": chunk_vector_index.index.d, + "vector_info": { + "chunk_vector_num": chunk_vector_index.index.ntotal, + "graph_vid_vector_num": graph_vid_vector_index.index.ntotal, + "graph_properties_vector_num": len(chunk_vector_index.properties), + }, + }, + ensure_ascii=False, + indent=2, ) - return json.dumps({ - "embed_dim": chunk_vector_index.index.d, - "vector_info": { - "chunk_vector_num": chunk_vector_index.index.ntotal, - "graph_vid_vector_num": graph_vid_vector_index.index.ntotal, - "graph_properties_vector_num": len(chunk_vector_index.properties) - } - }, ensure_ascii=False, indent=2) def clean_vector_index(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) - filename_prefix = get_filename_prefix(llm_settings.embedding_type, - getattr(Embeddings().get_embedding(), "model_name", None)) - VectorIndex.clean(str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, model_map.get(llm_settings.embedding_type) + ) + VectorIndex.clean( + str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix + ) gr.Info("Clean vector index successfully!") @@ -93,6 +107,5 @@ def build_vector_index(input_file, input_text): if input_file and input_text: raise gr.Error("Please only choose one between file and text.") texts = read_documents(input_file, input_text) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - context = builder.chunk_split(texts, "paragraph", "zh").build_vector_index().run() - return json.dumps(context, ensure_ascii=False, indent=2) + scheduler = SchedulerSingleton.get_instance() + return scheduler.schedule_flow("build_vector_index", texts)
