This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch text2gql in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit ec1bf70030a34f8e9587089831ab0cfcf918c439 Author: Lriver <[email protected]> AuthorDate: Tue Sep 30 22:32:44 2025 +0800 Add qa_generalize.py: Seed data generalization using gremlin_checker and llm_handler --- .../Vertical_Text2Gremlin/qa_generalize.py | 130 +++++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/text2gremlin/Vertical_Text2Gremlin/qa_generalize.py b/text2gremlin/Vertical_Text2Gremlin/qa_generalize.py new file mode 100644 index 00000000..f6c20f19 --- /dev/null +++ b/text2gremlin/Vertical_Text2Gremlin/qa_generalize.py @@ -0,0 +1,130 @@ + +import pandas as pd +import time +from typing import List, Dict +from gremlin_checker import check_gremlin_syntax +from llm_handler import generate_gremlin_variations, generate_texts_for_gremlin + +INPUT_CSV_PATH = 'test_gremlin_qa_dataset.csv' # 种子qa数据 +OUTPUT_CSV_PATH = 'augmented_text2gremlin.csv' # 输出路径 +CHUNK_SIZE = 100 # 每次从CSV中读取的行数 +WRITE_THRESHOLD = 200 # 缓冲区中累积多少条新数据,持久化写入一次 +GROUP_SIZE = 5 # 泛化时参考的同组问题数量 + + + +def save_and_clear_buffer(buffer: List[Dict], is_first_write: bool) -> bool: + """ + 将缓冲区数据进行持久化,并清空缓冲区。 + """ + if not buffer: + return is_first_write + + print(f"\n--- 缓冲区达到阈值,正在写入 {len(buffer)} 条数据 ---") + df_new = pd.DataFrame(buffer) + df_new.to_csv( + OUTPUT_CSV_PATH, + mode='a', + header=is_first_write, + index=False, + encoding='utf-8-sig' + ) + print(f"✅ 成功保存到 {OUTPUT_CSV_PATH}") + buffer.clear() + + return False + +def process_group(gremlin_query: str, group_df: pd.DataFrame) -> List[Dict]: + """ + 处理单个完整的分组 + """ + print("\n" + "="*80) + print(f"正在处理分组: {gremlin_query[:100]}...") + + # 提取种子数据 + seed_questions = group_df['question'].tolist()[:GROUP_SIZE] + + # 生成Gremlin + print(f" Step 1: 正在调用LLM基于 {len(seed_questions)} 个问题生成Gremlin变体") + candidate_queries = generate_gremlin_variations(gremlin_query, seed_questions) + if not candidate_queries: + print(" -> LLM 未返回Gremlin变体,跳过此分组。") + return [] + + print(f" -> LLM 生成了 {len(candidate_queries)} 条候选查询:") + + # AST语法检查 + valid_queries = [] + for query in candidate_queries: + is_valid, msg = check_gremlin_syntax(query) + if is_valid: + valid_queries.append(query) + print(f" 语法正确: {query}") + else: + print(f" ❌ 语法失败: {query} | 原因: {msg}") + + if not valid_queries: + print(" -> 语法检查后无有效Gremlin,跳过此分组。") + return [] + # 日志记录 + new_data_for_group = [] + print(f" Step 2: 正在为 {len(valid_queries)} 条有效Gremlin生成Text") + for valid_query in valid_queries: + generated_texts = generate_texts_for_gremlin(valid_query) + if generated_texts: + print(f" -> 为查询 '{valid_query[:80]}...' 生成了 {len(generated_texts)} 个问题。") + for text in generated_texts: + new_data_for_group.append({'question': text, 'gremlin_query': valid_query}) + time.sleep(1) + + return new_data_for_group + +def main(): + is_first_write = True + write_buffer = [] + carry_over_df = pd.DataFrame() # 第一步的种子数据不会严格的按照要求的数量进行泛化,设置一个暂存区处理边界问题 + + try: + csv_reader = pd.read_csv(INPUT_CSV_PATH, chunksize=CHUNK_SIZE, iterator=True) + + for i, chunk_df in enumerate(csv_reader): + print("\n" + "#"*30 + f" 开始处理数据块 Chunk {i+1} " + "#"*30) + + current_data = pd.concat([carry_over_df, chunk_df], ignore_index=True) + if current_data.empty: + continue + + last_query_in_chunk = current_data.iloc[-1]['gremlin_query'] # 找到末尾那个可能不完整的gremlin + carry_over_df = current_data[current_data['gremlin_query'] == last_query_in_chunk].copy() # 暂存,留到下一轮检查 + df_to_process = current_data.drop(carry_over_df.index) + + if not df_to_process.empty: + grouped = df_to_process.groupby('gremlin_query',sort=False) + for gremlin_query, group_df in grouped: + new_data = process_group(gremlin_query, group_df) + if new_data: + write_buffer.extend(new_data) + + if len(write_buffer) >= WRITE_THRESHOLD: + is_first_write = save_and_clear_buffer(write_buffer, is_first_write) + + print("\n" + "#"*30 + " 开始处理最后剩余的数据 " + "#"*30) + if not carry_over_df.empty: + final_grouped = carry_over_df.groupby('gremlin_query',sort=False) + for gremlin_query, group_df in final_grouped: + new_data = process_group(gremlin_query, group_df) + if new_data: + write_buffer.extend(new_data) + + print("\n--- 正在执行最后的写入操作... ---") + save_and_clear_buffer(write_buffer, is_first_write) + + except FileNotFoundError: + print(f"错误: 输入文件未找到 '{INPUT_CSV_PATH}'") + except Exception as e: + print(f"发生未知错误: {e}") + + print("\nQA泛化完成!") + +if __name__ == '__main__': + main() \ No newline at end of file
