This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch text2gql in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit 56d88de36249b9ef0ede4f6f641b22eda7e00da6 Author: Lriver <[email protected]> AuthorDate: Tue Sep 30 21:05:41 2025 +0800 feat: add main corpus generator with batch processing, global deduplication and error handling --- text2gremlin/AST_Text2Gremlin/base/generator.py | 540 ++++++++++++++++++++++++ 1 file changed, 540 insertions(+) diff --git a/text2gremlin/AST_Text2Gremlin/base/generator.py b/text2gremlin/AST_Text2Gremlin/base/generator.py new file mode 100644 index 00000000..fe5eb617 --- /dev/null +++ b/text2gremlin/AST_Text2Gremlin/base/generator.py @@ -0,0 +1,540 @@ + +""" +Gremlin语料库生成器主入口脚本。 + +从Gremlin查询模板生成大量多样化的查询-描述对,用于Text-to-Gremlin任务的训练数据。 +""" + +import os +import json +from antlr4 import InputStream, CommonTokenStream +from antlr4.error.ErrorListener import ErrorListener + +# Import all our custom modules from the gremlin_base package +from Config import Config +from Schema import Schema +from GremlinBase import GremlinBase +from GremlinParse import Traversal +from TraversalGenerator import TraversalGenerator +from GremlinTransVisitor import GremlinTransVisitor + +# Import the ANTLR-generated components +from gremlin.GremlinLexer import GremlinLexer +from gremlin.GremlinParser import GremlinParser +import random + +class SyntaxErrorListener(ErrorListener): + """私有错误监听器类,捕获语法错误。""" + + def __init__(self): + super().__init__() + self.has_error = False + self.error_message = "" + + def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): + """当语法错误发生时,此方法被调用。""" + self.has_error = True + self.error_message = f"Syntax Error at line {line}, column {column}: {msg}" + +def check_gremlin_syntax(query_string: str) -> tuple[bool, str]: + """ + 检查给定的Gremlin查询语句的语法。 + + Args: + query_string: The Gremlin query to check. + + Returns: + A tuple containing: + - bool: True if syntax is correct, False otherwise. + - str: An error message if syntax is incorrect, or "Syntax OK" if correct. + """ + try: + input_stream = InputStream(query_string) + lexer = GremlinLexer(input_stream) + token_stream = CommonTokenStream(lexer) + parser = GremlinParser(token_stream) + + # 移除默认的控制台错误监听器 + parser.removeErrorListeners() + + # 添加自定义的监听器 + error_listener = SyntaxErrorListener() + parser.addErrorListener(error_listener) + + # 尝试解析查询 + parser.queryList() + + if error_listener.has_error: + return (False, error_listener.error_message) + else: + return (True, "Syntax OK") + + except Exception as e: + return (False, f"Parser Exception: {str(e)}") + +def generate_corpus_from_template( + template_string: str, + config: Config, + schema: Schema, + gremlin_base: GremlinBase, + global_corpus_dict: dict +) -> tuple[int, dict]: + """ + 执行单个 Gremlin 模板字符串的完整 pipeline。 + + Args: + template_string: 用作模板的 Gremlin query。 + config: 加载的 Config 对象。 + schema: 加载的 Schema 对象。 + gremlin_base: 加载的 GremlinBase 对象。 + global_corpus_dict: 用于存储唯一 query-description 对的全局字典。 + + Returns: + tuple: (添加到全局语料库的新的唯一对的数量, 处理统计信息) + """ + # 初始化统计信息 + stats = { + 'success': False, + 'error_stage': '', + 'error_message': '', + 'generated_count': 0, + 'new_pairs_count': 0, + 'duplicate_count': 0, + 'syntax_error_count': 0 + } + + try: + # ANTLR 解析为 AST,并提取模版 + visitor = GremlinTransVisitor() + recipe = visitor.parse_and_visit(template_string) + + if not recipe: + stats['error_stage'] = 'recipe_extraction' + stats['error_message'] = 'Recipe extraction failed' + return 0, stats + + if not hasattr(recipe, 'steps') or not recipe.steps: + stats['error_stage'] = 'recipe_validation' + stats['error_message'] = 'Recipe has no steps' + return 0, stats + + # 泛化 + generator = TraversalGenerator(schema, recipe, gremlin_base) + corpus = generator.generate() + + if not corpus: + stats['error_stage'] = 'generation' + stats['error_message'] = 'Generator returned empty corpus' + return 0, stats + + stats['generated_count'] = len(corpus) + + # 语法检查 & 全局去重 + new_pairs_count = 0 + duplicate_count = 0 + syntax_error_count = 0 + + for query, description in corpus: + try: + # 首先进行语法检查 + is_valid, error_msg = check_gremlin_syntax(query) + + if not is_valid: + syntax_error_count += 1 + continue + + if query not in global_corpus_dict: + # 新的查询且语法正确,添加到全局字典 + global_corpus_dict[query] = description + new_pairs_count += 1 + else: + # 重复的查询,跳过 + duplicate_count += 1 + + except Exception as e: + syntax_error_count += 1 + continue + + # 更新统计信息 + stats['new_pairs_count'] = new_pairs_count + stats['duplicate_count'] = duplicate_count + stats['syntax_error_count'] = syntax_error_count + stats['success'] = True + + # 添加生成数量的警告信息 + if stats['generated_count'] > 5000: + stats['warning'] = f'由于本条模版的Recip复杂,生成了大量查询({stats["generated_count"]}条)' + elif new_pairs_count == 0 and stats['generated_count'] > 0: + stats['warning'] = f'生成了{stats["generated_count"]}条查询但全部重复' + + return new_pairs_count, stats + + except Exception as e: + # 捕获所有其他异常 + stats['error_stage'] = 'unknown' + stats['error_message'] = str(e) + return 0, stats + + +def generate_corpus_from_templates(templates: list[str], + config_path: str = None, + schema_path: str = None, + data_path: str = None, + output_file: str = "generated_corpus.json") -> dict: + """ + 从Gremlin模板列表生成完整的语料库。 + + Args: + templates: Gremlin查询模板列表 + config_path: 配置文件路径 + schema_path: Schema文件路径 + data_path: 数据文件路径 + output_file: 输出文件名 + + Returns: + 包含生成统计信息的字典 + """ + # --- Setup: Define paths and load dependencies --- + if not config_path or not schema_path or not data_path: + # 自动检测项目根目录 + current_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.dirname(current_dir) # 从base目录向上一级 + + schema_path = schema_path or os.path.join(project_root, 'db_data', 'schema', 'movie_schema.json') + data_path = data_path or os.path.join(project_root, 'db_data') + config_path = config_path or os.path.join(project_root, 'config.json') + + if not all(os.path.exists(p) for p in [config_path, schema_path, data_path]): + raise FileNotFoundError("Could not find necessary config, schema, or data files.") + + # Load all necessary components once + config = Config(file_path=config_path) + schema = Schema(schema_path, data_path) + gremlin_base = GremlinBase(config) + + # --- Run the generation process for each template with global deduplication --- + global_corpus_dict = {} # 使用字典进行去重,key是query,value是description + total_new_pairs = 0 + + # 处理统计信息 + processing_stats = { + 'total_templates': len(templates), + 'successful_templates': 0, + 'failed_templates': 0, + 'failed_details': [], + 'total_generated': 0, + 'total_syntax_errors': 0, + 'total_duplicates': 0 + } + + print(f"🚀 开始处理 {len(templates)} 个模板...") + + for i, template in enumerate(templates, 1): + try: + new_pairs_count, template_stats = generate_corpus_from_template( + template_string=template, + config=config, + schema=schema, + gremlin_base=gremlin_base, + global_corpus_dict=global_corpus_dict + ) + + total_new_pairs += new_pairs_count + + # 更新统计信息 + if template_stats['success']: + processing_stats['successful_templates'] += 1 + processing_stats['total_generated'] += template_stats['generated_count'] + processing_stats['total_syntax_errors'] += template_stats['syntax_error_count'] + processing_stats['total_duplicates'] += template_stats['duplicate_count'] + + # 根据情况显示不同的消息 + if new_pairs_count == 0 and template_stats['generated_count'] > 0: + print(f"[{i}/{len(templates)}] ⚠️ 生成 {template_stats['generated_count']} 条查询但全部重复") + elif template_stats['generated_count'] > 5000: + print(f"[{i}/{len(templates)}] ⚡ 大量生成 {new_pairs_count} 条新查询 (总生成{template_stats['generated_count']}条)") + else: + print(f"[{i}/{len(templates)}] ✅ 成功生成 {new_pairs_count} 条新查询") + else: + processing_stats['failed_templates'] += 1 + processing_stats['failed_details'].append({ + 'template_index': i, + 'template': template[:100] + '...' if len(template) > 100 else template, + 'error_stage': template_stats['error_stage'], + 'error_message': template_stats['error_message'] + }) + print(f"[{i}/{len(templates)}] ❌ 处理失败: {template_stats['error_message']}") + + except Exception as e: + # 处理单个模板时的意外错误 + processing_stats['failed_templates'] += 1 + processing_stats['failed_details'].append({ + 'template_index': i, + 'template': template[:100] + '...' if len(template) > 100 else template, + 'error_stage': 'unexpected_error', + 'error_message': str(e) + }) + print(f"[{i}/{len(templates)}] ❌ 意外错误: {str(e)}") + continue # 继续处理下一个模板 + + # 转换为列表格式以便后续处理 + full_corpus = [(query, desc) for query, desc in global_corpus_dict.items()] + + # --- Save the full corpus to a local file --- + # 确保只保存成功生成的查询-描述对 + from datetime import datetime + + corpus_data = { + "metadata": { + "total_templates": len(templates), + "successful_templates": processing_stats['successful_templates'], + "failed_templates": processing_stats['failed_templates'], + "total_unique_queries": len(full_corpus), + "generation_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + }, + "corpus": [ + { + "query": query, + "description": desc + } + for query, desc in full_corpus + ] + } + + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(corpus_data, f, ensure_ascii=False, indent=2) + + # --- Generate statistics and display results --- + stats = _generate_statistics(templates, full_corpus, output_file) + stats.update({ + "total_templates": len(templates), + "successful_templates": processing_stats['successful_templates'], + "failed_templates": processing_stats['failed_templates'], + "output_file": output_file + }) + _display_final_results(full_corpus, stats) + + return { + "total_templates": len(templates), + "successful_templates": processing_stats['successful_templates'], + "failed_templates": processing_stats['failed_templates'], + "total_unique_queries": len(full_corpus), + "output_file": output_file, + "statistics": stats + } + +def _generate_statistics(templates: list, full_corpus: list, output_file: str) -> dict: + """生成统计信息""" + # 按查询长度分类统计 + length_stats = {} + for query, _ in full_corpus: + steps = query.count('.') + length_stats[steps] = length_stats.get(steps, 0) + 1 + + # 按操作类型分类统计 + operation_stats = { + "查询(V/E)": 0, + "创建(addV/addE)": 0, + "更新(property)": 0, + "删除(drop)": 0 + } + + for query, _ in full_corpus: + if query.startswith('g.V(') or query.startswith('g.E('): + if '.drop()' in query: + operation_stats["删除(drop)"] += 1 + elif '.property(' in query: + operation_stats["更新(property)"] += 1 + else: + operation_stats["查询(V/E)"] += 1 + elif '.addV(' in query or '.addE(' in query: + operation_stats["创建(addV/addE)"] += 1 + + return { + "length_stats": length_stats, + "operation_stats": operation_stats, + "avg_per_template": len(full_corpus) / len(templates) if templates else 0 + } + +def _display_final_results(full_corpus: list, stats: dict): + """显示最终生成结果和统计信息""" + print(f"\n{'='*50}") + print(f"📊 生成完成统计") + print(f"{'='*50}") + print(f"处理的模板数量: {stats.get('total_templates', 0)}") + print(f"成功处理: {stats.get('successful_templates', 0)}") + print(f"处理失败: {stats.get('failed_templates', 0)}") + print(f"生成的独特查询数量: {len(full_corpus)}") + print(f"语料库已保存到: {stats.get('output_file', 'generated_corpus.json')}") + + # 按查询长度分类统计 + print(f"\n{'='*50}") + print("📈 查询复杂度分析:") + print(f"{'='*50}") + + for steps in sorted(stats['length_stats'].keys()): + print(f" {steps}步查询: {stats['length_stats'][steps]} 个") + + # 按操作类型分类统计 + print(f"\n{'='*50}") + print("🔍 操作类型分析:") + print(f"{'='*50}") + + for op_type, count in stats['operation_stats'].items(): + percentage = (count / len(full_corpus)) * 100 if full_corpus else 0 + print(f" {op_type}: {count} 个 ({percentage:.1f}%)") + + print(f"\n{'='*50}") + print(f"✅ 生成完成!共生成 {len(full_corpus)} 个独特查询") + print(f"{'='*50}") + + + +if __name__ == '__main__': + # templates = [ + # # === 查询操作 (Query) - 40% === + + # # 基础查询 + # "g.V().has('name', 'John')", + # "g.V().has('title', 'The Matrix')", + # "g.V().has('born', 1961)", + # "g.V().hasLabel('person')", + # "g.V().hasLabel('movie')", + + # # 导航查询 + # "g.V().has('name', 'Laurence Fishburne').out('acted_in')", + # "g.V().has('title', 'The Matrix').in('acted_in')", + # "g.V().hasLabel('person').out('directed')", + # "g.V().hasLabel('movie').in('rate')", + + # # 复杂查询 + # "g.V().has('name', 'Laurence Fishburne').out('acted_in').has('title', 'The Matrix')", + # "g.V().hasLabel('person').out('acted_in').in('rate')", + # "g.V().has('title', 'Matrix').in('acted_in').out('directed')", + + # # === 创建操作 (Create) - 25% === + + # # 基础创建 + # "g.addV('person')", + # "g.addV('movie')", + # "g.addV('user')", + + # # 带属性创建 + # "g.addV('person').property('name', 'New Actor')", + # "g.addV('movie').property('title', 'New Movie')", + # "g.addV('person').property('name', 'Jane').property('born', 1990)", + # "g.addV('movie').property('title', 'Test Movie').property('duration', 120)", + # "g.addV('user').property('login', 'newuser').property('name', 'New User')", + + # # === 更新操作 (Update) - 25% === + + # # 单属性更新 + # "g.V().has('name', 'John').property('born', 1990)", + # "g.V().has('title', 'Test').property('duration', 120)", + # "g.V().hasLabel('person').has('name', 'Jane').property('born', 1985)", + # "g.V().hasLabel('movie').has('title', 'Old Movie').property('rated', 'PG-13')", + + # # 多属性更新 + # "g.V().has('name', 'John').property('born', 1990).property('poster_image', 'new_url')", + # "g.V().has('title', 'Test').property('duration', 150).property('rated', 'R')", + # "g.V().hasLabel('user').has('login', 'testuser').property('name', 'Updated Name').property('born', 1995)", + + # # === 删除操作 (Delete) - 10% === + + # # 基础删除 + # "g.V().has('name', 'temp_person').drop()", + # "g.V().has('title', 'temp_movie').drop()", + # "g.V().hasLabel('user').has('login', 'temp_user').drop()", + + # # 条件删除 + # "g.V().hasLabel('person').has('born', 0).drop()", + # "g.V().hasLabel('movie').has('duration', 0).drop()", + # ] + + def load_templates_from_csv(csv_file_path: str) -> tuple[list[str], dict]: + """ + 从CSV文件中加载Gremlin查询作为模板 + + Args: + csv_file_path: CSV文件路径 + + Returns: + tuple: (成功加载的查询列表, 统计信息字典) + """ + import csv + + templates = [] + stats = { + 'total_rows': 0, + 'successful_loads': 0, + 'failed_loads': 0, + 'failed_queries': [] + } + + try: + with open(csv_file_path, 'r', encoding='utf-8') as file: + reader = csv.DictReader(file) + + for row_num, row in enumerate(reader, 1): + stats['total_rows'] += 1 + + try: + # 获取gremlin_query列 + gremlin_query = row.get('gremlin_query', '').strip() + + if not gremlin_query: + stats['failed_loads'] += 1 + stats['failed_queries'].append(f"第{row_num}行: 空查询") + continue + + # 移除可能的引号包围 + if gremlin_query.startswith('"') and gremlin_query.endswith('"'): + gremlin_query = gremlin_query[1:-1] + + # 基本语法检查 + if not gremlin_query.startswith('g.'): + stats['failed_loads'] += 1 + stats['failed_queries'].append(f"第{row_num}行: 格式错误") + continue + + templates.append(gremlin_query) + stats['successful_loads'] += 1 + + except Exception as e: + stats['failed_loads'] += 1 + stats['failed_queries'].append(f"第{row_num}行: {str(e)}") + continue + + except FileNotFoundError: + print(f"❌ 错误: 找不到CSV文件: {csv_file_path}") + return [], stats + except Exception as e: + print(f"❌ 读取CSV文件时发生错误: {str(e)}") + return [], stats + + return templates, stats + + # 从CSV文件加载模板 + csv_file_path = "cypher2gremlin_dataset.csv" + + print(f"🔄 从 {csv_file_path} 加载Gremlin查询模板...") + templates, load_stats = load_templates_from_csv(csv_file_path) + + print(f"📊 CSV加载统计: {load_stats['successful_loads']}/{load_stats['total_rows']} 成功") + + if load_stats['failed_loads'] > 0: + print(f"⚠️ {load_stats['failed_loads']} 个模板加载失败") + + if not templates: + print("❌ 没有成功加载任何模板,程序退出") + exit(1) + + print(f"✅ 成功加载 {len(templates)} 个模板,开始生成语料库...") + + # 生成语料库 + try: + result = generate_corpus_from_templates(templates) + except Exception as e: + print(f"❌ 生成过程中发生错误: {str(e)}") + import traceback + traceback.print_exc() \ No newline at end of file
