This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch text2gql
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git

commit 41ae008c82ccc4bbd4184cd7b723f970aa03cc9e
Author: Lriver <[email protected]>
AuthorDate: Tue Sep 30 22:33:01 2025 +0800

    Add instruct_convert.py: Instruction format conversion and train/test set 
division
---
 .../Vertical_Text2Gremlin/instruct_convert.py      | 104 +++++++++++++++++++++
 1 file changed, 104 insertions(+)

diff --git a/text2gremlin/Vertical_Text2Gremlin/instruct_convert.py 
b/text2gremlin/Vertical_Text2Gremlin/instruct_convert.py
new file mode 100644
index 00000000..1b5ecd75
--- /dev/null
+++ b/text2gremlin/Vertical_Text2Gremlin/instruct_convert.py
@@ -0,0 +1,104 @@
+import pandas as pd
+import json
+import os
+import random
+
+INPUT_CSV_PATH = 'test_gremlin_qa_dataset.csv' 
+OUTPUT_JSON_PATH = 'instruct_data.json'
+CHUNK_SIZE = 200  # 每次从CSV读取的行数
+
+
+TRAIN_RATIO = 0.7  # 训练集与测试集划分比例
+TRAIN_OUTPUT_PATH = './train_data/train_dataset.json' # 训练集输出文件
+TEST_OUTPUT_PATH = './train_data/test_dataset.json'   # 测试集输出文件
+
+def convert_csv_to_json():
+
+    print("数据转换 CSV -> JSON")
+    print(f"输入: '{INPUT_CSV_PATH}'")
+    print(f"输出: '{OUTPUT_JSON_PATH}'")
+    
+    instruction_text = 
"你是一位精通图数据库查询语言Gremlin的专家。你的任务是根据用户输入的自然语言问题,将其准确地转换为对应的Gremlin查询语句。"
+    is_first_object = True
+
+    try:
+        with open(OUTPUT_JSON_PATH, 'w', encoding='utf-8') as f:
+            f.write('[\n')
+            csv_reader = pd.read_csv(INPUT_CSV_PATH, chunksize=CHUNK_SIZE, 
iterator=True)
+            
+            total_rows_processed = 0
+            
+            for i, chunk_df in enumerate(csv_reader):
+                for index, row in chunk_df.iterrows():
+                    if pd.notna(row['question']) and 
pd.notna(row['gremlin_query']):
+                        
+                        if not is_first_object:
+                            f.write(',\n')
+                        
+                        formatted_data = {
+                            "instruction": instruction_text,
+                            "input": row['question'],
+                            "output": row['gremlin_query']
+                        }
+                        
+                        json_string = json.dumps(formatted_data, 
ensure_ascii=False, indent=2)
+                        f.write(json_string)
+                        
+                        is_first_object = False
+                        total_rows_processed += 1
+                
+                print(f"  已处理 {i+1} 个数据块,累计处理 {total_rows_processed} 行...")
+
+            f.write('\n]')
+            
+            print(f"\n 数据转换完成!总共转换了 {total_rows_processed} 条数据,保存文件至 
{OUTPUT_JSON_PATH}")
+            return True 
+
+    except FileNotFoundError:
+        print(f"错误: 输入文件未找到 '{INPUT_CSV_PATH}'。请检查文件名和路径。")
+        return False
+    except Exception as e:
+        print(f"发生未知错误: {e}")
+        return False
+
+def split_and_shuffle_dataset():
+    """
+    随机打乱,按比例划分训练集和测试集。
+    """
+    try:
+        print(f"  正在从 '{OUTPUT_JSON_PATH}' 加载数据到内存...")
+        with open(OUTPUT_JSON_PATH, 'r', encoding='utf-8') as f:
+            data = json.load(f)
+        print(f"  加载了 {len(data)} 条数据。")
+
+        # 随机打乱数据
+        random.shuffle(data)
+
+        # 划分数据集
+        split_index = int(len(data) * TRAIN_RATIO) # 计算划分点
+        train_data = data[:split_index]
+        test_data = data[split_index:]
+        
+        print(f"  数据划分完毕: {len(train_data)} 条训练数据, {len(test_data)} 条测试数据。")
+
+        print(f"  保存训练集到 '{TRAIN_OUTPUT_PATH}'...")
+        with open(TRAIN_OUTPUT_PATH, 'w', encoding='utf-8') as f:
+            json.dump(train_data, f, ensure_ascii=False, indent=2)
+
+        print(f"  保存测试集到 '{TEST_OUTPUT_PATH}'...")
+        with open(TEST_OUTPUT_PATH, 'w', encoding='utf-8') as f:
+            json.dump(test_data, f, ensure_ascii=False, indent=2)
+            
+        print(f"\n 数据集已成功划分为训练集和测试集。")
+
+    except FileNotFoundError:
+        print(f"错误: JSON文件 '{OUTPUT_JSON_PATH}' 未找到。请先确保步骤1成功执行。")
+    except Exception as e:
+        print(f"在划分数据集时发生错误: {e}")
+
+
+if __name__ == '__main__':
+    if convert_csv_to_json():
+        split_and_shuffle_dataset()
+    else:
+        print("\n数据转换失败,停止后续操作。")
\ No newline at end of file

Reply via email to