This is an automated email from the ASF dual-hosted git repository. JackieTien97 pushed a commit to branch ty/sqlalchemy in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit ee5afb691f8aa12f74876ab63ad9b0d421d6f66c Author: JackieTien97 <[email protected]> AuthorDate: Wed Apr 22 08:11:34 2026 +0800 Rewrite SQLAlchemy dialect to support IoTDB 2.0+ table model The old SQLAlchemy dialect was built for the tree model (path-based schema). This rewrites it to support the table model with standard relational SQL, including: - Column categories (TAG, ATTRIBUTE, FIELD, TIME) via dialect-specific args - DDL generation with CREATE TABLE categories and TTL support - Table model reflection (SHOW TABLES, SHOW COLUMNS FROM) - Updated type mappings (STRING, BLOB, TIMESTAMP, DATE) - Simplified SQL compiler (table model supports standard SQL) - DBAPI layer: add sql_dialect parameter, propagate exceptions Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> --- iotdb-client/client-py/iotdb/dbapi/Connection.py | 6 + iotdb-client/client-py/iotdb/dbapi/Cursor.py | 61 ++--- .../client-py/iotdb/sqlalchemy/IoTDBDDLCompiler.py | 67 +++++ .../client-py/iotdb/sqlalchemy/IoTDBDialect.py | 154 ++++++----- .../client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py | 303 +-------------------- .../iotdb/sqlalchemy/IoTDBTypeCompiler.py | 80 +++++- .../client-py/iotdb/sqlalchemy/__init__.py | 4 + .../tests/integration/sqlalchemy/test_dialect.py | 292 ++++++++++++++------ 8 files changed, 482 insertions(+), 485 deletions(-) diff --git a/iotdb-client/client-py/iotdb/dbapi/Connection.py b/iotdb-client/client-py/iotdb/dbapi/Connection.py index aee5520e9af..caec91364a5 100644 --- a/iotdb-client/client-py/iotdb/dbapi/Connection.py +++ b/iotdb-client/client-py/iotdb/dbapi/Connection.py @@ -37,9 +37,15 @@ class Connection(object): zone_id=Session.DEFAULT_ZONE_ID, enable_rpc_compression=False, sqlalchemy_mode=False, + sql_dialect=None, + database=None, ): self.__session = Session(host, port, username, password, fetch_size, zone_id) self.__sqlalchemy_mode = sqlalchemy_mode + if sql_dialect: + self.__session.sql_dialect = sql_dialect + if database: + self.__session.database = database self.__is_close = True try: self.__session.open(enable_rpc_compression) diff --git a/iotdb-client/client-py/iotdb/dbapi/Cursor.py b/iotdb-client/client-py/iotdb/dbapi/Cursor.py index a1d6e2caaba..018ade6e99a 100644 --- a/iotdb-client/client-py/iotdb/dbapi/Cursor.py +++ b/iotdb-client/client-py/iotdb/dbapi/Cursor.py @@ -129,41 +129,32 @@ class Cursor(object): sql_seqs.append(seq) sql = "\n".join(sql_seqs) - try: - data_set = self.__session.execute_statement(sql) - col_names = None - col_types = None - rows = [] - - if data_set: - data = data_set.todf() - - if self.__sqlalchemy_mode and time_index: - time_column = data.columns[0] - time_column_value = data.Time - del data[time_column] - for i in range(len(time_index)): - data.insert(time_index[i], time_names[i], time_column_value) - - col_names = data.columns.tolist() - col_types = data_set.get_column_types() - rows = data.values.tolist() - data_set.close_operation_handle() - - self.__result = { - "col_names": col_names, - "col_types": col_types, - "rows": rows, - "row_count": len(rows), - } - except Exception: - logger.error("failed to execute statement:{}".format(sql)) - self.__result = { - "col_names": None, - "col_types": None, - "rows": [], - "row_count": -1, - } + data_set = self.__session.execute_statement(sql) + col_names = None + col_types = None + rows = [] + + if data_set: + data = data_set.todf() + + if self.__sqlalchemy_mode and time_index: + time_column = data.columns[0] + time_column_value = data.Time + del data[time_column] + for i in range(len(time_index)): + data.insert(time_index[i], time_names[i], time_column_value) + + col_names = data.columns.tolist() + col_types = data_set.get_column_types() + rows = data.values.tolist() + data_set.close_operation_handle() + + self.__result = { + "col_names": col_names, + "col_types": col_types, + "rows": rows, + "row_count": len(rows), + } self.__rows = iter(self.__result["rows"]) def executemany(self, operation, seq_of_parameters=None): diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDDLCompiler.py b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDDLCompiler.py new file mode 100644 index 00000000000..0d8f377cb1a --- /dev/null +++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDDLCompiler.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from sqlalchemy.sql.compiler import DDLCompiler + + +class IoTDBDDLCompiler(DDLCompiler): + def visit_create_column(self, create, first_pk=False, **kw): + column = create.element + + if column.system: + return None + + category = column.dialect_options["iotdb"].get("category") + + if category and category.upper() == "TIME": + colspec = self.preparer.format_column(column) + " TIME" + return colspec + + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler_instance.process( + column.type, type_expression=column + ) + ) + + if category: + colspec += " " + category.upper() + + return colspec + + def post_create_table(self, table): + ttl = table.dialect_options["iotdb"].get("ttl") + if ttl is not None: + return " WITH (TTL=%d)" % int(ttl) + return "" + + def create_table_constraints(self, table, **kw): + return "" + + def visit_primary_key_constraint(self, constraint, **kw): + return None + + def visit_foreign_key_constraint(self, constraint, **kw): + return None + + def visit_unique_constraint(self, constraint, **kw): + return None + + def visit_check_constraint(self, constraint, **kw): + return None diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py index 912e23e9f7a..44f9e860fea 100644 --- a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py +++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py @@ -16,58 +16,64 @@ # under the License. # -from sqlalchemy import types, util +from sqlalchemy import schema as sa_schema, types from sqlalchemy.engine import default from sqlalchemy.sql import text -from sqlalchemy.sql.sqltypes import String from iotdb import dbapi +from .IoTDBDDLCompiler import IoTDBDDLCompiler from .IoTDBIdentifierPreparer import IoTDBIdentifierPreparer from .IoTDBSQLCompiler import IoTDBSQLCompiler from .IoTDBTypeCompiler import IoTDBTypeCompiler -TYPES_MAP = { +IOTDB_CATEGORY_TIME = "TIME" +IOTDB_CATEGORY_TAG = "TAG" +IOTDB_CATEGORY_ATTRIBUTE = "ATTRIBUTE" +IOTDB_CATEGORY_FIELD = "FIELD" + +ischema_names = { "BOOLEAN": types.Boolean, "INT32": types.Integer, "INT64": types.BigInteger, "FLOAT": types.Float, "DOUBLE": types.Float, + "STRING": types.String, "TEXT": types.Text, - "LONG": types.BigInteger, + "BLOB": types.LargeBinary, + "TIMESTAMP": types.DateTime, + "DATE": types.Date, } class IoTDBDialect(default.DefaultDialect): name = "iotdb" - driver = "iotdb-python" + driver = "iotdb" + statement_compiler = IoTDBSQLCompiler - type_compiler = IoTDBTypeCompiler + ddl_compiler = IoTDBDDLCompiler + type_compiler_cls = IoTDBTypeCompiler preparer = IoTDBIdentifierPreparer - convert_unicode = True - supports_unicode_statements = True - supports_unicode_binds = True - supports_simple_order_by_label = False + supports_alter = True supports_schemas = True - supports_right_nested_joins = False - description_encoding = None - - if hasattr(String, "RETURNS_UNICODE"): - returns_unicode_strings = String.RETURNS_UNICODE - else: - - def _check_unicode_returns(self, connection, additional_tests=None): - return True - - _check_unicode_returns = _check_unicode_returns - - def create_connect_args(self, url): - # inherits the docstring from interfaces.Dialect.create_connect_args - opts = url.translate_connect_args() - opts.update(url.query) - opts.update({"sqlalchemy_mode": True}) - return [[], opts] + supports_sequences = False + supports_native_boolean = True + supports_native_enum = False + supports_statement_cache = True + insert_returning = False + update_returning = False + delete_returning = False + supports_default_values = False + supports_empty_insert = False + postfetch_lastrowid = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + construct_arguments = [ + (sa_schema.Column, {"category": None}), + (sa_schema.Table, {"ttl": None}), + ] @classmethod def import_dbapi(cls): @@ -77,8 +83,23 @@ class IoTDBDialect(default.DefaultDialect): def dbapi(cls): return dbapi - def has_schema(self, connection, schema): - return schema in self.get_schema_names(connection) + def create_connect_args(self, url): + opts = url.translate_connect_args() + opts.update(url.query) + opts["sql_dialect"] = "table" + return ([], opts) + + def initialize(self, connection): + pass + + def _get_server_version_info(self, connection): + return None + + def _get_default_schema_name(self, connection): + return None + + def has_schema(self, connection, schema_name, **kw): + return schema_name in self.get_schema_names(connection) def has_table(self, connection, table_name, schema=None, **kw): return table_name in self.get_table_names(connection, schema=schema) @@ -88,22 +109,41 @@ class IoTDBDialect(default.DefaultDialect): return [row[0] for row in cursor.fetchall()] def get_table_names(self, connection, schema=None, **kw): - cursor = connection.execute( - text("SHOW DEVICES %s.**" % (schema or self.default_schema_name)) - ) - return [row[0].replace(schema + ".", "", 1) for row in cursor.fetchall()] + if schema: + connection.execute(text("USE %s" % schema)) + cursor = connection.execute(text("SHOW TABLES")) + return [row[0] for row in cursor.fetchall()] def get_columns(self, connection, table_name, schema=None, **kw): + if schema: + connection.execute(text("USE %s" % schema)) cursor = connection.execute( - text("SHOW TIMESERIES %s.%s.*" % (schema, table_name)) + text("SHOW COLUMNS FROM %s" % table_name) ) - columns = [self._general_time_column_info()] + columns = [] for row in cursor.fetchall(): - columns.append(self._create_column_info(row, schema, table_name)) + col_name = row[0] + col_type_str = row[1] + col_category = row[2] if len(row) > 2 else None + + sa_type = ischema_names.get(col_type_str.upper(), types.UserDefinedType) + + col_info = { + "name": col_name, + "type": sa_type() if isinstance(sa_type, type) else sa_type, + "nullable": True, + "default": None, + } + + if col_category: + col_info["iotdb_category"] = col_category.upper() + + columns.append(col_info) + return columns def get_pk_constraint(self, connection, table_name, schema=None, **kw): - pass + return {"constrained_columns": [], "name": None} def get_foreign_keys(self, connection, table_name, schema=None, **kw): return [] @@ -111,33 +151,11 @@ class IoTDBDialect(default.DefaultDialect): def get_indexes(self, connection, table_name, schema=None, **kw): return [] - @util.memoized_property - def _dialect_specific_select_one(self): - # IoTDB does not support select 1 - # so replace the statement with "show version" - return "SHOW VERSION" - - def _general_time_column_info(self): - """ - Treat Time as a column - """ - return { - "name": "Time", - "type": self._resolve_type("LONG"), - "nullable": False, - "default": None, - } - - def _create_column_info(self, row, schema, table_name): - """ - Generate description information for each column - """ - return { - "name": row[0].replace(schema + "." + table_name + ".", "", 1), - "type": self._resolve_type(row[3]), - "nullable": True, - "default": None, - } - - def _resolve_type(self, type_): - return TYPES_MAP.get(type_, types.UserDefinedType) + def get_view_names(self, connection, schema=None, **kw): + return [] + + def do_commit(self, dbapi_connection): + pass + + def do_rollback(self, dbapi_connection): + pass diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py index 008a314e683..08482ac41b9 100644 --- a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py +++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py @@ -17,308 +17,7 @@ # from sqlalchemy.sql.compiler import SQLCompiler -from sqlalchemy.sql.compiler import OPERATORS -from sqlalchemy.sql import operators class IoTDBSQLCompiler(SQLCompiler): - def order_by_clause(self, select, **kw): - """allow dialects to customize how ORDER BY is rendered.""" - - order_by = select._order_by_clause._compiler_dispatch(self, **kw) - if "Time" in order_by: - return " ORDER BY " + order_by.replace('"', "") - else: - return "" - - def group_by_clause(self, select, **kw): - """allow dialects to customize how GROUP BY is rendered.""" - return "" - - def visit_select( - self, - select, - asfrom=False, - parens=True, - fromhints=None, - compound_index=0, - nested_join_translation=False, - select_wraps_for=None, - lateral=False, - **kwargs, - ): - """ - Override this method to solve two problems - 1. IoTDB does not support querying Time as a measurement name (e.g. select Time from root.storagegroup.device) - 2. IoTDB does not support path.measurement format to determine a column (e.g. select root.storagegroup.device.temperature from root.storagegroup.device) - """ - assert select_wraps_for is None, ( - "SQLAlchemy 1.4 requires use of " - "the translate_select_structure hook for structural " - "translations of SELECT objects" - ) - - # initial setup of SELECT. the compile_state_factory may now - # be creating a totally different SELECT from the one that was - # passed in. for ORM use this will convert from an ORM-state - # SELECT to a regular "Core" SELECT. other composed operations - # such as computation of joins will be performed. - - kwargs["within_columns_clause"] = False - - compile_state = select_stmt._compile_state_factory(select_stmt, self, **kwargs) - select_stmt = compile_state.statement - - toplevel = not self.stack - - if toplevel and not self.compile_state: - self.compile_state = compile_state - - is_embedded_select = compound_index is not None or insert_into - - # translate step for Oracle, SQL Server which often need to - # restructure the SELECT to allow for LIMIT/OFFSET and possibly - # other conditions - if self.translate_select_structure: - new_select_stmt = self.translate_select_structure( - select_stmt, asfrom=asfrom, **kwargs - ) - - # if SELECT was restructured, maintain a link to the originals - # and assemble a new compile state - if new_select_stmt is not select_stmt: - compile_state_wraps_for = compile_state - select_wraps_for = select_stmt - select_stmt = new_select_stmt - - compile_state = select_stmt._compile_state_factory( - select_stmt, self, **kwargs - ) - select_stmt = compile_state.statement - - entry = self._default_stack_entry if toplevel else self.stack[-1] - - populate_result_map = need_column_expressions = ( - toplevel - or entry.get("need_result_map_for_compound", False) - or entry.get("need_result_map_for_nested", False) - ) - - # indicates there is a CompoundSelect in play and we are not the - # first select - if compound_index: - populate_result_map = False - - # this was first proposed as part of #3372; however, it is not - # reached in current tests and could possibly be an assertion - # instead. - if not populate_result_map and "add_to_result_map" in kwargs: - del kwargs["add_to_result_map"] - - froms = self._setup_select_stack( - select_stmt, compile_state, entry, asfrom, lateral, compound_index - ) - - column_clause_args = kwargs.copy() - column_clause_args.update( - {"within_label_clause": False, "within_columns_clause": False} - ) - - text = "SELECT " # we're off to a good start ! - - if select_stmt._hints: - hint_text, byfrom = self._setup_select_hints(select_stmt) - if hint_text: - text += hint_text + " " - else: - byfrom = None - - if select_stmt._independent_ctes: - for cte in select_stmt._independent_ctes: - cte._compiler_dispatch(self, **kwargs) - - if select_stmt._prefixes: - text += self._generate_prefixes( - select_stmt, select_stmt._prefixes, **kwargs - ) - - text += self.get_select_precolumns(select_stmt, **kwargs) - # the actual list of columns to print in the SELECT column list. - inner_columns = [ - c - for c in [ - self._label_select_column( - select_stmt, - column, - populate_result_map, - asfrom, - column_clause_args, - name=name, - proxy_name=proxy_name, - fallback_label_name=fallback_label_name, - column_is_repeated=repeated, - need_column_expressions=need_column_expressions, - ) - for ( - name, - proxy_name, - fallback_label_name, - column, - repeated, - ) in compile_state.columns_plus_names - ] - if c is not None - ] - - if populate_result_map and select_wraps_for is not None: - # if this select was generated from translate_select, - # rewrite the targeted columns in the result map - - translate = dict( - zip( - [ - name - for ( - key, - proxy_name, - fallback_label_name, - name, - repeated, - ) in compile_state.columns_plus_names - ], - [ - name - for ( - key, - proxy_name, - fallback_label_name, - name, - repeated, - ) in compile_state_wraps_for.columns_plus_names - ], - ) - ) - - self._result_columns = [ - (key, name, tuple(translate.get(o, o) for o in obj), type_) - for key, name, obj, type_ in self._result_columns - ] - - # change the superset aggregate function name into iotdb aggregate function name - # by matching the head of aggregate function name and replace it. - for i in range(len(inner_columns)): - if inner_columns[i].startswith("max("): - inner_columns[i] = inner_columns[i].replace("max(", "max_value(") - if inner_columns[i].startswith("min("): - inner_columns[i] = inner_columns[i].replace("min(", "min_value(") - if inner_columns[i].startswith("count(DISTINCT"): - inner_columns[i] = inner_columns[i].replace("count(DISTINCT", "count(") - - # IoTDB does not allow to query Time as column, - # need to filter out Time and pass Time and Time's alias to DBAPI separately - # to achieve the query of Time by encoding. - time_column_index = [] - time_column_names = [] - for i in range(len(inner_columns)): - column_strs = ( - inner_columns[i].replace(self.preparer.initial_quote, "").split() - ) - if "Time" in column_strs: - time_column_index.append(str(i)) - time_column_names.append( - column_strs[2] - if OPERATORS[operators.as_] in column_strs - else column_strs[0] - ) - # delete Time column - inner_columns = list( - filter( - lambda x: "Time" - not in x.replace(self.preparer.initial_quote, "").split(), - inner_columns, - ) - ) - - if inner_columns and time_column_index: - inner_columns[-1] = ( - inner_columns[-1] - + " \n FROM Time Index " - + " ".join(time_column_index) - + " \n FROM Time Name " - + " ".join(time_column_names) - ) - - text = self._compose_select_body( - text, - select_stmt, - compile_state, - inner_columns, - froms, - byfrom, - toplevel, - kwargs, - ) - - if select_stmt._statement_hints: - per_dialect = [ - ht - for (dialect_name, ht) in select_stmt._statement_hints - if dialect_name in ("*", self.dialect.name) - ] - if per_dialect: - text += " " + self.get_statement_hint_text(per_dialect) - - # In compound query, CTEs are shared at the compound level - if self.ctes and (not is_embedded_select or toplevel): - nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kwargs.get("visiting_cte"), - ) - + text - ) - - if select_stmt._suffixes: - text += " " + self._generate_prefixes( - select_stmt, select_stmt._suffixes, **kwargs - ) - - self.stack.pop(-1) - return text - - def visit_table( - self, - table, - asfrom=False, - iscrud=False, - ashint=False, - fromhints=None, - use_schema=True, - **kwargs, - ): - """ - IoTDB's table does not support quotation marks (e.g. select ** from `root.`) - need to override this method - """ - if asfrom or ashint: - effective_schema = self.preparer.schema_for_object(table) - - if use_schema and effective_schema: - ret = effective_schema + "." + table.name - else: - ret = table.name - if fromhints and table in fromhints: - ret = self.format_from_hint_text(ret, table, fromhints[table], iscrud) - return ret - else: - return "" - - def visit_column( - self, column, add_to_result_map=None, include_table=True, **kwargs - ): - """ - IoTDB's where statement does not support "table".column format(e.g. "table".column > 1) - need to override this method to return the name of column directly - """ - return column.name + pass diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBTypeCompiler.py b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBTypeCompiler.py index 4cfd2480bd4..1fb80185bd5 100644 --- a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBTypeCompiler.py +++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBTypeCompiler.py @@ -20,11 +20,17 @@ from sqlalchemy.sql.compiler import GenericTypeCompiler class IoTDBTypeCompiler(GenericTypeCompiler): + def visit_BOOLEAN(self, type_, **kw): + return "BOOLEAN" + def visit_FLOAT(self, type_, **kw): return "FLOAT" + def visit_REAL(self, type_, **kw): + return "FLOAT" + def visit_NUMERIC(self, type_, **kw): - return "INT64" + return "DOUBLE" def visit_DECIMAL(self, type_, **kw): return "DOUBLE" @@ -36,10 +42,76 @@ class IoTDBTypeCompiler(GenericTypeCompiler): return "INT32" def visit_BIGINT(self, type_, **kw): - return "LONG" + return "INT64" def visit_TIMESTAMP(self, type_, **kw): - return "LONG" + return "TIMESTAMP" + + def visit_DATETIME(self, type_, **kw): + return "TIMESTAMP" + + def visit_DATE(self, type_, **kw): + return "DATE" + + def visit_TEXT(self, type_, **kw): + return "STRING" + + def visit_VARCHAR(self, type_, **kw): + return "STRING" + + def visit_NVARCHAR(self, type_, **kw): + return "STRING" + + def visit_CHAR(self, type_, **kw): + return "STRING" + + def visit_BLOB(self, type_, **kw): + return "BLOB" + + def visit_BINARY(self, type_, **kw): + return "BLOB" + + def visit_VARBINARY(self, type_, **kw): + return "BLOB" + + def visit_LARGE_BINARY(self, type_, **kw): + return "BLOB" + + def visit_large_binary(self, type_, **kw): + return "BLOB" + + def visit_boolean(self, type_, **kw): + return "BOOLEAN" + + def visit_string(self, type_, **kw): + return "STRING" + + def visit_unicode(self, type_, **kw): + return "STRING" def visit_text(self, type_, **kw): - return "TEXT" + return "STRING" + + def visit_unicode_text(self, type_, **kw): + return "STRING" + + def visit_float(self, type_, **kw): + return "FLOAT" + + def visit_numeric(self, type_, **kw): + return "DOUBLE" + + def visit_integer(self, type_, **kw): + return "INT32" + + def visit_big_integer(self, type_, **kw): + return "INT64" + + def visit_timestamp(self, type_, **kw): + return "TIMESTAMP" + + def visit_datetime(self, type_, **kw): + return "TIMESTAMP" + + def visit_date(self, type_, **kw): + return "DATE" diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/__init__.py b/iotdb-client/client-py/iotdb/sqlalchemy/__init__.py index 2a1e720805f..7a6b7223a6f 100644 --- a/iotdb-client/client-py/iotdb/sqlalchemy/__init__.py +++ b/iotdb-client/client-py/iotdb/sqlalchemy/__init__.py @@ -15,3 +15,7 @@ # specific language governing permissions and limitations # under the License. # + +from .IoTDBDialect import IoTDBDialect + +__all__ = ["IoTDBDialect"] diff --git a/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py b/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py index 9ed0c808917..c5bb2fa3029 100644 --- a/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py +++ b/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py @@ -16,96 +16,236 @@ # under the License. # -import operator - -from sqlalchemy import create_engine, inspect +from sqlalchemy import ( + create_engine, + inspect, + Column, + Float, + Integer, + BigInteger, + String, + Boolean, + Table, + MetaData, +) from sqlalchemy.dialects import registry -from sqlalchemy.orm import Session from sqlalchemy.sql import text from tests.integration.iotdb_container import IoTDBContainer from urllib.parse import quote_plus as urlquote -final_flag = True -failed_count = 0 - - -def test_fail(): - global failed_count - global final_flag - final_flag = False - failed_count += 1 +TEST_DB = "test_sqlalchemy" -def print_message(message): - print("*********") - print(message) - print("*********") - assert False - -def test_dialect(): +def test_table_model_dialect(): with IoTDBContainer("iotdb:dev") as db: db: IoTDBContainer - password = urlquote("root") + password = urlquote("TimechoDB@2021") host = db.get_container_host_ip() port = db.get_exposed_port(6667) - url = f"iotdb://root:{password}@{host}:{port}" + url = f"iotdb://root:{password}@{host}:{port}/{TEST_DB}" registry.register("iotdb", "iotdb.sqlalchemy.IoTDBDialect", "IoTDBDialect") - eng = create_engine(url) - - with Session(eng) as session: - session.execute(text("create database root.cursor")) - session.execute(text("create database root.cursor_s1")) - session.execute( - text( - "create timeseries root.cursor.device1.temperature with datatype=FLOAT,encoding=RLE" - ) + engine = create_engine(url) + + with engine.connect() as conn: + conn.execute(text("CREATE DATABASE %s" % TEST_DB)) + + _test_ddl(engine) + _test_dml(engine) + _test_reflection(engine) + _test_time_column(engine) + + with engine.connect() as conn: + conn.execute(text("DROP DATABASE %s" % TEST_DB)) + + engine.dispose() + print("All table model dialect tests passed!") + + +def _test_ddl(engine): + metadata = MetaData() + + sensors = Table( + "sensors", + metadata, + Column("region", String, iotdb_category="TAG"), + Column("device_id", String, iotdb_category="TAG"), + Column("model", String, iotdb_category="ATTRIBUTE"), + Column("temperature", Float, iotdb_category="FIELD"), + Column("humidity", Float, iotdb_category="FIELD"), + Column("status", Boolean, iotdb_category="FIELD"), + schema=TEST_DB, + iotdb_ttl=3600000, + ) + + metadata.create_all(engine) + + insp = inspect(engine) + table_names = insp.get_table_names(schema=TEST_DB) + assert "sensors" in table_names, ( + "CREATE TABLE failed: 'sensors' not in %s" % table_names + ) + + sensors.drop(engine) + table_names = insp.get_table_names(schema=TEST_DB) + assert "sensors" not in table_names, ( + "DROP TABLE failed: 'sensors' still in %s" % table_names + ) + + print(" DDL tests passed") + + +def _test_dml(engine): + metadata = MetaData() + sensors = Table( + "sensors_dml", + metadata, + Column("region", String, iotdb_category="TAG"), + Column("device_id", String, iotdb_category="TAG"), + Column("temperature", Float, iotdb_category="FIELD"), + Column("humidity", Float, iotdb_category="FIELD"), + schema=TEST_DB, + ) + metadata.create_all(engine) + + with engine.connect() as conn: + conn.execute( + sensors.insert().values( + region="asia", + device_id="d001", + temperature=25.5, + humidity=60.0, ) - session.execute( - text( - "create timeseries root.cursor.device1.status with datatype=FLOAT,encoding=RLE" - ) + ) + conn.execute( + sensors.insert().values( + region="europe", + device_id="d002", + temperature=18.3, + humidity=75.0, ) - session.execute( - text( - "create timeseries root.cursor.device2.temperature with datatype=FLOAT,encoding=RLE" - ) + ) + + result = conn.execute(sensors.select()).fetchall() + assert len(result) == 2, "INSERT/SELECT failed: expected 2 rows, got %d" % len( + result + ) + + result = conn.execute( + sensors.select().where(sensors.c.region == "asia") + ).fetchall() + assert len(result) == 1, ( + "SELECT WHERE failed: expected 1 row, got %d" % len(result) + ) + + result = conn.execute( + sensors.select().order_by(sensors.c.temperature).limit(1) + ).fetchall() + assert len(result) == 1, "LIMIT failed: expected 1 row, got %d" % len(result) + + conn.execute( + sensors.delete().where(sensors.c.device_id == "d002") + ) + result = conn.execute(sensors.select()).fetchall() + assert len(result) == 1, ( + "DELETE failed: expected 1 row after delete, got %d" % len(result) + ) + + sensors.drop(engine) + print(" DML tests passed") + + +def _test_reflection(engine): + with engine.connect() as conn: + conn.execute(text("USE %s" % TEST_DB)) + conn.execute( + text( + "CREATE TABLE reflect_test (" + "region STRING TAG, " + "device STRING TAG, " + "model STRING ATTRIBUTE, " + "temperature FLOAT FIELD, " + "humidity DOUBLE FIELD" + ")" ) - - insp = inspect(eng) - # test get_schema_names - schema_names = insp.get_schema_names() - if not operator.ge( - schema_names, ["root.__audit", "root.cursor", "root.cursor_s1"] - ): - test_fail() - print_message("Actual result " + str(schema_names)) - print_message("test get_schema_names failed!") - # test get_table_names - table_names = insp.get_table_names("root.cursor") - if not operator.eq(table_names, ["device1", "device2"]): - test_fail() - print_message("Actual result " + str(table_names)) - print_message("test get_table_names failed!") - # test get_columns - columns = insp.get_columns(table_name="device1", schema="root.cursor") - if len(columns) != 3: - test_fail() - print_message("Actual result " + str(columns)) - print_message("test get_columns failed!") - - with Session(eng) as session: - session.execute(text("delete database root.cursor")) - session.execute(text("delete database root.cursor_s1")) - - # close engine - eng.dispose() - - -if final_flag: - print("All executions done!!") -else: - print("Some test failed, please have a check") - print("failed count: ", failed_count) - exit(1) + ) + + insp = inspect(engine) + + schemas = insp.get_schema_names() + assert TEST_DB in schemas, "%s not in schemas: %s" % (TEST_DB, schemas) + + tables = insp.get_table_names(schema=TEST_DB) + assert "reflect_test" in tables, "reflect_test not in tables: %s" % tables + + columns = insp.get_columns(table_name="reflect_test", schema=TEST_DB) + col_names = [c["name"] for c in columns] + assert "region" in col_names, "region not in columns: %s" % col_names + assert "temperature" in col_names, "temperature not in columns: %s" % col_names + + pk = insp.get_pk_constraint(table_name="reflect_test", schema=TEST_DB) + assert pk["constrained_columns"] == [], "Expected empty PK constraint" + + fks = insp.get_foreign_keys(table_name="reflect_test", schema=TEST_DB) + assert fks == [], "Expected empty FK list" + + indexes = insp.get_indexes(table_name="reflect_test", schema=TEST_DB) + assert indexes == [], "Expected empty index list" + + with engine.connect() as conn: + conn.execute(text("USE %s" % TEST_DB)) + conn.execute(text("DROP TABLE reflect_test")) + + print(" Reflection tests passed") + + +def _test_time_column(engine): + metadata = MetaData() + + table_implicit = Table( + "time_implicit", + metadata, + Column("device", String, iotdb_category="TAG"), + Column("value", Float, iotdb_category="FIELD"), + schema=TEST_DB, + ) + metadata.create_all(engine) + + with engine.connect() as conn: + conn.execute( + table_implicit.insert().values(device="d001", value=42.0) + ) + result = conn.execute(table_implicit.select()).fetchall() + assert len(result) == 1, ( + "Implicit TIME insert failed: expected 1 row, got %d" % len(result) + ) + + table_implicit.drop(engine) + + metadata2 = MetaData() + table_explicit = Table( + "time_explicit", + metadata2, + Column("ts", BigInteger, iotdb_category="TIME"), + Column("device", String, iotdb_category="TAG"), + Column("value", Float, iotdb_category="FIELD"), + schema=TEST_DB, + ) + metadata2.create_all(engine) + + with engine.connect() as conn: + conn.execute( + table_explicit.insert().values(ts=1000000, device="d001", value=42.0) + ) + result = conn.execute(table_explicit.select()).fetchall() + assert len(result) == 1, ( + "Explicit TIME insert failed: expected 1 row, got %d" % len(result) + ) + + table_explicit.drop(engine) + print(" TIME column tests passed") + + +if __name__ == "__main__": + test_table_model_dialect()
