This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch text2gql in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit b69ad852c4fc2061bfe2b162954dcd5c6937e321 Author: Lriver <[email protected]> AuthorDate: Tue Sep 30 21:05:27 2025 +0800 feat: add recursive backtracking traversal generator for diverse query variants from Recipe --- .../AST_Text2Gremlin/base/TraversalGenerator.py | 481 +++++++++++++++++++++ 1 file changed, 481 insertions(+) diff --git a/text2gremlin/AST_Text2Gremlin/base/TraversalGenerator.py b/text2gremlin/AST_Text2Gremlin/base/TraversalGenerator.py new file mode 100644 index 00000000..4a1fc5a1 --- /dev/null +++ b/text2gremlin/AST_Text2Gremlin/base/TraversalGenerator.py @@ -0,0 +1,481 @@ +""" +Gremlin查询生成器核心引擎。 + +基于递归回溯算法,从结构化配方生成大量多样化的Gremlin查询及其中文描述。 +""" + +import os +import random +import string +from typing import List, Dict, Any, Tuple, Set + +# 导入我们定义好的核心数据结构和 Schema +from Schema import Schema +from GremlinParse import Traversal, Step +from GremlinExpr import Predicate +from GremlinBase import GremlinBase +from Config import Config + +class TraversalGenerator: + def __init__(self, schema: Schema, recipe: Traversal, gremlin_base: GremlinBase): + self.schema = schema + self.recipe = recipe + self.gremlin_base = gremlin_base + self.generated_pairs: Set[Tuple[str, str]] = set() + + def generate(self) -> List[Tuple[str, str]]: + """ + 主生成函数,负责启动递归生成过程。 + + Returns: + 一个去重后的 (查询, 描述) 对的列表。 + """ + self.generated_pairs.clear() + self._recursive_generate( + recipe_steps=self.recipe.steps, + current_query="g", + current_desc="从图中开始", + current_label=None, + current_type='graph' + ) + return list(self.generated_pairs) + + def _get_random_value(self, label: str, prop_info: Dict, for_update: bool = False) -> Any: + """根据属性类型,智能地生成一个随机值。""" + prop_name, prop_type = prop_info['name'], prop_info['type'] + instance = self.schema.get_instance(label) + if instance and prop_name in instance and not for_update: + value = instance.get(prop_name) + if value is not None: return value + if prop_type == 'STRING': + return ''.join(random.choices(string.ascii_letters, k=random.randint(5, 8))) + if prop_type in ['INT32', 'INT64']: + return random.randint(1, 10000) + return "default_value" + + def _get_multiple_values(self, label: str, prop_info: Dict, for_update: bool = False) -> List[Any]: + """获取多个真实数据值,用于生成多个查询变体。""" + prop_name, prop_type = prop_info['name'], prop_info['type'] + instances = self.schema.get_instances(label) + + values = [] + for instance in instances: + if instance and prop_name in instance and not for_update: + value = instance.get(prop_name) + if value is not None: + values.append(value) + + # 如果没有获取到真实数据,生成一些随机值 + if not values: + if prop_type == 'STRING': + values = [''.join(random.choices(string.ascii_letters, k=random.randint(5, 8))) for _ in range(random.randint(2, 5))] + elif prop_type in ['INT32', 'INT64']: + values = [random.randint(1, 10000) for _ in range(random.randint(2, 5))] + else: + values = ["default_value"] + + return values + + def _recursive_generate(self, recipe_steps: List[Step], current_query: str, current_desc: str, current_label: str, current_type: str): + """【核心】递归生成函数,实现深度优先搜索和回溯。""" + if not recipe_steps: + # 当到达配方末端时,尝试随机增强(20%概率) + if random.random() < 0.2: + enhanced_queries = self._apply_random_enhancements(current_query, current_desc, current_label, current_type) + for enhanced_query, enhanced_desc in enhanced_queries: + self.generated_pairs.add((enhanced_query, enhanced_desc)) + return + + step_recipe = recipe_steps[0] + remaining_steps = recipe_steps[1:] + + # 获取当前状态下,该步骤的所有合法“填充”选项 + options = self._get_valid_options_for_step(step_recipe, current_label, current_type) + + # 遍历每一个合法选项,继续向下探索 + for option in options: + next_query = current_query + option['query_part'] + next_desc = current_desc + option['desc_part'] + + # 【保存中间结果】 + self.generated_pairs.add((next_query, next_desc)) + + if option['new_type'] != 'none': + # 继续递归 + self._recursive_generate(remaining_steps, next_query, next_desc, option['new_label'], option['new_type']) + + # 在中间步骤也有小概率(10%)进行随机增强 + if remaining_steps and random.random() < 0.1: + enhanced_queries = self._apply_random_enhancements(next_query, next_desc, option['new_label'], option['new_type']) + for enhanced_query, enhanced_desc in enhanced_queries: + # 对增强后的查询继续执行剩余步骤 + self._recursive_generate(remaining_steps, enhanced_query, enhanced_desc, option['new_label'], option['new_type']) + + def _get_valid_options_for_step(self, step_recipe: Step, current_label: str, current_type: str) -> List[Dict]: + """根据配方中的一步,返回所有合法的实例化选项。""" + step_name = step_recipe.name.lower() + step_params = step_recipe.params + options = [] + + # --- 起始步骤 --- + if current_type == 'graph': + if step_name == 'v': + # V() 步骤只是开始遍历,不指定具体的标签 + # 标签过滤由后续的 hasLabel 步骤处理 + if step_params: + # 如果 V() 有参数(ID),直接使用 + ids = ', '.join([repr(p) for p in step_params]) + options.append({ + 'query_part': f".V({ids})", + 'desc_part': f"查找ID为{ids}的顶点", + 'new_label': None, 'new_type': 'vertex' + }) + else: + # 无参数的 V(),返回所有顶点 + options.append({ + 'query_part': ".V()", + 'desc_part': "查找所有顶点", + 'new_label': None, 'new_type': 'vertex' + }) + elif step_name == 'addv': + label = step_params[0] + creation_info = self.schema.get_vertex_creation_info(label) + query_part = f".addV('{label}')" + desc_part = f"添加一个'{self.gremlin_base.get_schema_desc(label)}'顶点" + for prop_name in creation_info.get('required', []): + prop_info = next(p for p in self.schema.get_properties_with_type(label) if p['name'] == prop_name) + prop_value = self._get_random_value(label, prop_info, for_update=True) + query_part += f".property('{prop_name}', {repr(prop_value)})" + desc_part += f",并设置其'{self.gremlin_base.get_schema_desc(prop_name)}'为'{prop_value}'" + options.append({'query_part': query_part, 'desc_part': desc_part, 'new_label': label, 'new_type': 'vertex'}) + return options + + # --- 后续步骤 --- + if step_name in ['out', 'in', 'both']: + if current_label: # 只有当我们知道当前标签时才能导航 + valid_steps = self.schema.get_valid_steps(current_label, current_type) + possible_edges = next((s['params'] for s in valid_steps if s['step'] == step_name), []) + + if step_params and step_params[0] in possible_edges: + # 优先使用配方中指定的边 + edge = step_params[0] + new_label, new_type = self.schema.get_step_result_label(current_label, {'step': step_name, 'param': edge}) + desc_map = {'out': '出边', 'in': '入边', 'both': '双向边'} + options.append({ + 'query_part': f".{step_name}('{edge}')", + 'desc_part': f",然后沿着'{self.gremlin_base.get_schema_desc(edge)}'的{desc_map[step_name]}找到'{self.gremlin_base.get_schema_desc(new_label)}'顶点", + 'new_label': new_label, 'new_type': new_type + }) + + # 同时也生成其他可能的边变体(用于泛化) + for other_edge in possible_edges: + if other_edge != edge: # 避免重复 + new_label, new_type = self.schema.get_step_result_label(current_label, {'step': step_name, 'param': other_edge}) + options.append({ + 'query_part': f".{step_name}('{other_edge}')", + 'desc_part': f",然后沿着'{self.gremlin_base.get_schema_desc(other_edge)}'的{desc_map[step_name]}找到'{self.gremlin_base.get_schema_desc(new_label)}'顶点", + 'new_label': new_label, 'new_type': new_type + }) + else: + # 如果没有指定边或指定的边无效,尝试所有可能的边 + for edge in possible_edges: + new_label, new_type = self.schema.get_step_result_label(current_label, {'step': step_name, 'param': edge}) + desc_map = {'out': '出边', 'in': '入边', 'both': '双向边'} + options.append({ + 'query_part': f".{step_name}('{edge}')", + 'desc_part': f",然后沿着'{self.gremlin_base.get_schema_desc(edge)}'的{desc_map[step_name]}找到'{self.gremlin_base.get_schema_desc(new_label)}'顶点", + 'new_label': new_label, 'new_type': new_type + }) + + elif step_name == 'has': + prop_name = step_params[0] + + if current_label: + # 如果知道当前标签,使用该标签的多个真实数据值 + prop_info = next((p for p in self.schema.get_properties_with_type(current_label) if p['name'] == prop_name), None) + if prop_info: + # 获取多个真实数据值 + values = self._get_multiple_values(current_label, prop_info) + for value in values: + options.append({ + 'query_part': f".has('{prop_name}', {repr(value)})", + 'desc_part': f",其'{self.gremlin_base.get_schema_desc(prop_name)}'属性为'{value}'", + 'new_label': current_label, 'new_type': current_type + }) + else: + # 如果不知道当前标签,尝试所有有该属性的标签 + for label in self.schema.get_vertex_labels(): + prop_info = next((p for p in self.schema.get_properties_with_type(label) if p['name'] == prop_name), None) + if prop_info: + # 获取多个真实数据值 + values = self._get_multiple_values(label, prop_info) + for value in values: + options.append({ + 'query_part': f".has('{prop_name}', {repr(value)})", + 'desc_part': f",其'{self.gremlin_base.get_schema_desc(prop_name)}'属性为'{value}'", + 'new_label': label, 'new_type': current_type + }) + + elif step_name == 'property': # 更新 + prop_name = step_params[0] + prop_info = next((p for p in self.schema.get_updatable_properties(current_label) if p['name'] == prop_name), None) + if prop_info: + # 对于更新操作,我们可以使用多个不同的值来生成多个变体 + values = self._get_multiple_values(current_label, prop_info, for_update=True) + for value in values: + options.append({ + 'query_part': f".property('{prop_name}', {repr(value)})", + 'desc_part': f",并将其'{self.gremlin_base.get_schema_desc(prop_name)}'属性更新为'{value}'", + 'new_label': current_label, 'new_type': current_type + }) + + elif step_name == 'limit': + num = step_params[0] if step_params else random.randint(1, 10) + options.append({ + 'query_part': f".limit({num})", + 'desc_part': f",并只取前{num}个结果", + 'new_label': current_label, 'new_type': current_type + }) + + elif step_name == 'haslabel': + # 处理 hasLabel 步骤 + if step_params: + # 如果配方中指定了标签,优先使用指定的标签 + target_label = step_params[0] + if target_label in self.schema.get_vertex_labels(): + options.append({ + 'query_part': f".hasLabel('{target_label}')", + 'desc_part': f",过滤出'{self.gremlin_base.get_schema_desc(target_label)}'类型的顶点", + 'new_label': target_label, 'new_type': current_type + }) + + # 同时也生成其他可能的标签变体(用于泛化) + for label in self.schema.get_vertex_labels(): + if label != target_label: # 避免重复 + options.append({ + 'query_part': f".hasLabel('{label}')", + 'desc_part': f",过滤出'{self.gremlin_base.get_schema_desc(label)}'类型的顶点", + 'new_label': label, 'new_type': current_type + }) + else: + # 如果没有指定标签,尝试所有可能的标签 + for label in self.schema.get_vertex_labels(): + options.append({ + 'query_part': f".hasLabel('{label}')", + 'desc_part': f",过滤出'{self.gremlin_base.get_schema_desc(label)}'类型的顶点", + 'new_label': label, 'new_type': current_type + }) + + elif step_name == 'drop': + options.append({'query_part': ".drop()", 'desc_part': ",并删除它", 'new_label': None, 'new_type': 'none'}) + + return options + + def _apply_random_enhancements(self, query: str, desc: str, current_label: str, current_type: str) -> List[Tuple[str, str]]: + """ + 对查询进行随机增强,添加一些通用的筛选和限制条件。 + + Args: + query: 当前查询字符串 + desc: 当前描述字符串 + current_label: 当前标签 + current_type: 当前类型 + + Returns: + 增强后的查询-描述对列表 + """ + enhanced_queries = [] + + # 判断当前查询的状态,决定可以添加哪些增强 + if self._is_terminal_step(query): + # 如果是终止步骤,不进行增强 + return enhanced_queries + + if self._is_element_stream(query, current_type): + # 元素流状态:可以添加数量限制、去重、排序等 + enhanced_queries.extend(self._add_element_stream_enhancements(query, desc, current_label, current_type)) + + elif self._is_value_stream(query): + # 值流状态:可以添加去重、排序、数量限制 + enhanced_queries.extend(self._add_value_stream_enhancements(query, desc)) + + return enhanced_queries + + def _is_terminal_step(self, query: str) -> bool: + """判断查询是否以终止步骤结尾""" + terminal_steps = ['.count()', '.sum()', '.mean()', '.min()', '.max()', '.drop()', '.iterate()'] + return any(query.endswith(step) for step in terminal_steps) + + def _is_element_stream(self, query: str, current_type: str) -> bool: + """判断当前是否为元素流(顶点或边的流)""" + # 如果当前类型是vertex或edge,且不是值流,则为元素流 + return current_type in ['vertex', 'edge'] and not self._is_value_stream(query) + + def _is_value_stream(self, query: str) -> bool: + """判断当前是否为值流""" + value_steps = ['.values(', '.valueMap(', '.id()', '.label()', '.key()'] + return any(step in query for step in value_steps) + + def _add_element_stream_enhancements(self, query: str, desc: str, current_label: str, current_type: str) -> List[Tuple[str, str]]: + """为元素流添加增强""" + enhancements = [] + + # 1. 数量限制 - limit(n) + if random.random() < 0.4: # 40% 概率添加limit + # 使用更广泛但仍然合理的范围 + if random.random() < 0.7: # 70%概率使用常见值 + limit_num = random.choice([1, 3, 5, 10, 20, 50, 100]) + else: # 30%概率使用随机值 + limit_num = random.randint(1, 200) # 限制在合理范围内 + enhanced_query = f"{query}.limit({limit_num})" + enhanced_desc = f"{desc},并只取前{limit_num}个结果" + enhancements.append((enhanced_query, enhanced_desc)) + + # 2. 范围限制 - range(low, high) + if random.random() < 0.2: # 20% 概率添加range + if random.random() < 0.6: # 60%概率使用常见值 + low = random.choice([0, 1, 5, 10, 20]) + high = low + random.choice([5, 10, 15, 20, 30]) + else: # 40%概率使用随机值 + low = random.randint(0, 50) + high = low + random.randint(5, 100) + enhanced_query = f"{query}.range({low}, {high})" + enhanced_desc = f"{desc},并获取第{low}到{high}个结果" + enhancements.append((enhanced_query, enhanced_desc)) + + # 3. 随机采样 - sample(n) + if random.random() < 0.3: # 30% 概率添加sample + if random.random() < 0.8: # 80%概率使用常见值 + sample_num = random.choice([1, 2, 3, 5, 10]) + else: # 20%概率使用随机值 + sample_num = random.randint(1, 50) + enhanced_query = f"{query}.sample({sample_num})" + enhanced_desc = f"{desc},并随机抽取{sample_num}个" + enhancements.append((enhanced_query, enhanced_desc)) + + # 4. 去重 - dedup() + if random.random() < 0.3: # 30% 概率添加dedup + enhanced_query = f"{query}.dedup()" + enhanced_desc = f"{desc},并去除重复项" + enhancements.append((enhanced_query, enhanced_desc)) + + # 5. 简单排序 - order() (不使用by子句,避免复杂性) + if current_label and random.random() < 0.2: # 20% 概率添加简单排序 + # 只有在有明确标签时才尝试排序 + sortable_props = self._get_sortable_properties(current_label) + if sortable_props: + prop = random.choice(sortable_props) + order_dir = random.choice(['asc', 'desc']) + enhanced_query = f"{query}.order().by('{prop}', {order_dir})" + enhanced_desc = f"{desc},并按{self.gremlin_base.get_schema_desc(prop)}{'升序' if order_dir == 'asc' else '降序'}排列" + enhancements.append((enhanced_query, enhanced_desc)) + + return enhancements + + def _add_value_stream_enhancements(self, query: str, desc: str) -> List[Tuple[str, str]]: + """为值流添加增强""" + enhancements = [] + + # 1. 数量限制 + if random.random() < 0.4: + if random.random() < 0.7: # 70%概率使用常见值 + limit_num = random.choice([1, 3, 5, 10, 20]) + else: # 30%概率使用随机值 + limit_num = random.randint(1, 100) + enhanced_query = f"{query}.limit({limit_num})" + enhanced_desc = f"{desc},并只取前{limit_num}个值" + enhancements.append((enhanced_query, enhanced_desc)) + + # 2. 去重 + if random.random() < 0.4: + enhanced_query = f"{query}.dedup()" + enhanced_desc = f"{desc},并去除重复的值" + enhancements.append((enhanced_query, enhanced_desc)) + + # 3. 排序(值流可以直接排序,不需要by子句) + if random.random() < 0.3: + enhanced_query = f"{query}.order()" + enhanced_desc = f"{desc},并按字母/数字顺序排列" + enhancements.append((enhanced_query, enhanced_desc)) + + return enhancements + + def _get_sortable_properties(self, label: str) -> List[str]: + """获取指定标签的可排序属性""" + if not label: + return [] + + try: + properties = self.schema.get_properties_with_type(label) + # 选择数值型和字符串型属性作为可排序属性 + sortable = [] + for prop in properties: + if prop['type'] in ['STRING', 'INT32', 'INT64', 'FLOAT', 'DOUBLE']: + sortable.append(prop['name']) + return sortable + except: + return [] + +# --- 单模块测试入口 --- +if __name__ == "__main__": + # base_dir = os.path.dirname(os.path.abspath(__file__)) + # project_root = os.path.dirname(base_dir) + # schema_path = os.path.join(project_root, 'db_data', 'schema', 'movie_schema.json') + # data_path = os.path.join(project_root, 'db_data', 'movie', 'raw_data') + # config_path = os.path.join(project_root, 'config.json') + project_root = "/root/lzj/ospp/AST_Text2Gremlin" + schema_path = os.path.join(project_root, 'db_data', 'schema', 'movie_schema.json') + data_path = os.path.join(project_root, 'db_data') + config_path = "/root/lzj/ospp/AST_Text2Gremlin/config.json" + + if not all(os.path.exists(p) for p in [schema_path, data_path, config_path]): + print("错误: 找不到关键文件,请检查路径。") + else: + config = Config(file_path=config_path) + schema = Schema(schema_path, data_path) + gremlin_base = GremlinBase(config) + + print("\n--- [最终递归版] 测试'查询'配方: g.V().has('name', '...').out('acted_in') ---") + query_recipe = Traversal() + query_recipe.add_step(Step('V')) + query_recipe.add_step(Step('has', params=['name', 'some_value'])) + query_recipe.add_step(Step('out', params=['acted_in'])) + + generator = TraversalGenerator(schema, query_recipe, gremlin_base) + generated_queries = generator.generate() + print(f"查询配方共生成了 {len(generated_queries)} 条不同的语料。部分示例如下:") + for i, (q, d) in enumerate(random.sample(generated_queries, min(5, len(generated_queries)))): + print(f" 实例 {i+1}:\n 查询: {q}\n 描述: {d}") + + print("\n--- [最终递归版] 测试'增加'配-方: g.addV('person') ---") + add_recipe = Traversal() + add_recipe.add_step(Step('addV', params=['person'])) + + generator = TraversalGenerator(schema, add_recipe, gremlin_base) + generated_adds = generator.generate() + print(f"增加配方共生成了 {len(generated_adds)} 条不同的语料。部分示例如下:") + for i, (q, d) in enumerate(random.sample(generated_adds, min(5, len(generated_adds)))): + print(f" 实例 {i+1}:\n 查询: {q}\n 描述: {d}") + + print("\n--- [最终递归版] 测试'更新'配方: g.V().property('born', ...) ---") + update_recipe = Traversal() + update_recipe.add_step(Step('V')) + update_recipe.add_step(Step('property', params=['born', 1960])) + + generator = TraversalGenerator(schema, update_recipe, gremlin_base) + generated_updates = generator.generate() + print(f"更新配方共生成了 {len(generated_updates)} 条不同的语料。部分示例如下:") + for i, (q, d) in enumerate(random.sample(generated_updates, min(5, len(generated_updates)))): + print(f" 实例 {i+1}:\n 查询: {q}\n 描述: {d}") + + print("\n--- [最终递归版] 测试'删除'配方: g.V().has('name', '...').drop() ---") + drop_recipe = Traversal() + drop_recipe.add_step(Step('V')) + drop_recipe.add_step(Step('has', params=['name', 'some_value'])) + drop_recipe.add_step(Step('drop')) + + generator = TraversalGenerator(schema, drop_recipe, gremlin_base) + generated_drops = generator.generate() + print(f"删除配方共生成了 {len(generated_drops)} 条不同的语料。部分示例如下:") + for i, (q, d) in enumerate(random.sample(generated_drops, min(5, len(generated_drops)))): + print(f" 实例 {i+1}:\n 查询: {q}\n 描述: {d}")
