This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit b90925ad6d01183657fd91b54f8ed56590475da2 Author: Linyu <[email protected]> AuthorDate: Thu Sep 25 15:09:38 2025 +0800 refactor: refactor hugegraph-ai to integrate with CGraph & port some usecases in web demo (#49) --- .vibedev/spec/hugegraph-llm/fixed_flow/design.md | 643 +++++++++++++++++++++ .../spec/hugegraph-llm/fixed_flow/requirements.md | 24 + .vibedev/spec/hugegraph-llm/fixed_flow/tasks.md | 36 ++ .../demo/rag_demo/vector_graph_block.py | 21 +- .../src/hugegraph_llm/flows/build_schema.py | 71 +++ .../src/hugegraph_llm/flows/build_vector_index.py | 4 +- .../hugegraph_llm/flows/get_graph_index_info.py | 68 +++ .../src/hugegraph_llm/flows/graph_extract.py | 58 +- ...{build_vector_index.py => import_graph_data.py} | 48 +- .../{build_vector_index.py => prompt_generate.py} | 40 +- hugegraph-llm/src/hugegraph_llm/flows/scheduler.py | 31 +- ...ld_vector_index.py => update_vid_embeddings.py} | 44 +- hugegraph-llm/src/hugegraph_llm/flows/utils.py | 34 ++ hugegraph-llm/src/hugegraph_llm/nodes/base_node.py | 71 +++ .../nodes/document_node/chunk_split.py | 43 ++ .../nodes/hugegraph_node/commit_to_hugegraph.py | 35 ++ .../nodes/hugegraph_node/fetch_graph_data.py | 33 ++ .../hugegraph_llm/nodes/hugegraph_node/schema.py | 74 +++ .../nodes/index_node/build_semantic_index.py | 34 ++ .../nodes/index_node/build_vector_index.py | 34 ++ .../hugegraph_llm/nodes/llm_node/extract_info.py | 52 ++ .../nodes/llm_node/prompt_generate.py | 59 ++ .../hugegraph_llm/nodes/llm_node/schema_build.py | 91 +++ hugegraph-llm/src/hugegraph_llm/nodes/util.py | 27 + .../operators/common_op/check_schema.py | 160 ----- .../operators/document_op/chunk_split.py | 58 -- .../operators/hugegraph_op/commit_to_hugegraph.py | 127 +++- .../operators/hugegraph_op/schema_manager.py | 75 --- .../operators/index_op/build_vector_index.py | 48 -- .../hugegraph_llm/operators/llm_op/info_extract.py | 146 ----- .../operators/llm_op/property_graph_extract.py | 127 +--- hugegraph-llm/src/hugegraph_llm/state/ai_state.py | 28 + .../src/hugegraph_llm/utils/graph_index_utils.py | 40 ++ 33 files changed, 1707 insertions(+), 777 deletions(-) diff --git a/.vibedev/spec/hugegraph-llm/fixed_flow/design.md b/.vibedev/spec/hugegraph-llm/fixed_flow/design.md new file mode 100644 index 00000000..c5777236 --- /dev/null +++ b/.vibedev/spec/hugegraph-llm/fixed_flow/design.md @@ -0,0 +1,643 @@ +# Hugegraph-ai 固定工作流执行引擎设计文档 + +## 概述 + +Hugegraph固定工作流执行引擎是用来执行固定工作流的工作流执行引擎,每个工作流对应到实际Web Demo的一个具体用例,包括向量索引的构建,图索引的构建等等。该引擎基于PyCGraph框架构建,提供了高性能、可复用的流水线调度能力。 + +### 设计目标 + +- **性能优异**:通过流水线复用机制保证固定工作流的执行性能 +- **高可靠性**:确保数据一致性和故障恢复能力,提供完善的错误处理机制 +- **易于扩展**:能够简单轻松地新增固定工作流,支持动态调度 +- **资源优化**:通过流水线池化管理,减少重复构图开销 + +### 技术栈 + +- **PyCGraph**:基于C++的高性能图计算框架,提供GPipeline和GPipelineManager +- **Python**:主要开发语言,提供业务逻辑和接口层 +- **Threading**:支持并发调度和线程安全 + +### 模块分层 +```text +hugegraph-llm/ +└── src/ + └── hugegraph_llm/ + ├── api/ # FastAPI 接口层,提供 rag_api、admin_api 等服务 + ├── config/ # 配置管理,包含各类配置与生成工具 + ├── demo/ # Gradio Web Demo 及相关交互应用 + ├── document/ # 文档处理与分块等工具 + ├── enums/ # 枚举类型定义 + ├── flows/ # 工作流调度与核心流程(如向量/图索引构建、数据导入等) + │ ├── __init__.py + │ ├── common.py # BaseFlow抽象基类 + │ ├── scheduler.py # 调度器核心实现 + │ ├── build_vector_index.py # 向量索引构建工作流 + │ ├── graph_extract.py # 图抽取工作流 + │ ├── import_graph_data.py # 图数据导入工作流 + │ ├── update_vid_embeddings.py # 向量更新工作流 + │ ├── get_graph_index_info.py # 图索引信息获取工作流 + │ ├── build_schema.py # 模式构建工作流 + │ └── prompt_generate.py # 提示词生成工作流 + ├── indices/ # 各类索引实现(向量、图、关键词等) + ├── middleware/ # 中间件与请求处理 + ├── models/ # LLM、Embedding、Reranker 等模型相关 + ├── nodes/ # Node调度层,负责Operator生命周期和上下文管理 + │ ├── base_node.py + │ ├── document_node/ + │ ├── hugegraph_node/ + │ ├── index_node/ + │ ├── llm_node/ + │ └── util.py + ├── operators/ # 主要算子与任务(如 KG 构建、GraphRAG、Text2Gremlin 等) + ├── resources/ # 资源文件(Prompt、示例、Gremlin 模板等) + ├── state/ # 状态管理 + ├── utils/ # 工具类与通用方法 + └── __init__.py # 包初始化 +``` + +## 架构设计 + +### 整体架构 + +> 新架构在Flow与Operator之间引入Node层,Node负责Operator的生命周期管理、上下文绑定、参数区解耦和并发安全,所有Flow均通过Node组装,Operator只关注业务实现。 + +#### 架构图 + +```mermaid +graph TB + subgraph UserLayer["用户层"] + User["用户请求"] + end + + subgraph SchedulerLayer["调度层"] + Scheduler["Scheduler<br/>调度器"] + Singleton["SchedulerSingleton<br/>单例管理器"] + end + + subgraph FlowLayer["工作流层"] + Pool["pipeline_pool<br/>流水线池"] + BVI["BuildVectorIndexFlow<br/>向量索引构建"] + GE["GraphExtractFlow<br/>图抽取工作流"] + end + + subgraph PyCGraphLayer["PyCGraph层"] + Manager1["GPipelineManager<br/>向量索引管理器"] + Manager2["GPipelineManager<br/>图抽取管理器"] + Pipeline1["GPipeline<br/>向量索引流水线"] + Pipeline2["GPipeline<br/>图抽取流水线"] + end + + subgraph OperatorLayer["算子层"] + ChunkSplit["ChunkSplitNode<br/>文档分块"] + BuildVector["BuildVectorIndexNode<br/>向量索引构建"] + SchemaNode["SchemaNode<br/>模式管理"] + InfoExtract["ExtractNode<br/>信息抽取"] + PropGraph["Commit2GraphNode<br/>图数据导入"] + FetchNode["FetchGraphDataNode<br/>图数据拉取"] + SemanticIndex["BuildSemanticIndexNode<br/>语义索引构建"] + end + + subgraph StateLayer["状态层"] + WkInput["wkflow_input<br/>工作流输入"] + WkState["wkflow_state<br/>工作流状态"] + end + + User --> Scheduler + Scheduler --> Singleton + Scheduler --> Pool + Pool --> BVI + Pool --> GE + BVI --> Manager1 + GE --> Manager2 + Manager1 --> Pipeline1 + Manager2 --> Pipeline2 + Pipeline1 --> ChunkSplit + Pipeline1 --> BuildVector + Pipeline2 --> SchemaNode + Pipeline2 --> ChunkSplit + Pipeline2 --> InfoExtract + Pipeline2 --> PropGraph + Pipeline1 --> WkInput + Pipeline1 --> WkState + Pipeline2 --> WkInput + Pipeline2 --> WkState + + style Scheduler fill:#e1f5fe + style Pool fill:#f3e5f5 + style Manager1 fill:#fff3e0 + style Manager2 fill:#fff3e0 + style Pipeline1 fill:#e8f5e8 + style Pipeline2 fill:#e8f5e8 +``` + +#### 调度流程图 + +```mermaid +flowchart TD + Start([开始]) --> CheckFlow{检查工作流<br/>是否支持} + CheckFlow -->|否| Error1[抛出ValueError] + CheckFlow -->|是| FetchPipeline[从Manager获取<br/>可复用Pipeline] + + FetchPipeline --> IsNull{Pipeline<br/>是否为null} + + IsNull -->|是| BuildNew[构建新Pipeline] + BuildNew --> InitPipeline[初始化Pipeline] + InitPipeline --> InitCheck{初始化<br/>是否成功} + InitCheck -->|否| Error2[记录错误并中止] + InitCheck -->|是| RunPipeline[执行Pipeline] + RunPipeline --> RunCheck{执行<br/>是否成功} + RunCheck -->|否| Error3[记录错误并中止] + RunCheck -->|是| PostDeal[后处理结果] + PostDeal --> AddToPool[添加到复用池] + AddToPool --> Return[返回结果] + + IsNull -->|否| PrepareInput[准备输入数据] + PrepareInput --> RunReused[执行复用Pipeline] + RunReused --> ReusedCheck{执行<br/>是否成功} + ReusedCheck -->|否| Error4[抛出RuntimeError] + ReusedCheck -->|是| PostDealReused[后处理结果] + PostDealReused --> ReleasePipeline[释放Pipeline] + ReleasePipeline --> Return + + Error1 --> End([结束]) + Error2 --> End + Error3 --> End + Error4 --> End + Return --> End + + style Start fill:#4caf50 + style End fill:#f44336 + style CheckFlow fill:#ff9800 + style IsNull fill:#ff9800 + style InitCheck fill:#ff9800 + style RunCheck fill:#ff9800 + style ReusedCheck fill:#ff9800 +``` + +### 核心组件 + +#### 1. Scheduler(调度器) +- **职责**:调度中心,维护 `pipeline_pool`,提供统一的工作流调度接口 +- **特性**: + - 支持多种工作流类型(build_vector_index、graph_extract、import_graph_data、update_vid_embeddings、get_graph_index_info、build_schema、prompt_generate等) + - 流水线池化管理,支持复用 + - 线程安全的单例模式 + - 可配置的最大流水线数量 + +#### 2. GPipelineManager(流水线管理器) +- **来源**:PyCGraph框架提供 +- **职责**:负责流水线对象 `GPipeline` 的获取、添加、释放与复用 +- **特性**: + - 自动管理流水线生命周期 + - 支持流水线复用和资源回收 + - 提供fetch/add/release操作接口 + +#### 3. BaseFlow(工作流基类) +- **职责**:工作流构建与前后处理抽象 +- **接口**: + - `prepare()`: 预处理接口,准备输入数据 + - `build_flow()`: 组装Node并注册依赖关系 + - `post_deal()`: 后处理接口,处理执行结果 +- **实现**: + - `BuildVectorIndexFlow`: 向量索引构建工作流 + - `GraphExtractFlow`: 图抽取工作流 + - `ImportGraphDataFlow`: 图数据导入工作流 + - `UpdateVidEmbeddingsFlows`: 向量更新工作流 + - `GetGraphIndexInfoFlow`: 图索引信息获取工作流 + - `BuildSchemaFlow`: 模式构建工作流 + - `PromptGenerateFlow`: 提示词生成工作流 + +#### 4. Node(节点调度器) +- **职责**:作为Operator的生命周期管理者,负责参数区绑定、上下文初始化、并发安全、异常处理等。 +- **特性**: + - 统一生命周期接口(init、node_init、run、operator_schedule) + - 通过参数区(wkflow_input/wkflow_state)与Flow/Operator解耦 + - Operator只需实现run(data_json)方法,Node负责调度和结果写回 + - 典型Node如:ChunkSplitNode、BuildVectorIndexNode、SchemaNode、ExtractNode、Commit2GraphNode、FetchGraphDataNode、BuildSemanticIndexNode、SchemaBuildNode、PromptGenerateNode等 + +#### 5. Operator(算子) +- **职责**:实现具体的业务原子操作 +- **特性**: + - 只需关注自身业务逻辑实现 + - 由Node统一调度 + +#### 6. GPipeline(流水线实例) +- **来源**:PyCGraph框架提供 +- **职责**:具体流水线实例,包含参数区与节点DAG拓扑 +- **参数区**: + - `wkflow_input`: 流水线运行输入 + - `wkflow_state`: 流水线运行状态与中间结果 + +### 核心数据结构 + +```python +# Scheduler核心数据结构 +Scheduler.pipeline_pool: Dict[str, Any] = { + "build_vector_index": { + "manager": GPipelineManager(), + "flow": BuildVectorIndexFlow(), + }, + "graph_extract": { + "manager": GPipelineManager(), + "flow": GraphExtractFlow(), + } +} +``` + +### 调度流程 + +#### schedule_flow方法执行流程 + +1. **工作流验证**:校验 `flow` 是否受支持,查表获取对应的 `manager` 与 `flow` 实例 + +2. **流水线获取**:从 `manager.fetch()` 获取可复用的 `GPipeline` + +3. **新流水线处理**(当fetch()返回None时): + - 调用 `flow.build_flow(*args, **kwargs)` 构建新流水线 + - 调用 `pipeline.init()` 完成初始化,失败则记录错误并中止 + - 调用 `pipeline.run()` 执行,失败则中止 + - 调用 `flow.post_deal(pipeline)` 生成输出 + - 调用 `manager.add(pipeline)` 将流水线加入可复用池 + +4. **复用流水线处理**(当fetch()返回现有流水线时): + - 从 `pipeline.getGParamWithNoEmpty("wkflow_input")` 获取输入对象 + - 调用 `flow.prepare(prepared_input, *args, **kwargs)` 进行参数刷新 + - 调用 `pipeline.run()` 执行,失败则中止 + - 调用 `flow.post_deal(pipeline)` 生成输出 + - 调用 `manager.release(pipeline)` 归还流水线 + +### 并发与复用策略 + +#### 线程安全 +- `SchedulerSingleton` 使用双重检查锁保证全局单例 +- 线程安全获取 `Scheduler` 实例 + +#### 资源管理 +- 每种 `flow` 拥有独立的 `GPipelineManager` +- 最大并发量由 `Scheduler.max_pipeline` 与底层 `GPipelineManager` 策略共同约束 +- 通过 `fetch/add/release` 机制减少重复构图的开销 + +#### 性能优化 +- 流水线复用机制适合高频相同工作流场景 +- 减少重复初始化和构图的时间开销 +- 支持并发执行多个工作流实例 + +### 错误处理与日志 + +#### 错误检测 +- 对 `init/run` 的 `Status.isErr()` 进行检测 +- 统一抛出 `RuntimeError` 并记录详细 `status.getInfo()` +- 提供完整的错误堆栈信息 + +#### 日志记录 +- 使用统一的日志系统记录关键操作 +- 记录流水线执行状态和错误信息 +- 支持不同级别的日志输出 + +#### 结果处理 +- `flow.post_deal` 负责将 `wkflow_state` 转换为对外可消费结果(如JSON) +- 提供标准化的输出格式 +- 支持错误信息的友好展示 + +### 扩展指引 + +#### 新增Node/Operator/Flow步骤 +1. 实现Operator业务逻辑(如ChunkSplit/BuildVectorIndex/InfoExtract等) +2. 实现对应Node(继承BaseNode,负责参数区绑定和调度Operator) +3. 在Flow中组装Node,注册依赖关系 +4. 在Scheduler注册新的Flow + +#### 输入输出约定 +- 统一使用 `wkflow_input` 作为输入载体 +- 统一使用 `wkflow_state` 作为状态与结果容器 +- 确保可复用流水线在不同请求间可被快速重置 + +#### 最佳实践 +- 保持Flow类的无状态设计 +- 合理使用流水线复用机制 +- 提供完善的错误处理和日志记录 +- 遵循统一的接口规范 + +## Flow对象设计 + +### BaseFlow抽象基类 + +```python +class BaseFlow(ABC): + """ + Base class for flows, defines three interface methods: prepare, build_flow, and post_deal. + """ + + @abstractmethod + def prepare(self, prepared_input: WkFlowInput, *args, **kwargs): + """ + Pre-processing interface. + """ + pass + + @abstractmethod + def build_flow(self, *args, **kwargs): + """ + Interface for building the flow. + """ + pass + + @abstractmethod + def post_deal(self, *args, **kwargs): + """ + Post-processing interface. + """ + pass +``` + +### 接口说明 + +每个Flow对象都需要实现三个核心接口: + +- **prepare**: 用来准备整个workflow的输入数据,设置工作流参数 +- **build_flow**: 用来构建整个workflow的流水线,注册节点和依赖关系 +- **post_deal**: 用来处理workflow的执行结果,转换为对外输出格式 + +### 具体实现示例 + +#### BuildVectorIndexFlow(向量索引构建工作流) + +```python +class BuildVectorIndexFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, texts): + prepared_input.texts = texts + prepared_input.language = "zh" + prepared_input.split_type = "paragraph" + return + + def build_flow(self, texts): + pipeline = GPipeline() + # prepare for workflow input + prepared_input = WkFlowInput() + self.prepare(prepared_input, texts) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + chunk_split_node = ChunkSplitNode() + build_vector_node = BuildVectorIndexNode() + pipeline.registerGElement(chunk_split_node, set(), "chunk_split") + pipeline.registerGElement(build_vector_node, {chunk_split_node}, "build_vector") + + return pipeline + + def post_deal(self, pipeline=None): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + return json.dumps(res, ensure_ascii=False, indent=2) +``` + +#### GraphExtractFlow(图抽取工作流) + +```python +class GraphExtractFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, schema, texts, example_prompt, extract_type): + prepared_input.texts = texts + prepared_input.language = "zh" + prepared_input.split_type = "document" + prepared_input.example_prompt = example_prompt + prepared_input.schema = schema + prepare_schema(prepared_input, schema) + return + + def build_flow(self, schema, texts, example_prompt, extract_type): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, schema, texts, example_prompt, extract_type) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + schema_node = SchemaNode() + + chunk_split_node = ChunkSplitNode() + graph_extract_node = ExtractNode() + + pipeline.registerGElement(schema_node, set(), "schema_node") + pipeline.registerGElement(chunk_split_node, set(), "chunk_split") + pipeline.registerGElement( + graph_extract_node, {schema_node, chunk_split_node}, "graph_extract" + ) + + return pipeline + + def post_deal(self, pipeline=None): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + vertices = res.get("vertices", []) + edges = res.get("edges", []) + if not vertices and not edges: + log.info("Please check the schema.(The schema may not match the Doc)") + return json.dumps( + { + "vertices": vertices, + "edges": edges, + "warning": "The schema may not match the Doc", + }, + ensure_ascii=False, + indent=2, + ) + return json.dumps( + {"vertices": vertices, "edges": edges}, + ensure_ascii=False, + indent=2, + ) +``` + +## Node对象设计 + +### 节点生命周期 + +节点以 GNode 为抽象基类,统一生命周期与状态返回。方法职责与约定如下: + +#### 初始化阶段 + +- **init()**: + - **责任**:完成节点级初始化工作(如绑定共享上下文、准备参数区),确保节点具备运行所需的最小环境 + - **约定**:仅做轻量初始化,不执行业务逻辑;返回状态用于判断是否可继续 + +- **node_init()**: + - **责任**:解析与校验本次运行所需的输入(通常来自 wk_input),构建运行期依赖(如内部配置、变换器、资源句柄) + - **约定**:输入缺失或不合法时,应返回错误状态并中止后续执行;不产生对外可见的业务结果 + +#### 运行阶段 + +- **run()**: + - **责任**:执行业务主流程(纯计算或 I/O),在完成后将节点产出写入共享状态(wkflow_state/上下文) + - **约定**: + - 进入前应先调用 node_init() 并检查其返回状态 + - 对共享状态的写操作需遵循并发安全约定(如加锁/解锁) + - 出错使用统一状态返回,不抛出未捕获异常到流程编排层 + +### 输入/输出与上下文约定 + +- **输入**:通过编排层预置于参数区(如 wk_input),节点在 node_init() 中读取并校验 +- **输出**:通过共享状态容器(如 wkflow_state/上下文)对外暴露,键/字段命名应稳定可预期,供下游节点消费 + +### 错误处理约定 + +- 统一以状态对象表示成功/失败与信息;错误应尽早返回,避免在 run() 中继续副作用操作 +- 对可预见的校验类错误使用明确的错误信息,便于定位问题与编排层记录 + +### 并发与可重入约定 + +- 共享状态的写入需在临界区内完成;读取视数据一致性要求决定是否加锁 +- 节点应尽量保持无副作用或将副作用范围收敛在可控区域,以支持重试与复用 + +### 可测试性与解耦 + +- 业务纯逻辑应与框架交互解耦,优先封装为可单测的纯函数/内部方法 +- 节点仅负责生命周期编排与上下文读写,具体策略与算法通过内部可替换组件提供 + +### 节点类型 + +#### 文档处理节点 +- **ChunkSplitNode**: 文档分块处理节点 + - 功能:将输入文档按照指定策略进行分块 + - 输入:原始文档文本 + - 输出:分块后的文档片段 + +#### 索引构建节点 +- **BuildVectorIndexNode**: 向量索引构建节点 + - 功能:基于文档分块构建向量索引 + - 输入:文档分块 + - 输出:向量索引数据 + +#### 模式管理节点 +- **SchemaManagerNode**: 图模式管理节点 + - 功能:从HugeGraph获取图模式信息 + - 输入:图名称 + - 输出:图模式定义 + +- **CheckSchemaNode**: 模式校验节点 + - 功能:校验用户定义的图模式 + - 输入:用户定义的JSON模式 + - 输出:校验后的模式定义 + +#### 图抽取节点 +- **InfoExtractNode**: 信息抽取节点 + - 功能:从文档中抽取三元组信息 + - 输入:文档分块和模式定义 + - 输出:抽取的三元组数据 + +- **PropertyGraphExtractNode**: 属性图抽取节点 + - 功能:从文档中抽取属性图结构 + - 输入:文档分块和模式定义 + - 输出:抽取的顶点和边数据 + +#### 模式构建节点 +- **SchemaBuildNode**: 模式构建节点 + - 功能:基于文档和查询示例构建图模式 + - 输入:文档文本、查询示例、少样本模式 + - 输出:构建的图模式定义 + +#### 提示词生成节点 +- **PromptGenerateNode**: 提示词生成节点 + - 功能:基于源文本、场景和示例名称生成提示词 + - 输入:源文本、场景、示例名称 + - 输出:生成的提示词 + + +## 测试策略 + +### 测试目标 + +目前的测试策略主要目标是保证移植之后的workflow和移植之前的workflow执行结果、程序行为一致。 + +### 测试范围 + +#### 1. 功能测试 +- **工作流执行结果一致性**:确保新架构下的工作流执行结果与原有实现完全一致 +- **输入输出格式验证**:验证输入参数处理和输出格式转换的正确性 +- **错误处理测试**:确保错误场景下的行为与预期一致 + +#### 2. 性能测试 +- **流水线复用效果**:验证流水线复用机制的性能提升效果 +- **并发执行测试**:测试多工作流并发执行的稳定性和性能 +- **资源使用测试**:监控内存和CPU使用情况,确保资源使用合理 + +#### 3. 稳定性测试 +- **长时间运行测试**:验证系统在长时间运行下的稳定性 +- **异常恢复测试**:测试系统在异常情况下的恢复能力 +- **内存泄漏测试**:确保流水线复用不会导致内存泄漏 + +### 测试方法 + +#### 1. 单元测试 +- 对每个Flow类进行单元测试 +- 对每个Node类进行单元测试 +- 对Scheduler调度逻辑进行测试 + +#### 2. 集成测试 +- 端到端工作流测试 +- 多工作流组合测试 +- 与外部系统集成测试 + +#### 3. 性能基准测试 +- 建立性能基准线 +- 对比新旧架构的性能差异 +- 监控关键性能指标 + +### 测试数据 + +#### 1. 标准测试数据集 +- 准备标准化的测试文档 +- 准备标准化的图模式定义 +- 准备标准化的期望输出结果 + +#### 2. 边界测试数据 +- 空输入测试 +- 大文件测试 +- 特殊字符测试 +- 异常格式测试 + +### 测试环境 + +#### 1. 开发环境测试 +- 本地开发环境的功能验证 +- 快速迭代测试 + +#### 2. 测试环境验证 +- 模拟生产环境的完整测试 +- 性能压力测试 + +#### 3. 生产环境验证 +- 灰度发布验证 +- 生产环境监控 + +### 测试自动化 + +#### 1. CI/CD集成 +- 自动化测试流程集成 +- 代码提交触发测试 +- 测试结果自动报告 + +#### 2. 回归测试 +- 定期执行回归测试 +- 确保新功能不影响现有功能 +- 性能回归检测 + +### 测试指标 + +#### 1. 功能指标 +- 测试覆盖率 > 90% +- 功能正确性 100% +- 错误处理覆盖率 > 95% + +#### 2. 性能指标 +- 响应时间提升 > 20% +- 吞吐量提升 > 30% +- 资源使用优化 > 15% + +#### 3. 稳定性指标 +- 系统可用性 > 99.9% +- 平均故障恢复时间 < 5分钟 +- 内存泄漏率 = 0% diff --git a/.vibedev/spec/hugegraph-llm/fixed_flow/requirements.md b/.vibedev/spec/hugegraph-llm/fixed_flow/requirements.md new file mode 100644 index 00000000..09502736 --- /dev/null +++ b/.vibedev/spec/hugegraph-llm/fixed_flow/requirements.md @@ -0,0 +1,24 @@ +## 需求列表 + +### 核心框架设计 + +**核心**:Scheduler类中的schedule_flow设计与实现 + +**验收标准**: +1.1. 核心框架尽可能复用资源,避免资源的重复分配和释放 +1.2. 应该保证正常的请求处理指标要求 +1.3. 应该能够配置框架整体使用的资源上限 + +### 固定工作流移植 + +**核心**:移植Web Demo中的所有用例 +2.1. 保证使用核心框架移植后的工作流的程序行为和移植之前保持一致即可 + +**已完成的工作流类型**: +- build_vector_index: 向量索引构建工作流 +- graph_extract: 图抽取工作流 +- import_graph_data: 图数据导入工作流 +- update_vid_embeddings: 向量更新工作流 +- get_graph_index_info: 图索引信息获取工作流 +- build_schema: 模式构建工作流 +- prompt_generate: 提示词生成工作流 diff --git a/.vibedev/spec/hugegraph-llm/fixed_flow/tasks.md b/.vibedev/spec/hugegraph-llm/fixed_flow/tasks.md new file mode 100644 index 00000000..a84aee2f --- /dev/null +++ b/.vibedev/spec/hugegraph-llm/fixed_flow/tasks.md @@ -0,0 +1,36 @@ +# HugeGraph-ai 固定工作流框架设计和用例移植 + +本文档将 HugeGraph 固定工作流框架设计和用例移植转换为一系列可执行的编码任务。 + +## 1. schedule_flow设计与实现 + +- [x] **1.1 构建Scheduler框架1.0** + - 需要能够复用已经创建过的Pipeline(Pipeline Pooling) + - 使用CGraph(Graph-based engine)作为底层执行引擎 + - 不同Node之间松耦合 + +- [ ] **1.2 优化Scheduler框架资源配置** + - 支持用户配置底层线程池参数 + - 现有的workflow可能会根据输入有细小的变化,导致相同的用例得到不同的workflow,怎么解决这个问题呢? + - Node/Operator解耦,Node负责生命周期和上下文,Operator只关注业务逻辑 + - Flow只负责组装Node,所有业务逻辑下沉到Node/Operator + - Scheduler支持多类型Flow注册,注册方式更灵活 + +- [ ] **1.3 优化Scheduler框架资源使用** + - 根据负载控制每个PipelineManager管理的Pipeline数量,实现动态扩缩容 + - Node层支持参数区自动绑定和并发安全 + - Operator只需实现run(data_json)方法,Node负责调度和结果写回 + +## 2. 固定工作流用例移植 + +- [x] **2.1 build_vector_index workflow移植** +- [x] **2.2 graph_extract workflow移植** +- [x] **2.3 import_graph_data workflow移植** + - 基于Node/Operator机制实现import_graph_data工作流 +- [x] **2.4 update_vid_embeddings workflow移植** + - 基于Node/Operator机制实现update_vid_embeddings工作流 +- [x] **2.5 get_graph_index_info workflow移植** +- [x] **2.6 build_schema workflow移植** + - 基于Node/Operator机制实现build_schema工作流 +- [x] **2.7 prompt_generate workflow移植** + - 基于Node/Operator机制实现prompt_generate工作流 diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py index 9897f420..4aa47694 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py @@ -26,8 +26,7 @@ import gradio as gr from hugegraph_llm.config import huge_settings from hugegraph_llm.config import prompt from hugegraph_llm.config import resource_path -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.llm_op.prompt_generate import PromptGenerate +from hugegraph_llm.flows.scheduler import SchedulerSingleton from hugegraph_llm.utils.graph_index_utils import ( get_graph_index_info, clean_all_graph_index, @@ -61,7 +60,7 @@ def store_prompt(doc, schema, example_prompt): def generate_prompt_for_ui(source_text, scenario, example_name): """ - Handles the UI logic for generating a new prompt. It calls the PromptGenerate operator. + Handles the UI logic for generating a new prompt using the new workflow architecture. """ if not all([source_text, scenario, example_name]): gr.Warning( @@ -69,19 +68,13 @@ def generate_prompt_for_ui(source_text, scenario, example_name): ) return gr.update() try: - prompt_generator = PromptGenerate(llm=LLMs().get_chat_llm()) - context = { - "source_text": source_text, - "scenario": scenario, - "example_name": example_name, - } - result_context = prompt_generator.run(context) - # Presents the result of generating prompt - generated_prompt = result_context.get( - "generated_extract_prompt", "Generation failed. Please check the logs." + # using new architecture + scheduler = SchedulerSingleton.get_instance() + result = scheduler.schedule_flow( + "prompt_generate", source_text, scenario, example_name ) gr.Info("Prompt generated successfully!") - return generated_prompt + return result except Exception as e: log.error("Error generating Prompt: %s", e, exc_info=True) raise gr.Error(f"Error generating Prompt: {e}") from e diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py b/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py new file mode 100644 index 00000000..6bbcb851 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.nodes.llm_node.schema_build import SchemaBuildNode +from hugegraph_llm.utils.log import log + +import json +from PyCGraph import GPipeline + + +class BuildSchemaFlow(BaseFlow): + def __init__(self): + pass + + def prepare( + self, + prepared_input: WkFlowInput, + texts=None, + query_examples=None, + few_shot_schema=None, + ): + prepared_input.texts = texts + # Optional fields packed into wk_input for SchemaBuildNode + # Keep raw values; node will parse if strings + prepared_input.query_examples = query_examples + prepared_input.few_shot_schema = few_shot_schema + return + + def build_flow(self, texts=None, query_examples=None, few_shot_schema=None): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare( + prepared_input, + texts=texts, + query_examples=query_examples, + few_shot_schema=few_shot_schema, + ) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + schema_build_node = SchemaBuildNode() + pipeline.registerGElement(schema_build_node, set(), "schema_build") + + return pipeline + + def post_deal(self, pipeline=None): + state_json = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + if "schema" not in state_json: + return "" + res = state_json["schema"] + try: + formatted_schema = json.dumps(res, ensure_ascii=False, indent=2) + return formatted_schema + except (TypeError, ValueError) as e: + log.error("Failed to format schema: %s", e) + return str(res) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py index f1ee8c1c..9a07b5db 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py @@ -14,13 +14,13 @@ # limitations under the License. from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.document_node.chunk_split import ChunkSplitNode +from hugegraph_llm.nodes.index_node.build_vector_index import BuildVectorIndexNode from hugegraph_llm.state.ai_state import WkFlowInput import json from PyCGraph import GPipeline -from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode -from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndexNode from hugegraph_llm.state.ai_state import WkFlowState diff --git a/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py new file mode 100644 index 00000000..fa10d019 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from hugegraph_llm.config import huge_settings, llm_settings, resource_path +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.indices.vector_index import VectorIndex +from hugegraph_llm.models.embeddings.init_embedding import model_map +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.nodes.hugegraph_node.fetch_graph_data import FetchGraphDataNode +from PyCGraph import GPipeline +from hugegraph_llm.utils.embedding_utils import ( + get_filename_prefix, + get_index_folder_name, +) + + +class GetGraphIndexInfoFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, *args, **kwargs): + return + + def build_flow(self, *args, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, *args, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + fetch_node = FetchGraphDataNode() + pipeline.registerGElement(fetch_node, set(), "fetch_node") + return pipeline + + def post_deal(self, pipeline=None): + graph_summary_info = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + index_dir = str(os.path.join(resource_path, folder_name, "graph_vids")) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, + model_map.get(llm_settings.embedding_type, None), + ) + try: + vector_index = VectorIndex.from_index_file(index_dir, filename_prefix) + except FileNotFoundError: + return json.dumps(graph_summary_info, ensure_ascii=False, indent=2) + graph_summary_info["vid_index"] = { + "embed_dim": vector_index.index.d, + "num_vectors": vector_index.index.ntotal, + "num_vids": len(vector_index.properties), + } + return json.dumps(graph_summary_info, ensure_ascii=False, indent=2) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py index f1a6c5f6..1b0c9825 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py @@ -16,14 +16,10 @@ import json from PyCGraph import GPipeline from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.document_node.chunk_split import ChunkSplitNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.nodes.llm_node.extract_info import ExtractNode from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from hugegraph_llm.operators.common_op.check_schema import CheckSchemaNode -from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode -from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManagerNode -from hugegraph_llm.operators.llm_op.info_extract import InfoExtractNode -from hugegraph_llm.operators.llm_op.property_graph_extract import ( - PropertyGraphExtractNode, -) from hugegraph_llm.utils.log import log @@ -31,21 +27,6 @@ class GraphExtractFlow(BaseFlow): def __init__(self): pass - def _import_schema( - self, - from_hugegraph=None, - from_extraction=None, - from_user_defined=None, - ): - if from_hugegraph: - return SchemaManagerNode() - elif from_user_defined: - return CheckSchemaNode() - elif from_extraction: - raise NotImplementedError("Not implemented yet") - else: - raise ValueError("No input data / invalid schema type") - def prepare( self, prepared_input: WkFlowInput, schema, texts, example_prompt, extract_type ): @@ -55,17 +36,7 @@ class GraphExtractFlow(BaseFlow): prepared_input.split_type = "document" prepared_input.example_prompt = example_prompt prepared_input.schema = schema - schema = schema.strip() - if schema.startswith("{"): - try: - schema = json.loads(schema) - prepared_input.schema = schema - except json.JSONDecodeError as exc: - log.error("Invalid JSON format in schema. Please check it again.") - raise ValueError("Invalid JSON format in schema.") from exc - else: - log.info("Get schema '%s' from graphdb.", schema) - prepared_input.graph_name = schema + prepared_input.extract_type = extract_type return def build_flow(self, schema, texts, example_prompt, extract_type): @@ -76,27 +47,10 @@ class GraphExtractFlow(BaseFlow): pipeline.createGParam(prepared_input, "wkflow_input") pipeline.createGParam(WkFlowState(), "wkflow_state") - schema = schema.strip() - schema_node = None - if schema.startswith("{"): - try: - schema = json.loads(schema) - schema_node = self._import_schema(from_user_defined=schema) - except json.JSONDecodeError as exc: - log.error("Invalid JSON format in schema. Please check it again.") - raise ValueError("Invalid JSON format in schema.") from exc - else: - log.info("Get schema '%s' from graphdb.", schema) - schema_node = self._import_schema(from_hugegraph=schema) + schema_node = SchemaNode() chunk_split_node = ChunkSplitNode() - graph_extract_node = None - if extract_type == "triples": - graph_extract_node = InfoExtractNode() - elif extract_type == "property_graph": - graph_extract_node = PropertyGraphExtractNode() - else: - raise ValueError(f"Unsupported extract_type: {extract_type}") + graph_extract_node = ExtractNode() pipeline.registerGElement(schema_node, set(), "schema_node") pipeline.registerGElement(chunk_split_node, set(), "chunk_split") pipeline.registerGElement( diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py similarity index 50% copy from hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py copy to hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py index f1ee8c1c..5581ef10 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py @@ -13,43 +13,53 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hugegraph_llm.flows.common import BaseFlow -from hugegraph_llm.state.ai_state import WkFlowInput - import json -from PyCGraph import GPipeline -from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode -from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndexNode -from hugegraph_llm.state.ai_state import WkFlowState +import gradio as gr +from PyCGraph import GPipeline +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.hugegraph_node.commit_to_hugegraph import Commit2GraphNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.utils.log import log -class BuildVectorIndexFlow(BaseFlow): +class ImportGraphDataFlow(BaseFlow): def __init__(self): pass - def prepare(self, prepared_input: WkFlowInput, texts): - prepared_input.texts = texts - prepared_input.language = "zh" - prepared_input.split_type = "paragraph" + def prepare(self, prepared_input: WkFlowInput, data, schema): + try: + data_json = json.loads(data.strip()) if isinstance(data, str) else data + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON for 'data': {e.msg}") from e + log.debug( + "Import graph data (truncated): %s", + (data[:512] + "...") + if isinstance(data, str) and len(data) > 512 + else (data if isinstance(data, str) else "<obj>"), + ) + prepared_input.data_json = data_json + prepared_input.schema = schema return - def build_flow(self, texts): + def build_flow(self, data, schema): pipeline = GPipeline() - # prepare for workflow input prepared_input = WkFlowInput() - self.prepare(prepared_input, texts) + # prepare input data + self.prepare(prepared_input, data, schema) pipeline.createGParam(prepared_input, "wkflow_input") pipeline.createGParam(WkFlowState(), "wkflow_state") - chunk_split_node = ChunkSplitNode() - build_vector_node = BuildVectorIndexNode() - pipeline.registerGElement(chunk_split_node, set(), "chunk_split") - pipeline.registerGElement(build_vector_node, {chunk_split_node}, "build_vector") + schema_node = SchemaNode() + commit_node = Commit2GraphNode() + pipeline.registerGElement(schema_node, set(), "schema_node") + pipeline.registerGElement(commit_node, {schema_node}, "commit_node") return pipeline def post_deal(self, pipeline=None): res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + gr.Info("Import graph data successfully!") return json.dumps(res, ensure_ascii=False, indent=2) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py similarity index 56% copy from hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py copy to hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py index f1ee8c1c..aece6bd6 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py @@ -14,42 +14,50 @@ # limitations under the License. from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.llm_node.prompt_generate import PromptGenerateNode from hugegraph_llm.state.ai_state import WkFlowInput -import json from PyCGraph import GPipeline -from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode -from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndexNode from hugegraph_llm.state.ai_state import WkFlowState -class BuildVectorIndexFlow(BaseFlow): +class PromptGenerateFlow(BaseFlow): def __init__(self): pass - def prepare(self, prepared_input: WkFlowInput, texts): - prepared_input.texts = texts - prepared_input.language = "zh" - prepared_input.split_type = "paragraph" + def prepare(self, prepared_input: WkFlowInput, source_text, scenario, example_name): + """ + Prepare input data for PromptGenerate workflow + """ + prepared_input.source_text = source_text + prepared_input.scenario = scenario + prepared_input.example_name = example_name return - def build_flow(self, texts): + def build_flow(self, source_text, scenario, example_name): + """ + Build the PromptGenerate workflow + """ pipeline = GPipeline() - # prepare for workflow input + # Prepare workflow input prepared_input = WkFlowInput() - self.prepare(prepared_input, texts) + self.prepare(prepared_input, source_text, scenario, example_name) pipeline.createGParam(prepared_input, "wkflow_input") pipeline.createGParam(WkFlowState(), "wkflow_state") - chunk_split_node = ChunkSplitNode() - build_vector_node = BuildVectorIndexNode() - pipeline.registerGElement(chunk_split_node, set(), "chunk_split") - pipeline.registerGElement(build_vector_node, {chunk_split_node}, "build_vector") + # Create PromptGenerate node + prompt_generate_node = PromptGenerateNode() + pipeline.registerGElement(prompt_generate_node, set(), "prompt_generate") return pipeline def post_deal(self, pipeline=None): + """ + Process the execution result of PromptGenerate workflow + """ res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - return json.dumps(res, ensure_ascii=False, indent=2) + return res.get( + "generated_extract_prompt", "Generation failed. Please check the logs." + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py index b096310d..559540ce 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py @@ -15,10 +15,15 @@ import threading from typing import Dict, Any -from PyCGraph import GPipelineManager +from PyCGraph import GPipeline, GPipelineManager from hugegraph_llm.flows.build_vector_index import BuildVectorIndexFlow from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.flows.graph_extract import GraphExtractFlow +from hugegraph_llm.flows.import_graph_data import ImportGraphDataFlow +from hugegraph_llm.flows.update_vid_embeddings import UpdateVidEmbeddingsFlows +from hugegraph_llm.flows.get_graph_index_info import GetGraphIndexInfoFlow +from hugegraph_llm.flows.build_schema import BuildSchemaFlow +from hugegraph_llm.flows.prompt_generate import PromptGenerateFlow from hugegraph_llm.utils.log import log @@ -37,6 +42,26 @@ class Scheduler: "manager": GPipelineManager(), "flow": GraphExtractFlow(), } + self.pipeline_pool["import_graph_data"] = { + "manager": GPipelineManager(), + "flow": ImportGraphDataFlow(), + } + self.pipeline_pool["update_vid_embeddings"] = { + "manager": GPipelineManager(), + "flow": UpdateVidEmbeddingsFlows(), + } + self.pipeline_pool["get_graph_index_info"] = { + "manager": GPipelineManager(), + "flow": GetGraphIndexInfoFlow(), + } + self.pipeline_pool["build_schema"] = { + "manager": GPipelineManager(), + "flow": BuildSchemaFlow(), + } + self.pipeline_pool["prompt_generate"] = { + "manager": GPipelineManager(), + "flow": PromptGenerateFlow(), + } self.max_pipeline = max_pipeline # TODO: Implement Agentic Workflow @@ -46,9 +71,9 @@ class Scheduler: def schedule_flow(self, flow: str, *args, **kwargs): if flow not in self.pipeline_pool: raise ValueError(f"Unsupported workflow {flow}") - manager = self.pipeline_pool[flow]["manager"] + manager: GPipelineManager = self.pipeline_pool[flow]["manager"] flow: BaseFlow = self.pipeline_pool[flow]["flow"] - pipeline = manager.fetch() + pipeline: GPipeline = manager.fetch() if pipeline is None: # call coresponding flow_func to create new workflow pipeline = flow.build_flow(*args, **kwargs) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py similarity index 52% copy from hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py copy to hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py index f1ee8c1c..b3f0d992 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py @@ -13,43 +13,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hugegraph_llm.flows.common import BaseFlow -from hugegraph_llm.state.ai_state import WkFlowInput - -import json -from PyCGraph import GPipeline - -from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode -from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndexNode +from PyCGraph import CStatus, GPipeline +from hugegraph_llm.flows.common import BaseFlow, WkFlowInput +from hugegraph_llm.nodes.hugegraph_node.fetch_graph_data import FetchGraphDataNode +from hugegraph_llm.nodes.index_node.build_semantic_index import BuildSemanticIndexNode from hugegraph_llm.state.ai_state import WkFlowState -class BuildVectorIndexFlow(BaseFlow): - def __init__(self): - pass - - def prepare(self, prepared_input: WkFlowInput, texts): - prepared_input.texts = texts - prepared_input.language = "zh" - prepared_input.split_type = "paragraph" - return +class UpdateVidEmbeddingsFlows(BaseFlow): + def prepare(self, prepared_input: WkFlowInput): + return CStatus() - def build_flow(self, texts): + def build_flow(self): pipeline = GPipeline() - # prepare for workflow input prepared_input = WkFlowInput() - self.prepare(prepared_input, texts) + # prepare input data + self.prepare(prepared_input) pipeline.createGParam(prepared_input, "wkflow_input") pipeline.createGParam(WkFlowState(), "wkflow_state") - chunk_split_node = ChunkSplitNode() - build_vector_node = BuildVectorIndexNode() - pipeline.registerGElement(chunk_split_node, set(), "chunk_split") - pipeline.registerGElement(build_vector_node, {chunk_split_node}, "build_vector") + fetch_node = FetchGraphDataNode() + build_node = BuildSemanticIndexNode() + pipeline.registerGElement(fetch_node, set(), "fetch_node") + pipeline.registerGElement(build_node, {fetch_node}, "build_node") return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline): res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - return json.dumps(res, ensure_ascii=False, indent=2) + removed_num = res.get("removed_vid_vector_num", 0) + added_num = res.get("added_vid_vector_num", 0) + return f"Removed {removed_num} vectors, added {added_num} vectors." diff --git a/hugegraph-llm/src/hugegraph_llm/flows/utils.py b/hugegraph-llm/src/hugegraph_llm/flows/utils.py new file mode 100644 index 00000000..b4ba05c8 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/utils.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from hugegraph_llm.state.ai_state import WkFlowInput +from hugegraph_llm.utils.log import log + + +def prepare_schema(prepared_input: WkFlowInput, schema): + schema = schema.strip() + if schema.startswith("{"): + try: + schema = json.loads(schema) + prepared_input.schema = schema + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", schema) + prepared_input.graph_name = schema + return diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py new file mode 100644 index 00000000..0ea0675c --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import GNode, CStatus +from hugegraph_llm.nodes.util import init_context +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class BaseNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + """ + Node initialization method, can be overridden by subclasses. + Returns a CStatus object indicating whether initialization succeeded. + """ + return CStatus() + + def run(self): + """ + Main logic for node execution, can be overridden by subclasses. + Returns a CStatus object indicating whether execution succeeded. + """ + sts = self.node_init() + if sts.isErr(): + return sts + self.context.lock() + try: + data_json = self.context.to_json() + finally: + self.context.unlock() + + try: + res = self.operator_schedule(data_json) + except Exception as exc: + import traceback + + node_info = f"Node type: {type(self).__name__}, Node object: {self}" + err_msg = f"Node failed: {exc}\n{node_info}\n{traceback.format_exc()}" + return CStatus(-1, err_msg) + + self.context.lock() + try: + if isinstance(res, dict): + self.context.assign_from_json(res) + finally: + self.context.unlock() + return CStatus() + + def operator_schedule(self, data_json): + """ + Interface for scheduling the operator, can be overridden by subclasses. + Returns a CStatus object indicating whether scheduling succeeded. + """ + pass diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py new file mode 100644 index 00000000..4c5acbe9 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hugegraph_llm.nodes.base_node import BaseNode +from PyCGraph import CStatus +from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class ChunkSplitNode(BaseNode): + chunk_split_op: ChunkSplit + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + if ( + self.wk_input.texts is None + or self.wk_input.language is None + or self.wk_input.split_type is None + ): + return CStatus(-1, "Error occurs when prepare for workflow input") + texts = self.wk_input.texts + language = self.wk_input.language + split_type = self.wk_input.split_type + if isinstance(texts, str): + texts = [texts] + self.chunk_split_op = ChunkSplit(texts, split_type, language) + return CStatus() + + def operator_schedule(self, data_json): + return self.chunk_split_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py new file mode 100644 index 00000000..b576e817 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class Commit2GraphNode(BaseNode): + commit_to_graph_op: Commit2Graph + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + data_json = self.wk_input.data_json if self.wk_input.data_json else None + if data_json: + self.context.assign_from_json(data_json) + self.commit_to_graph_op = Commit2Graph() + return CStatus() + + def operator_schedule(self, data_json): + return self.commit_to_graph_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py new file mode 100644 index 00000000..b2434e52 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import FetchGraphData +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.utils.hugegraph_utils import get_hg_client + + +class FetchGraphDataNode(BaseNode): + fetch_graph_data_op: FetchGraphData + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + self.fetch_graph_data_op = FetchGraphData(get_hg_client()) + return CStatus() + + def operator_schedule(self, data_json): + return self.fetch_graph_data_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py new file mode 100644 index 00000000..71c490b2 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from PyCGraph import CStatus +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.common_op.check_schema import CheckSchema +from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.utils.log import log + + +class SchemaNode(BaseNode): + schema_manager: SchemaManager + check_schema: CheckSchema + context: WkFlowState = None + wk_input: WkFlowInput = None + + schema = None + + def _import_schema( + self, + from_hugegraph=None, + from_extraction=None, + from_user_defined=None, + ): + if from_hugegraph: + return SchemaManager(from_hugegraph) + elif from_user_defined: + return CheckSchema(from_user_defined) + elif from_extraction: + raise NotImplementedError("Not implemented yet") + else: + raise ValueError("No input data / invalid schema type") + + def node_init(self): + self.schema = self.wk_input.schema + self.schema = self.schema.strip() + if self.schema.startswith("{"): + try: + schema = json.loads(self.schema) + self.check_schema = self._import_schema(from_user_defined=schema) + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", self.schema) + self.schema_manager = self._import_schema(from_hugegraph=self.schema) + return CStatus() + + def operator_schedule(self, data_json): + print(f"check data json {data_json}") + if self.schema.startswith("{"): + try: + return self.check_schema.run(data_json) + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", self.schema) + return self.schema_manager.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py new file mode 100644 index 00000000..ab31fa39 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.build_semantic_index import BuildSemanticIndex +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class BuildSemanticIndexNode(BaseNode): + build_semantic_index_op: BuildSemanticIndex + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + self.build_semantic_index_op = BuildSemanticIndex(get_embedding(llm_settings)) + return CStatus() + + def operator_schedule(self, data_json): + return self.build_semantic_index_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py new file mode 100644 index 00000000..cf2f9b67 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndex +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class BuildVectorIndexNode(BaseNode): + build_vector_index_op: BuildVectorIndex + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + self.build_vector_index_op = BuildVectorIndex(get_embedding(llm_settings)) + return CStatus() + + def operator_schedule(self, data_json): + return self.build_vector_index_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py new file mode 100644 index 00000000..8bceed80 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.llm_op.info_extract import InfoExtract +from hugegraph_llm.operators.llm_op.property_graph_extract import PropertyGraphExtract +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class ExtractNode(BaseNode): + property_graph_extract: PropertyGraphExtract + info_extract: InfoExtract + context: WkFlowState = None + wk_input: WkFlowInput = None + + extract_type: str = None + + def node_init(self): + llm = get_chat_llm(llm_settings) + if self.wk_input.example_prompt is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + example_prompt = self.wk_input.example_prompt + extract_type = self.wk_input.extract_type + self.extract_type = extract_type + if extract_type == "triples": + self.info_extract = InfoExtract(llm, example_prompt) + elif extract_type == "property_graph": + self.property_graph_extract = PropertyGraphExtract(llm, example_prompt) + else: + return CStatus(-1, f"Unsupported extract_type: {extract_type}") + return CStatus() + + def operator_schedule(self, data_json): + if self.extract_type == "triples": + return self.info_extract.run(data_json) + elif self.extract_type == "property_graph": + return self.property_graph_extract.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py new file mode 100644 index 00000000..317f9e6a --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.llm_op.prompt_generate import PromptGenerate +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class PromptGenerateNode(BaseNode): + prompt_generate: PromptGenerate + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + """ + Node initialization method, initialize PromptGenerate operator + """ + llm = get_chat_llm(llm_settings) + if not all( + [ + self.wk_input.source_text, + self.wk_input.scenario, + self.wk_input.example_name, + ] + ): + return CStatus( + -1, + "Missing required parameters: source_text, scenario, or example_name", + ) + + self.prompt_generate = PromptGenerate(llm) + context = { + "source_text": self.wk_input.source_text, + "scenario": self.wk_input.scenario, + "example_name": self.wk_input.example_name, + } + self.context.assign_from_json(context) + return CStatus() + + def operator_schedule(self, data_json): + """ + Schedule the execution of PromptGenerate operator + """ + return self.prompt_generate.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py new file mode 100644 index 00000000..a28b4134 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from PyCGraph import CStatus +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.config import llm_settings +from hugegraph_llm.operators.llm_op.schema_build import SchemaBuilder +from hugegraph_llm.utils.log import log + + +class SchemaBuildNode(BaseNode): + schema_builder: SchemaBuilder + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + llm = get_chat_llm(llm_settings) + self.schema_builder = SchemaBuilder(llm) + + # texts -> raw_texts + raw_texts = [] + if self.wk_input.texts: + if isinstance(self.wk_input.texts, list): + raw_texts = [t for t in self.wk_input.texts if isinstance(t, str)] + elif isinstance(self.wk_input.texts, str): + raw_texts = [self.wk_input.texts] + + # query_examples: already parsed list[dict] or raw JSON string + query_examples = [] + qe_src = self.wk_input.query_examples if self.wk_input.query_examples else None + if qe_src: + try: + parsed_examples = json.loads(qe_src) + # Validate and retain the description and gremlin fields + query_examples = [ + { + "description": ex.get("description", ""), + "gremlin": ex.get("gremlin", ""), + } + for ex in parsed_examples + if isinstance(ex, dict) and "description" in ex and "gremlin" in ex + ] + except json.JSONDecodeError as e: + return CStatus(-1, f"Query Examples is not in a valid JSON format: {e}") + + # few_shot_schema: already parsed dict or raw JSON string + few_shot_schema = {} + fss_src = ( + self.wk_input.few_shot_schema if self.wk_input.few_shot_schema else None + ) + if fss_src: + try: + few_shot_schema = json.loads(fss_src) + except json.JSONDecodeError as e: + return CStatus( + -1, f"Few Shot Schema is not in a valid JSON format: {e}" + ) + + _context_payload = { + "raw_texts": raw_texts, + "query_examples": query_examples, + "few_shot_schema": few_shot_schema, + } + self.context.assign_from_json(_context_payload) + + return CStatus() + + def operator_schedule(self, data_json): + try: + schema_result = self.schema_builder.run(data_json) + + return {"schema": schema_result} + except Exception as e: + log.error("Failed to generate schema: %s", e) + return {"schema": f"Schema generation failed: {e}"} diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/util.py b/hugegraph-llm/src/hugegraph_llm/nodes/util.py new file mode 100644 index 00000000..60bdc2e8 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/util.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus + + +def init_context(obj) -> CStatus: + try: + obj.context = obj.getGParamWithNoEmpty("wkflow_state") + obj.wk_input = obj.getGParamWithNoEmpty("wkflow_input") + if obj.context is None or obj.wk_input is None: + return CStatus(-1, "Required workflow parameters not found") + return CStatus() + except Exception as e: + return CStatus(-1, f"Failed to initialize context: {str(e)}") diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py index 7a533517..c1c74203 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -20,12 +20,8 @@ from typing import Any, Optional, Dict from hugegraph_llm.enums.property_cardinality import PropertyCardinality from hugegraph_llm.enums.property_data_type import PropertyDataType -from hugegraph_llm.operators.util import init_context from hugegraph_llm.utils.log import log -from PyCGraph import GNode, CStatus -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState - def log_and_raise(message: str) -> None: log.warning(message) @@ -174,159 +170,3 @@ class CheckSchema: } ) property_label_set.add(prop) - - -class CheckSchemaNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - return init_context(self) - - def node_init(self): - if self.wk_input.schema is None: - return CStatus(-1, "Error occurs when prepare for workflow input") - self.data = self.wk_input.schema - return CStatus() - - def run(self) -> CStatus: - # init workflow input - sts = self.node_init() - if sts.isErr(): - return sts - # 1. Validate the schema structure - self.context.lock() - schema = self.data or self.context.schema - self._validate_schema(schema) - # 2. Process property labels and also create a set for it - property_labels, property_label_set = self._process_property_labels(schema) - # 3. Process properties in given vertex/edge labels - self._process_vertex_labels(schema, property_labels, property_label_set) - self._process_edge_labels(schema, property_labels, property_label_set) - # 4. Update schema with processed pks - schema["propertykeys"] = property_labels - self.context.schema = schema - self.context.unlock() - return CStatus() - - def _validate_schema(self, schema: Dict[str, Any]) -> None: - check_type(schema, dict, "Input data is not a dictionary.") - if "vertexlabels" not in schema or "edgelabels" not in schema: - log_and_raise("Input data does not contain 'vertexlabels' or 'edgelabels'.") - check_type( - schema["vertexlabels"], list, "'vertexlabels' in input data is not a list." - ) - check_type( - schema["edgelabels"], list, "'edgelabels' in input data is not a list." - ) - - def _process_property_labels(self, schema: Dict[str, Any]) -> (list, set): - property_labels = schema.get("propertykeys", []) - check_type( - property_labels, - list, - "'propertykeys' in input data is not of correct type.", - ) - property_label_set = {label["name"] for label in property_labels} - return property_labels, property_label_set - - def _process_vertex_labels( - self, schema: Dict[str, Any], property_labels: list, property_label_set: set - ) -> None: - for vertex_label in schema["vertexlabels"]: - self._validate_vertex_label(vertex_label) - properties = vertex_label["properties"] - primary_keys = self._process_keys( - vertex_label, "primary_keys", properties[:1] - ) - if len(primary_keys) == 0: - log_and_raise(f"'primary_keys' of {vertex_label['name']} is empty.") - vertex_label["primary_keys"] = primary_keys - nullable_keys = self._process_keys( - vertex_label, "nullable_keys", properties[1:] - ) - vertex_label["nullable_keys"] = nullable_keys - self._add_missing_properties( - properties, property_labels, property_label_set - ) - - def _process_edge_labels( - self, schema: Dict[str, Any], property_labels: list, property_label_set: set - ) -> None: - for edge_label in schema["edgelabels"]: - self._validate_edge_label(edge_label) - properties = edge_label.get("properties", []) - self._add_missing_properties( - properties, property_labels, property_label_set - ) - - def _validate_vertex_label(self, vertex_label: Dict[str, Any]) -> None: - check_type(vertex_label, dict, "VertexLabel in input data is not a dictionary.") - if "name" not in vertex_label: - log_and_raise("VertexLabel in input data does not contain 'name'.") - check_type( - vertex_label["name"], str, "'name' in vertex_label is not of correct type." - ) - if "properties" not in vertex_label: - log_and_raise("VertexLabel in input data does not contain 'properties'.") - check_type( - vertex_label["properties"], - list, - "'properties' in vertex_label is not of correct type.", - ) - if len(vertex_label["properties"]) == 0: - log_and_raise("'properties' in vertex_label is empty.") - - def _validate_edge_label(self, edge_label: Dict[str, Any]) -> None: - check_type(edge_label, dict, "EdgeLabel in input data is not a dictionary.") - if ( - "name" not in edge_label - or "source_label" not in edge_label - or "target_label" not in edge_label - ): - log_and_raise( - "EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'." - ) - check_type( - edge_label["name"], str, "'name' in edge_label is not of correct type." - ) - check_type( - edge_label["source_label"], - str, - "'source_label' in edge_label is not of correct type.", - ) - check_type( - edge_label["target_label"], - str, - "'target_label' in edge_label is not of correct type.", - ) - - def _process_keys( - self, label: Dict[str, Any], key_type: str, default_keys: list - ) -> list: - keys = label.get(key_type, default_keys) - check_type( - keys, list, f"'{key_type}' in {label['name']} is not of correct type." - ) - new_keys = [key for key in keys if key in label["properties"]] - return new_keys - - def _add_missing_properties( - self, properties: list, property_labels: list, property_label_set: set - ) -> None: - for prop in properties: - if prop not in property_label_set: - property_labels.append( - { - "name": prop, - "data_type": PropertyDataType.DEFAULT.value, - "cardinality": PropertyCardinality.DEFAULT.value, - } - ) - property_label_set.add(prop) - - def get_result(self): - self.context.lock() - res = self.context.to_json() - self.context.unlock() - return res diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py index d779a40a..c31e77af 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py @@ -19,8 +19,6 @@ from typing import Literal, Dict, Any, Optional, Union, List from langchain_text_splitters import RecursiveCharacterTextSplitter -from hugegraph_llm.operators.util import init_context -from PyCGraph import GNode, CStatus # Constants LANGUAGE_ZH = "zh" @@ -30,62 +28,6 @@ SPLIT_TYPE_PARAGRAPH = "paragraph" SPLIT_TYPE_SENTENCE = "sentence" -class ChunkSplitNode(GNode): - def init(self): - return init_context(self) - - def node_init(self): - if ( - self.wk_input.texts is None - or self.wk_input.language is None - or self.wk_input.split_type is None - ): - return CStatus(-1, "Error occurs when prepare for workflow input") - texts = self.wk_input.texts - language = self.wk_input.language - split_type = self.wk_input.split_type - if isinstance(texts, str): - texts = [texts] - self.texts = texts - self.separators = self._get_separators(language) - self.text_splitter = self._get_text_splitter(split_type) - return CStatus() - - def _get_separators(self, language: str) -> List[str]: - if language == LANGUAGE_ZH: - return ["\n\n", "\n", "。", ",", ""] - if language == LANGUAGE_EN: - return ["\n\n", "\n", ".", ",", " ", ""] - raise ValueError("language must be zh or en") - - def _get_text_splitter(self, split_type: str): - if split_type == SPLIT_TYPE_DOCUMENT: - return lambda text: [text] - if split_type == SPLIT_TYPE_PARAGRAPH: - return RecursiveCharacterTextSplitter( - chunk_size=500, chunk_overlap=30, separators=self.separators - ).split_text - if split_type == SPLIT_TYPE_SENTENCE: - return RecursiveCharacterTextSplitter( - chunk_size=50, chunk_overlap=0, separators=self.separators - ).split_text - raise ValueError("Type must be document, paragraph or sentence") - - def run(self): - sts = self.node_init() - if sts.isErr(): - return sts - all_chunks = [] - for text in self.texts: - chunks = self.text_splitter(text) - all_chunks.extend(chunks) - - self.context.lock() - self.context.chunks = all_chunks - self.context.unlock() - return CStatus() - - class ChunkSplit: def __init__( self, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py index 5cc846d2..9eec04f7 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py @@ -40,15 +40,19 @@ class Commit2Graph: schema = data.get("schema") vertices = data.get("vertices", []) edges = data.get("edges", []) - + print(f"get schema {schema}") if not vertices and not edges: - log.critical("(Loading) Both vertices and edges are empty. Please check the input data again.") + log.critical( + "(Loading) Both vertices and edges are empty. Please check the input data again." + ) raise ValueError("Both vertices and edges input are empty.") if not schema: # TODO: ensure the function works correctly (update the logic later) self.schema_free_mode(data.get("triples", [])) - log.warning("Using schema_free mode, could try schema_define mode for better effect!") + log.warning( + "Using schema_free mode, could try schema_define mode for better effect!" + ) else: self.init_schema_if_need(schema) self.load_into_graph(vertices, edges, schema) @@ -64,7 +68,9 @@ class Commit2Graph: # list or set default_value = [] input_properties[key] = default_value - log.warning("Property '%s' missing in vertex, set to '%s' for now", key, default_value) + log.warning( + "Property '%s' missing in vertex, set to '%s' for now", key, default_value + ) def _handle_graph_creation(self, func, *args, **kwargs): try: @@ -78,29 +84,42 @@ class Commit2Graph: def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many-statements # pylint: disable=R0912 (too-many-branches) - vertex_label_map = {v_label["name"]: v_label for v_label in schema["vertexlabels"]} + vertex_label_map = { + v_label["name"]: v_label for v_label in schema["vertexlabels"] + } edge_label_map = {e_label["name"]: e_label for e_label in schema["edgelabels"]} - property_label_map = {p_label["name"]: p_label for p_label in schema["propertykeys"]} + property_label_map = { + p_label["name"]: p_label for p_label in schema["propertykeys"] + } for vertex in vertices: input_label = vertex["label"] # 1. ensure the input_label in the graph schema if input_label not in vertex_label_map: - log.critical("(Input) VertexLabel %s not found in schema, skip & need check it!", input_label) + log.critical( + "(Input) VertexLabel %s not found in schema, skip & need check it!", + input_label, + ) continue input_properties = vertex["properties"] vertex_label = vertex_label_map[input_label] primary_keys = vertex_label["primary_keys"] nullable_keys = vertex_label.get("nullable_keys", []) - non_null_keys = [key for key in vertex_label["properties"] if key not in nullable_keys] + non_null_keys = [ + key for key in vertex_label["properties"] if key not in nullable_keys + ] has_problem = False # 2. Handle primary-keys mode vertex for pk in primary_keys: if not input_properties.get(pk): if len(primary_keys) == 1: - log.error("Primary-key '%s' missing in vertex %s, skip it & need check it again", pk, vertex) + log.error( + "Primary-key '%s' missing in vertex %s, skip it & need check it again", + pk, + vertex, + ) has_problem = True break # TODO: transform to Enum first (better in earlier step) @@ -110,14 +129,20 @@ class Commit2Graph: input_properties[pk] = default_value_map(data_type) else: input_properties[pk] = [] - log.warning("Primary-key '%s' missing in vertex %s, mark empty & need check it again!", pk, vertex) + log.warning( + "Primary-key '%s' missing in vertex %s, mark empty & need check it again!", + pk, + vertex, + ) if has_problem: continue # 3. Ensure all non-nullable props are set for key in non_null_keys: if key not in input_properties: - self._set_default_property(key, input_properties, property_label_map) + self._set_default_property( + key, input_properties, property_label_map + ) # 4. Check all data type value is right for key, value in input_properties.items(): @@ -125,14 +150,19 @@ class Commit2Graph: data_type = property_label_map[key]["data_type"] cardinality = property_label_map[key]["cardinality"] if not self._check_property_data_type(data_type, cardinality, value): - log.error("Property type/format '%s' is not correct, skip it & need check it again", key) + log.error( + "Property type/format '%s' is not correct, skip it & need check it again", + key, + ) has_problem = True break if has_problem: continue # TODO: we could try batch add vertices first, setback to single-mode if failed - vid = self._handle_graph_creation(self.client.graph().addVertex, input_label, input_properties).id + vid = self._handle_graph_creation( + self.client.graph().addVertex, input_label, input_properties + ).id vertex["id"] = vid for edge in edges: @@ -142,11 +172,16 @@ class Commit2Graph: properties = edge["properties"] if label not in edge_label_map: - log.critical("(Input) EdgeLabel %s not found in schema, skip & need check it!", label) + log.critical( + "(Input) EdgeLabel %s not found in schema, skip & need check it!", + label, + ) continue # TODO: we could try batch add edges first, setback to single-mode if failed - self._handle_graph_creation(self.client.graph().addEdge, label, start, end, properties) + self._handle_graph_creation( + self.client.graph().addEdge, label, start, end, properties + ) def init_schema_if_need(self, schema: dict): properties = schema["propertykeys"] @@ -170,19 +205,27 @@ class Commit2Graph: source_vertex_label = edge["source_label"] target_vertex_label = edge["target_label"] properties = edge["properties"] - self.schema.edgeLabel(edge_label).sourceLabel(source_vertex_label).targetLabel( - target_vertex_label - ).properties(*properties).nullableKeys(*properties).ifNotExist().create() + self.schema.edgeLabel(edge_label).sourceLabel( + source_vertex_label + ).targetLabel(target_vertex_label).properties(*properties).nullableKeys( + *properties + ).ifNotExist().create() def schema_free_mode(self, data): self.schema.propertyKey("name").asText().ifNotExist().create() - self.schema.vertexLabel("vertex").useCustomizeStringId().properties("name").ifNotExist().create() - self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel("vertex").properties( + self.schema.vertexLabel("vertex").useCustomizeStringId().properties( "name" ).ifNotExist().create() + self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel( + "vertex" + ).properties("name").ifNotExist().create() - self.schema.indexLabel("vertexByName").onV("vertex").by("name").secondary().ifNotExist().create() - self.schema.indexLabel("edgeByName").onE("edge").by("name").secondary().ifNotExist().create() + self.schema.indexLabel("vertexByName").onV("vertex").by( + "name" + ).secondary().ifNotExist().create() + self.schema.indexLabel("edgeByName").onE("edge").by( + "name" + ).secondary().ifNotExist().create() for item in data: s, p, o = (element.strip() for element in item) @@ -196,8 +239,12 @@ class Commit2Graph: data_type = PropertyDataType(prop["data_type"]) cardinality = PropertyCardinality(prop["cardinality"]) except ValueError: - log.critical("Invalid data type %s / cardinality %s for property %s, skip & should check it again", - prop["data_type"], prop["cardinality"], name) + log.critical( + "Invalid data type %s / cardinality %s for property %s, skip & should check it again", + prop["data_type"], + prop["cardinality"], + name, + ) return property_key = self.schema.propertyKey(name) @@ -231,7 +278,9 @@ class Commit2Graph: log.warning("UUID type is not supported, use text instead") property_key.asText() else: - log.error("Unknown data type %s for property_key %s", data_type, property_key) + log.error( + "Unknown data type %s for property_key %s", data_type, property_key + ) def _set_property_cardinality(self, property_key, cardinality): if cardinality == PropertyCardinality.SINGLE: @@ -241,10 +290,17 @@ class Commit2Graph: elif cardinality == PropertyCardinality.SET: property_key.valueSet() else: - log.error("Unknown cardinality %s for property_key %s", cardinality, property_key) - - def _check_property_data_type(self, data_type: str, cardinality: str, value) -> bool: - if cardinality in (PropertyCardinality.LIST.value, PropertyCardinality.SET.value): + log.error( + "Unknown cardinality %s for property_key %s", cardinality, property_key + ) + + def _check_property_data_type( + self, data_type: str, cardinality: str, value + ) -> bool: + if cardinality in ( + PropertyCardinality.LIST.value, + PropertyCardinality.SET.value, + ): return self._check_collection_data_type(data_type, value) return self._check_single_data_type(data_type, value) @@ -259,14 +315,21 @@ class Commit2Graph: def _check_single_data_type(self, data_type: str, value) -> bool: if data_type == PropertyDataType.BOOLEAN.value: return isinstance(value, bool) - if data_type in (PropertyDataType.BYTE.value, PropertyDataType.INT.value, PropertyDataType.LONG.value): + if data_type in ( + PropertyDataType.BYTE.value, + PropertyDataType.INT.value, + PropertyDataType.LONG.value, + ): return isinstance(value, int) if data_type in (PropertyDataType.FLOAT.value, PropertyDataType.DOUBLE.value): return isinstance(value, float) if data_type in (PropertyDataType.TEXT.value, PropertyDataType.UUID.value): return isinstance(value, str) # TODO: check ok below - if data_type == PropertyDataType.DATE.value: # the format should be "yyyy-MM-dd" + if ( + data_type == PropertyDataType.DATE.value + ): # the format should be "yyyy-MM-dd" import re - return isinstance(value, str) and re.match(r'^\d{4}-\d{2}-\d{2}$', value) + + return isinstance(value, str) and re.match(r"^\d{4}-\d{2}-\d{2}$", value) raise ValueError(f"Unknown/Unsupported data type: {data_type}") diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py index 670c18b4..c4e2124c 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py @@ -17,12 +17,8 @@ from typing import Dict, Any, Optional from hugegraph_llm.config import huge_settings -from hugegraph_llm.operators.util import init_context -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from pyhugegraph.client import PyHugeClient -from PyCGraph import GNode, CStatus - class SchemaManager: def __init__(self, graph_name: str): @@ -74,74 +70,3 @@ class SchemaManager: # TODO: enhance the logic here context["simple_schema"] = self.simple_schema(schema) return context - - -class SchemaManagerNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - return init_context(self) - - def node_init(self): - if self.wk_input.graph_name is None: - return CStatus(-1, "Error occurs when prepare for workflow input") - graph_name = self.wk_input.graph_name - self.graph_name = graph_name - self.client = PyHugeClient( - url=huge_settings.graph_url, - graph=self.graph_name, - user=huge_settings.graph_user, - pwd=huge_settings.graph_pwd, - graphspace=huge_settings.graph_space, - ) - self.schema = self.client.schema() - return CStatus() - - def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: - mini_schema = {} - - # Add necessary vertexlabels items (3) - if "vertexlabels" in schema: - mini_schema["vertexlabels"] = [] - for vertex in schema["vertexlabels"]: - new_vertex = { - key: vertex[key] - for key in ["id", "name", "properties"] - if key in vertex - } - mini_schema["vertexlabels"].append(new_vertex) - - # Add necessary edgelabels items (4) - if "edgelabels" in schema: - mini_schema["edgelabels"] = [] - for edge in schema["edgelabels"]: - new_edge = { - key: edge[key] - for key in ["name", "source_label", "target_label", "properties"] - if key in edge - } - mini_schema["edgelabels"].append(new_edge) - - return mini_schema - - def run(self) -> CStatus: - sts = self.node_init() - if sts.isErr(): - return sts - schema = self.schema.getSchema() - if not schema["vertexlabels"] and not schema["edgelabels"]: - raise Exception(f"Can not get {self.graph_name}'s schema from HugeGraph!") - - self.context.lock() - self.context.schema = schema - # TODO: enhance the logic here - self.context.simple_schema = self.simple_schema(schema) - self.context.unlock() - return CStatus() - - def get_result(self): - self.context.lock() - res = self.context.to_json() - self.context.unlock() - return res diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py index ee89d330..5cdad031 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py @@ -30,54 +30,6 @@ from hugegraph_llm.utils.embedding_utils import ( ) from hugegraph_llm.utils.log import log -from hugegraph_llm.operators.util import init_context -from hugegraph_llm.models.embeddings.init_embedding import get_embedding -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from PyCGraph import GNode, CStatus - - -class BuildVectorIndexNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - return init_context(self) - - def node_init(self): - self.embedding = get_embedding(llm_settings) - self.folder_name = get_index_folder_name( - huge_settings.graph_name, huge_settings.graph_space - ) - self.index_dir = str(os.path.join(resource_path, self.folder_name, "chunks")) - self.filename_prefix = get_filename_prefix( - llm_settings.embedding_type, getattr(self.embedding, "model_name", None) - ) - self.vector_index = VectorIndex.from_index_file( - self.index_dir, self.filename_prefix - ) - return CStatus() - - def run(self): - # init workflow input - sts = self.node_init() - if sts.isErr(): - return sts - self.context.lock() - try: - if self.context.chunks is None: - raise ValueError("chunks not found in context.") - chunks = self.context.chunks - finally: - self.context.unlock() - chunks_embedding = [] - log.debug("Building vector index for %s chunks...", len(chunks)) - # TODO: use async_get_texts_embedding instead of single sync method - chunks_embedding = asyncio.run(get_embeddings_parallel(self.embedding, chunks)) - if len(chunks_embedding) > 0: - self.vector_index.add(chunks_embedding, chunks) - self.vector_index.to_index_file(self.index_dir, self.filename_prefix) - return CStatus() - class BuildVectorIndex: def __init__(self, embedding: BaseEmbedding): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 15a8fdda..571ffde5 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -18,16 +18,10 @@ import re from typing import List, Any, Dict, Optional -from hugegraph_llm.config import llm_settings from hugegraph_llm.document.chunk_split import ChunkSplitter from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log -from hugegraph_llm.operators.util import init_context -from hugegraph_llm.models.llms.init_llm import get_chat_llm -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from PyCGraph import GNode, CStatus - SCHEMA_EXAMPLE_PROMPT = """## Main Task Extract Triples from the given text and graph schema @@ -213,143 +207,3 @@ class InfoExtract: if self.valid(edge["start"]) and self.valid(edge["end"]) ] return graph - - -class InfoExtractNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - return init_context(self) - - def node_init(self): - self.llm = get_chat_llm(llm_settings) - if self.wk_input.example_prompt is None: - return CStatus(-1, "Error occurs when prepare for workflow input") - self.example_prompt = self.wk_input.example_prompt - return CStatus() - - def extract_triples_by_regex_with_schema(self, schema, text): - text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") - pattern = r"\((.*?), (.*?), (.*?)\) - ([^ ]*)" - matches = re.findall(pattern, text) - - vertices_dict = {v["id"]: v for v in self.context.vertices} - for match in matches: - s, p, o, label = [item.strip() for item in match] - if None in [label, s, p, o]: - continue - # TODO: use a more efficient way to compare the extract & input property - p_lower = p.lower() - for vertex in schema["vertices"]: - if vertex["vertex_label"] == label and any( - pp.lower() == p_lower for pp in vertex["properties"] - ): - id = f"{label}-{s}" - if id not in vertices_dict: - vertices_dict[id] = { - "id": id, - "name": s, - "label": label, - "properties": {p: o}, - } - else: - vertices_dict[id]["properties"].update({p: o}) - break - for edge in schema["edges"]: - if edge["edge_label"] == label: - source_label = edge["source_vertex_label"] - source_id = f"{source_label}-{s}" - if source_id not in vertices_dict: - vertices_dict[source_id] = { - "id": source_id, - "name": s, - "label": source_label, - "properties": {}, - } - target_label = edge["target_vertex_label"] - target_id = f"{target_label}-{o}" - if target_id not in vertices_dict: - vertices_dict[target_id] = { - "id": target_id, - "name": o, - "label": target_label, - "properties": {}, - } - self.context.edges.append( - { - "start": source_id, - "end": target_id, - "type": label, - "properties": {}, - } - ) - break - self.context.vertices = list(vertices_dict.values()) - - def extract_triples_by_regex(self, text): - text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") - pattern = r"\((.*?), (.*?), (.*?)\)" - self.context.triples += re.findall(pattern, text) - - def run(self) -> CStatus: - sts = self.node_init() - if sts.isErr(): - return sts - self.context.lock() - if self.context.chunks is None: - self.context.unlock() - raise ValueError("parameter required by extract node not found in context.") - schema = self.context.schema - chunks = self.context.chunks - - if schema: - self.context.vertices = [] - self.context.edges = [] - else: - self.context.triples = [] - - self.context.unlock() - - for sentence in chunks: - proceeded_chunk = self.extract_triples_by_llm(schema, sentence) - log.debug( - "[Legacy] %s input: %s \n output:%s", - self.__class__.__name__, - sentence, - proceeded_chunk, - ) - if schema: - self.extract_triples_by_regex_with_schema(schema, proceeded_chunk) - else: - self.extract_triples_by_regex(proceeded_chunk) - - if self.context.call_count: - self.context.call_count += len(chunks) - else: - self.context.call_count = len(chunks) - self._filter_long_id() - return CStatus() - - def extract_triples_by_llm(self, schema, chunk) -> str: - prompt = generate_extract_triple_prompt(chunk, schema) - if self.example_prompt is not None: - prompt = self.example_prompt + prompt - return self.llm.generate(prompt=prompt) - - # TODO: make 'max_length' be a configurable param in settings.py/settings.cfg - def valid(self, element_id: str, max_length: int = 256) -> bool: - if len(element_id.encode("utf-8")) >= max_length: - log.warning("Filter out GraphElementID too long: %s", element_id) - return False - return True - - def _filter_long_id(self): - self.context.vertices = [ - vertex for vertex in self.context.vertices if self.valid(vertex["id"]) - ] - self.context.edges = [ - edge - for edge in self.context.edges - if self.valid(edge["start"]) and self.valid(edge["end"]) - ] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py index 6e492b8f..79fb33b4 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py @@ -21,16 +21,11 @@ import json import re from typing import List, Any, Dict -from hugegraph_llm.config import llm_settings, prompt +from hugegraph_llm.config import prompt from hugegraph_llm.document.chunk_split import ChunkSplitter from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log -from hugegraph_llm.operators.util import init_context -from hugegraph_llm.models.llms.init_llm import get_chat_llm -from hugegraph_llm.state.ai_state import WkFlowState, WkFlowInput -from PyCGraph import GNode, CStatus - # TODO: It is not clear whether there is any other dependence on the SCHEMA_EXAMPLE_PROMPT variable. # Because the SCHEMA_EXAMPLE_PROMPT variable will no longer change based on # prompt.extract_graph_prompt changes after the system loads, this does not seem to meet expectations. @@ -182,123 +177,3 @@ class PropertyGraphExtract: "Invalid property graph JSON! Please check the extracted JSON data carefully" ) return items - - -class PropertyGraphExtractNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - self.NECESSARY_ITEM_KEYS = {"label", "type", "properties"} # pylint: disable=invalid-name - return init_context(self) - - def node_init(self): - self.llm = get_chat_llm(llm_settings) - if self.wk_input.example_prompt is None: - return CStatus(-1, "Error occurs when prepare for workflow input") - self.example_prompt = self.wk_input.example_prompt - return CStatus() - - def run(self) -> CStatus: - sts = self.node_init() - if sts.isErr(): - return sts - self.context.lock() - try: - if self.context.schema is None or self.context.chunks is None: - raise ValueError( - "parameter required by extract node not found in context." - ) - schema = self.context.schema - chunks = self.context.chunks - if self.context.vertices is None: - self.context.vertices = [] - if self.context.edges is None: - self.context.edges = [] - finally: - self.context.unlock() - - items = [] - for chunk in chunks: - proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk) - log.debug( - "[LLM] %s input: %s \n output:%s", - self.__class__.__name__, - chunk, - proceeded_chunk, - ) - items.extend(self._extract_and_filter_label(schema, proceeded_chunk)) - items = filter_item(schema, items) - self.context.lock() - try: - for item in items: - if item["type"] == "vertex": - self.context.vertices.append(item) - elif item["type"] == "edge": - self.context.edges.append(item) - finally: - self.context.unlock() - self.context.call_count = (self.context.call_count or 0) + len(chunks) - return CStatus() - - def extract_property_graph_by_llm(self, schema, chunk): - prompt = generate_extract_property_graph_prompt(chunk, schema) - if self.example_prompt is not None: - prompt = self.example_prompt + prompt - return self.llm.generate(prompt=prompt) - - def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: - # Use regex to extract a JSON object with curly braces - json_match = re.search(r"({.*})", text, re.DOTALL) - if not json_match: - log.critical( - "Invalid property graph! No JSON object found, " - "please check the output format example in prompt." - ) - return [] - json_str = json_match.group(1).strip() - - items = [] - try: - property_graph = json.loads(json_str) - # Expect property_graph to be a dict with keys "vertices" and "edges" - if not ( - isinstance(property_graph, dict) - and "vertices" in property_graph - and "edges" in property_graph - ): - log.critical( - "Invalid property graph format; expecting 'vertices' and 'edges'." - ) - return items - - # Create sets for valid vertex and edge labels based on the schema - vertex_label_set = {vertex["name"] for vertex in schema["vertexlabels"]} - edge_label_set = {edge["name"] for edge in schema["edgelabels"]} - - def process_items(item_list, valid_labels, item_type): - for item in item_list: - if not isinstance(item, dict): - log.warning( - "Invalid property graph item type '%s'.", type(item) - ) - continue - if not self.NECESSARY_ITEM_KEYS.issubset(item.keys()): - log.warning("Invalid item keys '%s'.", item.keys()) - continue - if item["label"] not in valid_labels: - log.warning( - "Invalid %s label '%s' has been ignored.", - item_type, - item["label"], - ) - continue - items.append(item) - - process_items(property_graph["vertices"], vertex_label_set, "vertex") - process_items(property_graph["edges"], edge_label_set, "edge") - except json.JSONDecodeError: - log.critical( - "Invalid property graph JSON! Please check the extracted JSON data carefully" - ) - return items diff --git a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py index 0543aa2b..6d3418c0 100644 --- a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py +++ b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py @@ -25,6 +25,14 @@ class WkFlowInput(GParam): example_prompt: str = None # need by graph information extract schema: str = None # Schema information requeired by SchemaNode graph_name: str = None + data_json = None + extract_type = None + query_examples = None + few_shot_schema = None + # Fields related to PromptGenerate + source_text: str = None # Original text + scenario: str = None # Scenario description + example_name: str = None # Example name def reset(self, _: CStatus) -> None: self.texts = None @@ -33,6 +41,14 @@ class WkFlowInput(GParam): self.example_prompt = None self.schema = None self.graph_name = None + self.data_json = None + self.extract_type = None + self.query_examples = None + self.few_shot_schema = None + # PromptGenerate related configuration + self.source_text = None + self.scenario = None + self.example_name = None class WkFlowState(GParam): @@ -49,6 +65,8 @@ class WkFlowState(GParam): graph_result = None keywords_embeddings = None + generated_extract_prompt: Optional[str] = None + def setup(self): self.schema = None self.simple_schema = None @@ -63,6 +81,8 @@ class WkFlowState(GParam): self.graph_result = None self.keywords_embeddings = None + self.generated_extract_prompt = None + return CStatus() def to_json(self): @@ -79,3 +99,11 @@ class WkFlowState(GParam): for k, v in self.__dict__.items() if not k.startswith("_") and v is not None } + + # Implement a method that assigns keys from data_json as WkFlowState member variables + def assign_from_json(self, data_json: dict): + """ + Assigns each key in the input json object as a member variable of WkFlowState. + """ + for k, v in data_json.items(): + setattr(self, k, v) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index f61b5f84..ccace69f 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -36,6 +36,15 @@ from ..operators.kg_construction_task import KgBuilder def get_graph_index_info(): + try: + scheduler = SchedulerSingleton.get_instance() + return scheduler.schedule_flow("get_graph_index_info") + except Exception as e: # pylint: disable=broad-exception-caught + log.error(e) + raise gr.Error(str(e)) + + +def get_graph_index_info_old(): builder = KgBuilder( LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() ) @@ -150,6 +159,15 @@ def extract_graph(input_file, input_text, schema, example_prompt) -> str: def update_vid_embedding(): + scheduler = SchedulerSingleton.get_instance() + try: + return scheduler.schedule_flow("update_vid_embeddings") + except Exception as e: # pylint: disable=broad-exception-caught + log.error(e) + raise gr.Error(str(e)) + + +def update_vid_embedding_old(): builder = KgBuilder( LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() ) @@ -166,6 +184,18 @@ def update_vid_embedding(): def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: + try: + scheduler = SchedulerSingleton.get_instance() + return scheduler.schedule_flow("import_graph_data", data, schema) + except Exception as e: # pylint: disable=W0718 + log.error(e) + traceback.print_exc() + # Note: can't use gr.Error here + gr.Warning(str(e) + " Please check the graph data format/type carefully.") + return data + + +def import_graph_data_old(data: str, schema: str) -> Union[str, Dict[str, Any]]: try: data_json = json.loads(data.strip()) log.debug("Import graph data: %s", data) @@ -190,6 +220,16 @@ def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: def build_schema(input_text, query_example, few_shot): + scheduler = SchedulerSingleton.get_instance() + try: + return scheduler.schedule_flow( + "build_schema", input_text, query_example, few_shot + ) + except (TypeError, ValueError) as e: + raise gr.Error(f"Schema generation failed: {e}") + + +def build_schema_old(input_text, query_example, few_shot): context = { "raw_texts": [input_text] if input_text else [], "query_examples": [],
