This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch text2gql in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit 84f7a417b1fc342630498959c1d9b28ca3894566 Author: Lriver <[email protected]> AuthorDate: Tue Sep 30 20:53:21 2025 +0800 feat: add graph database schema management with vertex/edge labels and properties --- text2gremlin/AST_Text2Gremlin/base/Schema.py | 211 +++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) diff --git a/text2gremlin/AST_Text2Gremlin/base/Schema.py b/text2gremlin/AST_Text2Gremlin/base/Schema.py new file mode 100644 index 00000000..305c5ea0 --- /dev/null +++ b/text2gremlin/AST_Text2Gremlin/base/Schema.py @@ -0,0 +1,211 @@ + +""" +图数据库Schema管理模块。 + +负责解析Schema定义和CSV数据文件,为查询生成器提供图结构信息和真实数据实例。 +""" + +import os +import json +import pandas as pd +from typing import List, Dict +import json +import random +import pandas as pd +from typing import List, Dict, Any, Tuple + +class Schema: + def __init__(self, schema_file: str, data_dir: str): + self.data_dir = data_dir + self.vertices: Dict[str, Dict[str, Any]] = {} + self.edges: Dict[str, Dict[str, Any]] = {} + self.vertex_data: Dict[str, pd.DataFrame] = {} + self.edge_data: Dict[str, pd.DataFrame] = {} + + with open(schema_file, 'r', encoding='utf-8') as f: + schema_data = json.load(f) + + # 解析 schema 定义 + for item in schema_data.get('schema', []): + label = item['label'] + if item['type'] == 'VERTEX': + self.vertices[label] = { + 'primary': item.get('primary', None), + 'properties': {prop['name']: {'type': prop['type'], 'optional': prop.get('optional', False)} for prop in item.get('properties', [])} + } + elif item['type'] == 'EDGE': + self.edges[label] = { + 'source': None, 'destination': None, + 'properties': {prop['name']: {'type': prop['type'], 'optional': prop.get('optional', False)} for prop in item.get('properties', [])} + } + + # 2. 解析 files 定义,获取路径、header行数和边的端点 + self.vertex_files: Dict[str, Dict] = {} + self.edge_files: Dict[str, Dict] = {} + for file_info in schema_data.get('files', []): + label = file_info['label'] + path = os.path.join(self.data_dir, file_info['path']) + header_rows = file_info.get('header', 1) # 获取header行数,默认为1 + + file_details = {'path': path, 'header_rows': header_rows} + + is_edge = 'SRC_ID' in file_info and 'DST_ID' in file_info + if is_edge: + self.edge_files[label] = file_details + if label in self.edges: + self.edges[label]['source'] = file_info['SRC_ID'] + self.edges[label]['destination'] = file_info['DST_ID'] + else: + self.vertex_files[label] = file_details + + def _parse_custom_csv(self, file_path: str, header_line_index: int) -> pd.DataFrame: + """解析自定义多行表头的 CSV 文件。""" + try: + with open(file_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + + # 从第二行解析列名 + header_line = lines[header_line_index - 1] + column_defs = header_line.strip().split(',') + column_names = [d.split(':')[0] for d in column_defs] + + # 从指定header行之后开始读取数据 + data_lines = lines[header_line_index:] + + if not data_lines: + return pd.DataFrame(columns=column_names) + + # 使用pandas从内存中的字符串列表读取数据 + from io import StringIO + csv_data = StringIO("".join(data_lines)) + df = pd.read_csv(csv_data, header=None, names=column_names) + return df + + except (FileNotFoundError, IndexError) as e: + print(f"警告: 读取或解析文件失败: {file_path}, 错误: {e}") + return pd.DataFrame() + + def _load_vertex_data(self, label: str): + if label not in self.vertex_data and label in self.vertex_files: + file_info = self.vertex_files[label] + self.vertex_data[label] = self._parse_custom_csv(file_info['path'], file_info['header_rows']) + + def _load_edge_data(self, label: str): + if label not in self.edge_data and label in self.edge_files: + file_info = self.edge_files[label] + self.edge_data[label] = self._parse_custom_csv(file_info['path'], file_info['header_rows']) + + # --- Schema 查询方法 (保持不变) --- + def get_vertex_labels(self) -> List[str]: + return list(self.vertices.keys()) + + def get_edge_labels(self) -> List[str]: + return list(self.edges.keys()) + + def get_properties_with_type(self, label: str) -> List[Dict[str, str]]: + props_dict = self.vertices.get(label, {}).get('properties', {}) or self.edges.get(label, {}).get('properties', {}) + return [{'name': name, 'type': meta['type']} for name, meta in props_dict.items()] + + def get_valid_steps(self, current_label: str, element_type: str = 'vertex') -> List[Dict]: + if element_type == 'vertex': + if current_label not in self.vertices: return [] + valid_steps = [] + outgoing = [l for l, e in self.edges.items() if e['source'] == current_label] + if outgoing: valid_steps.append({'step': 'out', 'params': outgoing}) + incoming = [l for l, e in self.edges.items() if e['destination'] == current_label] + if incoming: valid_steps.append({'step': 'in', 'params': incoming}) + props = self.get_properties_with_type(current_label) + if props: + valid_steps.append({'step': 'properties', 'params': [p['name'] for p in props]}) + valid_steps.append({'step': 'has', 'params': props}) + return valid_steps + return [] + + def get_step_result_label(self, start_label: str, step: Dict) -> Tuple[str, str]: + step_name, step_param = step.get('step'), step.get('param') + if step_name == 'out': return self.edges[step_param]['destination'], 'vertex' + if step_name == 'in': return self.edges[step_param]['source'], 'vertex' + if step_name in ['properties', 'has', 'values']: return start_label, 'vertex' + return None, None + + def get_vertex_creation_info(self, label: str) -> Dict: + if label not in self.vertices: return {} + schema_info = self.vertices[label] + required = [name for name, meta in schema_info['properties'].items() if not meta['optional']] + return {'primary': schema_info.get('primary'), 'required': required} + + def get_edge_creation_info(self, label: str) -> Tuple[str, str]: + if label in self.edges: return (self.edges[label]['source'], self.edges[label]['destination']) + return (None, None) + + def get_updatable_properties(self, label: str) -> List[Dict[str, str]]: + if label not in self.vertices: return [] + schema_info = self.vertices[label] + primary_key = schema_info.get('primary') + return [{'name': name, 'type': meta['type']} for name, meta in schema_info['properties'].items() if name != primary_key] + + def get_instance(self, label: str) -> Dict: + """获取单个实例(保持向后兼容)""" + instances = self.get_instances(label, count=1) + return instances[0] if instances else {} + + def get_instances(self, label: str, count: int = None) -> List[Dict]: + """获取多个实例 + + Args: + label: 标签名 + count: 要获取的实例数量,如果为None则随机选择2-5个 + + Returns: + 实例列表 + """ + import random + + is_edge = label in self.edges + data_cache = self.edge_data if is_edge else self.vertex_data + load_func = self._load_edge_data if is_edge else self._load_vertex_data + + if label not in data_cache: + load_func(label) + + df = data_cache.get(label) + if df is None or df.empty: + return [] + + # 如果没有指定数量,随机选择2-5个 + if count is None: + count = random.randint(2, 5) + + # 如果实际数据量小于要求的数量,就全部取出 + actual_count = min(count, len(df)) + + # 随机采样 + sampled_df = df.sample(actual_count) + return sampled_df.to_dict('records') + +# --- 单模块测试入口 --- +if __name__ == "__main__": + base_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.dirname(base_dir) + schema_path = os.path.join(project_root, 'db_data', 'schema', 'movie_schema.json') + # 【修正】将 data_path 指向包含 movie/raw_data 的上级目录 + data_path = os.path.join(project_root, 'db_data') + + if not os.path.exists(schema_path) or not os.path.exists(data_path): + print("错误: 找不到 Schema 或数据文件,请检查路径。") + else: + schema = Schema(schema_path, data_path) + print("\n--- Schema 初始化成功 (已修复CSV读取逻辑) ---") + + print("\n--- 测试数据实例获取 ---") + random_person = schema.get_instance('person') + print(f"随机获取一个 'person' 实例: {random_person}") + + random_user = schema.get_instance('user') + print(f"随机获取一个 'user' 实例: {random_user}") + + # 验证 name 属性是否能被正确读取 + if random_person and 'name' in random_person: + print(f"成功读取到 'person' 的 name: {random_person['name']}") + else: + print("错误: 未能从 'person' 实例中读取到 name 属性。") \ No newline at end of file
