This is an automated email from the ASF dual-hosted git repository.

JackieTien97 pushed a commit to branch ty/sqlalchemy
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit ee5afb691f8aa12f74876ab63ad9b0d421d6f66c
Author: JackieTien97 <[email protected]>
AuthorDate: Wed Apr 22 08:11:34 2026 +0800

    Rewrite SQLAlchemy dialect to support IoTDB 2.0+ table model
    
    The old SQLAlchemy dialect was built for the tree model (path-based
    schema). This rewrites it to support the table model with standard
    relational SQL, including:
    - Column categories (TAG, ATTRIBUTE, FIELD, TIME) via dialect-specific args
    - DDL generation with CREATE TABLE categories and TTL support
    - Table model reflection (SHOW TABLES, SHOW COLUMNS FROM)
    - Updated type mappings (STRING, BLOB, TIMESTAMP, DATE)
    - Simplified SQL compiler (table model supports standard SQL)
    - DBAPI layer: add sql_dialect parameter, propagate exceptions
    
    Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
---
 iotdb-client/client-py/iotdb/dbapi/Connection.py   |   6 +
 iotdb-client/client-py/iotdb/dbapi/Cursor.py       |  61 ++---
 .../client-py/iotdb/sqlalchemy/IoTDBDDLCompiler.py |  67 +++++
 .../client-py/iotdb/sqlalchemy/IoTDBDialect.py     | 154 ++++++-----
 .../client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py | 303 +--------------------
 .../iotdb/sqlalchemy/IoTDBTypeCompiler.py          |  80 +++++-
 .../client-py/iotdb/sqlalchemy/__init__.py         |   4 +
 .../tests/integration/sqlalchemy/test_dialect.py   | 292 ++++++++++++++------
 8 files changed, 482 insertions(+), 485 deletions(-)

diff --git a/iotdb-client/client-py/iotdb/dbapi/Connection.py 
b/iotdb-client/client-py/iotdb/dbapi/Connection.py
index aee5520e9af..caec91364a5 100644
--- a/iotdb-client/client-py/iotdb/dbapi/Connection.py
+++ b/iotdb-client/client-py/iotdb/dbapi/Connection.py
@@ -37,9 +37,15 @@ class Connection(object):
         zone_id=Session.DEFAULT_ZONE_ID,
         enable_rpc_compression=False,
         sqlalchemy_mode=False,
+        sql_dialect=None,
+        database=None,
     ):
         self.__session = Session(host, port, username, password, fetch_size, 
zone_id)
         self.__sqlalchemy_mode = sqlalchemy_mode
+        if sql_dialect:
+            self.__session.sql_dialect = sql_dialect
+        if database:
+            self.__session.database = database
         self.__is_close = True
         try:
             self.__session.open(enable_rpc_compression)
diff --git a/iotdb-client/client-py/iotdb/dbapi/Cursor.py 
b/iotdb-client/client-py/iotdb/dbapi/Cursor.py
index a1d6e2caaba..018ade6e99a 100644
--- a/iotdb-client/client-py/iotdb/dbapi/Cursor.py
+++ b/iotdb-client/client-py/iotdb/dbapi/Cursor.py
@@ -129,41 +129,32 @@ class Cursor(object):
                     sql_seqs.append(seq)
             sql = "\n".join(sql_seqs)
 
-        try:
-            data_set = self.__session.execute_statement(sql)
-            col_names = None
-            col_types = None
-            rows = []
-
-            if data_set:
-                data = data_set.todf()
-
-                if self.__sqlalchemy_mode and time_index:
-                    time_column = data.columns[0]
-                    time_column_value = data.Time
-                    del data[time_column]
-                    for i in range(len(time_index)):
-                        data.insert(time_index[i], time_names[i], 
time_column_value)
-
-                col_names = data.columns.tolist()
-                col_types = data_set.get_column_types()
-                rows = data.values.tolist()
-                data_set.close_operation_handle()
-
-            self.__result = {
-                "col_names": col_names,
-                "col_types": col_types,
-                "rows": rows,
-                "row_count": len(rows),
-            }
-        except Exception:
-            logger.error("failed to execute statement:{}".format(sql))
-            self.__result = {
-                "col_names": None,
-                "col_types": None,
-                "rows": [],
-                "row_count": -1,
-            }
+        data_set = self.__session.execute_statement(sql)
+        col_names = None
+        col_types = None
+        rows = []
+
+        if data_set:
+            data = data_set.todf()
+
+            if self.__sqlalchemy_mode and time_index:
+                time_column = data.columns[0]
+                time_column_value = data.Time
+                del data[time_column]
+                for i in range(len(time_index)):
+                    data.insert(time_index[i], time_names[i], 
time_column_value)
+
+            col_names = data.columns.tolist()
+            col_types = data_set.get_column_types()
+            rows = data.values.tolist()
+            data_set.close_operation_handle()
+
+        self.__result = {
+            "col_names": col_names,
+            "col_types": col_types,
+            "rows": rows,
+            "row_count": len(rows),
+        }
         self.__rows = iter(self.__result["rows"])
 
     def executemany(self, operation, seq_of_parameters=None):
diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDDLCompiler.py 
b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDDLCompiler.py
new file mode 100644
index 00000000000..0d8f377cb1a
--- /dev/null
+++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDDLCompiler.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from sqlalchemy.sql.compiler import DDLCompiler
+
+
+class IoTDBDDLCompiler(DDLCompiler):
+    def visit_create_column(self, create, first_pk=False, **kw):
+        column = create.element
+
+        if column.system:
+            return None
+
+        category = column.dialect_options["iotdb"].get("category")
+
+        if category and category.upper() == "TIME":
+            colspec = self.preparer.format_column(column) + " TIME"
+            return colspec
+
+        colspec = (
+            self.preparer.format_column(column)
+            + " "
+            + self.dialect.type_compiler_instance.process(
+                column.type, type_expression=column
+            )
+        )
+
+        if category:
+            colspec += " " + category.upper()
+
+        return colspec
+
+    def post_create_table(self, table):
+        ttl = table.dialect_options["iotdb"].get("ttl")
+        if ttl is not None:
+            return " WITH (TTL=%d)" % int(ttl)
+        return ""
+
+    def create_table_constraints(self, table, **kw):
+        return ""
+
+    def visit_primary_key_constraint(self, constraint, **kw):
+        return None
+
+    def visit_foreign_key_constraint(self, constraint, **kw):
+        return None
+
+    def visit_unique_constraint(self, constraint, **kw):
+        return None
+
+    def visit_check_constraint(self, constraint, **kw):
+        return None
diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py 
b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py
index 912e23e9f7a..44f9e860fea 100644
--- a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py
+++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py
@@ -16,58 +16,64 @@
 # under the License.
 #
 
-from sqlalchemy import types, util
+from sqlalchemy import schema as sa_schema, types
 from sqlalchemy.engine import default
 from sqlalchemy.sql import text
-from sqlalchemy.sql.sqltypes import String
 
 from iotdb import dbapi
 
+from .IoTDBDDLCompiler import IoTDBDDLCompiler
 from .IoTDBIdentifierPreparer import IoTDBIdentifierPreparer
 from .IoTDBSQLCompiler import IoTDBSQLCompiler
 from .IoTDBTypeCompiler import IoTDBTypeCompiler
 
-TYPES_MAP = {
+IOTDB_CATEGORY_TIME = "TIME"
+IOTDB_CATEGORY_TAG = "TAG"
+IOTDB_CATEGORY_ATTRIBUTE = "ATTRIBUTE"
+IOTDB_CATEGORY_FIELD = "FIELD"
+
+ischema_names = {
     "BOOLEAN": types.Boolean,
     "INT32": types.Integer,
     "INT64": types.BigInteger,
     "FLOAT": types.Float,
     "DOUBLE": types.Float,
+    "STRING": types.String,
     "TEXT": types.Text,
-    "LONG": types.BigInteger,
+    "BLOB": types.LargeBinary,
+    "TIMESTAMP": types.DateTime,
+    "DATE": types.Date,
 }
 
 
 class IoTDBDialect(default.DefaultDialect):
     name = "iotdb"
-    driver = "iotdb-python"
+    driver = "iotdb"
+
     statement_compiler = IoTDBSQLCompiler
-    type_compiler = IoTDBTypeCompiler
+    ddl_compiler = IoTDBDDLCompiler
+    type_compiler_cls = IoTDBTypeCompiler
     preparer = IoTDBIdentifierPreparer
-    convert_unicode = True
 
-    supports_unicode_statements = True
-    supports_unicode_binds = True
-    supports_simple_order_by_label = False
+    supports_alter = True
     supports_schemas = True
-    supports_right_nested_joins = False
-    description_encoding = None
-
-    if hasattr(String, "RETURNS_UNICODE"):
-        returns_unicode_strings = String.RETURNS_UNICODE
-    else:
-
-        def _check_unicode_returns(self, connection, additional_tests=None):
-            return True
-
-        _check_unicode_returns = _check_unicode_returns
-
-    def create_connect_args(self, url):
-        # inherits the docstring from interfaces.Dialect.create_connect_args
-        opts = url.translate_connect_args()
-        opts.update(url.query)
-        opts.update({"sqlalchemy_mode": True})
-        return [[], opts]
+    supports_sequences = False
+    supports_native_boolean = True
+    supports_native_enum = False
+    supports_statement_cache = True
+    insert_returning = False
+    update_returning = False
+    delete_returning = False
+    supports_default_values = False
+    supports_empty_insert = False
+    postfetch_lastrowid = False
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+
+    construct_arguments = [
+        (sa_schema.Column, {"category": None}),
+        (sa_schema.Table, {"ttl": None}),
+    ]
 
     @classmethod
     def import_dbapi(cls):
@@ -77,8 +83,23 @@ class IoTDBDialect(default.DefaultDialect):
     def dbapi(cls):
         return dbapi
 
-    def has_schema(self, connection, schema):
-        return schema in self.get_schema_names(connection)
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args()
+        opts.update(url.query)
+        opts["sql_dialect"] = "table"
+        return ([], opts)
+
+    def initialize(self, connection):
+        pass
+
+    def _get_server_version_info(self, connection):
+        return None
+
+    def _get_default_schema_name(self, connection):
+        return None
+
+    def has_schema(self, connection, schema_name, **kw):
+        return schema_name in self.get_schema_names(connection)
 
     def has_table(self, connection, table_name, schema=None, **kw):
         return table_name in self.get_table_names(connection, schema=schema)
@@ -88,22 +109,41 @@ class IoTDBDialect(default.DefaultDialect):
         return [row[0] for row in cursor.fetchall()]
 
     def get_table_names(self, connection, schema=None, **kw):
-        cursor = connection.execute(
-            text("SHOW DEVICES %s.**" % (schema or self.default_schema_name))
-        )
-        return [row[0].replace(schema + ".", "", 1) for row in 
cursor.fetchall()]
+        if schema:
+            connection.execute(text("USE %s" % schema))
+        cursor = connection.execute(text("SHOW TABLES"))
+        return [row[0] for row in cursor.fetchall()]
 
     def get_columns(self, connection, table_name, schema=None, **kw):
+        if schema:
+            connection.execute(text("USE %s" % schema))
         cursor = connection.execute(
-            text("SHOW TIMESERIES %s.%s.*" % (schema, table_name))
+            text("SHOW COLUMNS FROM %s" % table_name)
         )
-        columns = [self._general_time_column_info()]
+        columns = []
         for row in cursor.fetchall():
-            columns.append(self._create_column_info(row, schema, table_name))
+            col_name = row[0]
+            col_type_str = row[1]
+            col_category = row[2] if len(row) > 2 else None
+
+            sa_type = ischema_names.get(col_type_str.upper(), 
types.UserDefinedType)
+
+            col_info = {
+                "name": col_name,
+                "type": sa_type() if isinstance(sa_type, type) else sa_type,
+                "nullable": True,
+                "default": None,
+            }
+
+            if col_category:
+                col_info["iotdb_category"] = col_category.upper()
+
+            columns.append(col_info)
+
         return columns
 
     def get_pk_constraint(self, connection, table_name, schema=None, **kw):
-        pass
+        return {"constrained_columns": [], "name": None}
 
     def get_foreign_keys(self, connection, table_name, schema=None, **kw):
         return []
@@ -111,33 +151,11 @@ class IoTDBDialect(default.DefaultDialect):
     def get_indexes(self, connection, table_name, schema=None, **kw):
         return []
 
-    @util.memoized_property
-    def _dialect_specific_select_one(self):
-        # IoTDB does not support select 1
-        # so replace the statement with "show version"
-        return "SHOW VERSION"
-
-    def _general_time_column_info(self):
-        """
-        Treat Time as a column
-        """
-        return {
-            "name": "Time",
-            "type": self._resolve_type("LONG"),
-            "nullable": False,
-            "default": None,
-        }
-
-    def _create_column_info(self, row, schema, table_name):
-        """
-        Generate description information for each column
-        """
-        return {
-            "name": row[0].replace(schema + "." + table_name + ".", "", 1),
-            "type": self._resolve_type(row[3]),
-            "nullable": True,
-            "default": None,
-        }
-
-    def _resolve_type(self, type_):
-        return TYPES_MAP.get(type_, types.UserDefinedType)
+    def get_view_names(self, connection, schema=None, **kw):
+        return []
+
+    def do_commit(self, dbapi_connection):
+        pass
+
+    def do_rollback(self, dbapi_connection):
+        pass
diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py 
b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py
index 008a314e683..08482ac41b9 100644
--- a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py
+++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBSQLCompiler.py
@@ -17,308 +17,7 @@
 #
 
 from sqlalchemy.sql.compiler import SQLCompiler
-from sqlalchemy.sql.compiler import OPERATORS
-from sqlalchemy.sql import operators
 
 
 class IoTDBSQLCompiler(SQLCompiler):
-    def order_by_clause(self, select, **kw):
-        """allow dialects to customize how ORDER BY is rendered."""
-
-        order_by = select._order_by_clause._compiler_dispatch(self, **kw)
-        if "Time" in order_by:
-            return " ORDER BY " + order_by.replace('"', "")
-        else:
-            return ""
-
-    def group_by_clause(self, select, **kw):
-        """allow dialects to customize how GROUP BY is rendered."""
-        return ""
-
-    def visit_select(
-        self,
-        select,
-        asfrom=False,
-        parens=True,
-        fromhints=None,
-        compound_index=0,
-        nested_join_translation=False,
-        select_wraps_for=None,
-        lateral=False,
-        **kwargs,
-    ):
-        """
-        Override this method to solve two problems
-        1. IoTDB does not support querying Time as a measurement name (e.g. 
select Time from root.storagegroup.device)
-        2. IoTDB does not support path.measurement format to determine a 
column (e.g. select root.storagegroup.device.temperature from 
root.storagegroup.device)
-        """
-        assert select_wraps_for is None, (
-            "SQLAlchemy 1.4 requires use of "
-            "the translate_select_structure hook for structural "
-            "translations of SELECT objects"
-        )
-
-        # initial setup of SELECT.  the compile_state_factory may now
-        # be creating a totally different SELECT from the one that was
-        # passed in.  for ORM use this will convert from an ORM-state
-        # SELECT to a regular "Core" SELECT.  other composed operations
-        # such as computation of joins will be performed.
-
-        kwargs["within_columns_clause"] = False
-
-        compile_state = select_stmt._compile_state_factory(select_stmt, self, 
**kwargs)
-        select_stmt = compile_state.statement
-
-        toplevel = not self.stack
-
-        if toplevel and not self.compile_state:
-            self.compile_state = compile_state
-
-        is_embedded_select = compound_index is not None or insert_into
-
-        # translate step for Oracle, SQL Server which often need to
-        # restructure the SELECT to allow for LIMIT/OFFSET and possibly
-        # other conditions
-        if self.translate_select_structure:
-            new_select_stmt = self.translate_select_structure(
-                select_stmt, asfrom=asfrom, **kwargs
-            )
-
-            # if SELECT was restructured, maintain a link to the originals
-            # and assemble a new compile state
-            if new_select_stmt is not select_stmt:
-                compile_state_wraps_for = compile_state
-                select_wraps_for = select_stmt
-                select_stmt = new_select_stmt
-
-                compile_state = select_stmt._compile_state_factory(
-                    select_stmt, self, **kwargs
-                )
-                select_stmt = compile_state.statement
-
-        entry = self._default_stack_entry if toplevel else self.stack[-1]
-
-        populate_result_map = need_column_expressions = (
-            toplevel
-            or entry.get("need_result_map_for_compound", False)
-            or entry.get("need_result_map_for_nested", False)
-        )
-
-        # indicates there is a CompoundSelect in play and we are not the
-        # first select
-        if compound_index:
-            populate_result_map = False
-
-        # this was first proposed as part of #3372; however, it is not
-        # reached in current tests and could possibly be an assertion
-        # instead.
-        if not populate_result_map and "add_to_result_map" in kwargs:
-            del kwargs["add_to_result_map"]
-
-        froms = self._setup_select_stack(
-            select_stmt, compile_state, entry, asfrom, lateral, compound_index
-        )
-
-        column_clause_args = kwargs.copy()
-        column_clause_args.update(
-            {"within_label_clause": False, "within_columns_clause": False}
-        )
-
-        text = "SELECT "  # we're off to a good start !
-
-        if select_stmt._hints:
-            hint_text, byfrom = self._setup_select_hints(select_stmt)
-            if hint_text:
-                text += hint_text + " "
-        else:
-            byfrom = None
-
-        if select_stmt._independent_ctes:
-            for cte in select_stmt._independent_ctes:
-                cte._compiler_dispatch(self, **kwargs)
-
-        if select_stmt._prefixes:
-            text += self._generate_prefixes(
-                select_stmt, select_stmt._prefixes, **kwargs
-            )
-
-        text += self.get_select_precolumns(select_stmt, **kwargs)
-        # the actual list of columns to print in the SELECT column list.
-        inner_columns = [
-            c
-            for c in [
-                self._label_select_column(
-                    select_stmt,
-                    column,
-                    populate_result_map,
-                    asfrom,
-                    column_clause_args,
-                    name=name,
-                    proxy_name=proxy_name,
-                    fallback_label_name=fallback_label_name,
-                    column_is_repeated=repeated,
-                    need_column_expressions=need_column_expressions,
-                )
-                for (
-                    name,
-                    proxy_name,
-                    fallback_label_name,
-                    column,
-                    repeated,
-                ) in compile_state.columns_plus_names
-            ]
-            if c is not None
-        ]
-
-        if populate_result_map and select_wraps_for is not None:
-            # if this select was generated from translate_select,
-            # rewrite the targeted columns in the result map
-
-            translate = dict(
-                zip(
-                    [
-                        name
-                        for (
-                            key,
-                            proxy_name,
-                            fallback_label_name,
-                            name,
-                            repeated,
-                        ) in compile_state.columns_plus_names
-                    ],
-                    [
-                        name
-                        for (
-                            key,
-                            proxy_name,
-                            fallback_label_name,
-                            name,
-                            repeated,
-                        ) in compile_state_wraps_for.columns_plus_names
-                    ],
-                )
-            )
-
-            self._result_columns = [
-                (key, name, tuple(translate.get(o, o) for o in obj), type_)
-                for key, name, obj, type_ in self._result_columns
-            ]
-
-        # change the superset aggregate function name into iotdb aggregate 
function name
-        # by matching the head of aggregate function name and replace it.
-        for i in range(len(inner_columns)):
-            if inner_columns[i].startswith("max("):
-                inner_columns[i] = inner_columns[i].replace("max(", 
"max_value(")
-            if inner_columns[i].startswith("min("):
-                inner_columns[i] = inner_columns[i].replace("min(", 
"min_value(")
-            if inner_columns[i].startswith("count(DISTINCT"):
-                inner_columns[i] = inner_columns[i].replace("count(DISTINCT", 
"count(")
-
-        # IoTDB does not allow to query Time as column,
-        # need to filter out Time and pass Time and Time's alias to DBAPI 
separately
-        # to achieve the query of Time by encoding.
-        time_column_index = []
-        time_column_names = []
-        for i in range(len(inner_columns)):
-            column_strs = (
-                inner_columns[i].replace(self.preparer.initial_quote, 
"").split()
-            )
-            if "Time" in column_strs:
-                time_column_index.append(str(i))
-                time_column_names.append(
-                    column_strs[2]
-                    if OPERATORS[operators.as_] in column_strs
-                    else column_strs[0]
-                )
-        # delete Time column
-        inner_columns = list(
-            filter(
-                lambda x: "Time"
-                not in x.replace(self.preparer.initial_quote, "").split(),
-                inner_columns,
-            )
-        )
-
-        if inner_columns and time_column_index:
-            inner_columns[-1] = (
-                inner_columns[-1]
-                + " \n FROM Time Index "
-                + " ".join(time_column_index)
-                + " \n FROM Time Name "
-                + " ".join(time_column_names)
-            )
-
-        text = self._compose_select_body(
-            text,
-            select_stmt,
-            compile_state,
-            inner_columns,
-            froms,
-            byfrom,
-            toplevel,
-            kwargs,
-        )
-
-        if select_stmt._statement_hints:
-            per_dialect = [
-                ht
-                for (dialect_name, ht) in select_stmt._statement_hints
-                if dialect_name in ("*", self.dialect.name)
-            ]
-            if per_dialect:
-                text += " " + self.get_statement_hint_text(per_dialect)
-
-        # In compound query, CTEs are shared at the compound level
-        if self.ctes and (not is_embedded_select or toplevel):
-            nesting_level = len(self.stack) if not toplevel else None
-            text = (
-                self._render_cte_clause(
-                    nesting_level=nesting_level,
-                    visiting_cte=kwargs.get("visiting_cte"),
-                )
-                + text
-            )
-
-        if select_stmt._suffixes:
-            text += " " + self._generate_prefixes(
-                select_stmt, select_stmt._suffixes, **kwargs
-            )
-
-        self.stack.pop(-1)
-        return text
-
-    def visit_table(
-        self,
-        table,
-        asfrom=False,
-        iscrud=False,
-        ashint=False,
-        fromhints=None,
-        use_schema=True,
-        **kwargs,
-    ):
-        """
-        IoTDB's table does not support quotation marks (e.g. select ** from 
`root.`)
-        need to override this method
-        """
-        if asfrom or ashint:
-            effective_schema = self.preparer.schema_for_object(table)
-
-            if use_schema and effective_schema:
-                ret = effective_schema + "." + table.name
-            else:
-                ret = table.name
-            if fromhints and table in fromhints:
-                ret = self.format_from_hint_text(ret, table, fromhints[table], 
iscrud)
-            return ret
-        else:
-            return ""
-
-    def visit_column(
-        self, column, add_to_result_map=None, include_table=True, **kwargs
-    ):
-        """
-        IoTDB's where statement does not support "table".column format(e.g. 
"table".column > 1)
-        need to override this method to return the name of column directly
-        """
-        return column.name
+    pass
diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBTypeCompiler.py 
b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBTypeCompiler.py
index 4cfd2480bd4..1fb80185bd5 100644
--- a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBTypeCompiler.py
+++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBTypeCompiler.py
@@ -20,11 +20,17 @@ from sqlalchemy.sql.compiler import GenericTypeCompiler
 
 
 class IoTDBTypeCompiler(GenericTypeCompiler):
+    def visit_BOOLEAN(self, type_, **kw):
+        return "BOOLEAN"
+
     def visit_FLOAT(self, type_, **kw):
         return "FLOAT"
 
+    def visit_REAL(self, type_, **kw):
+        return "FLOAT"
+
     def visit_NUMERIC(self, type_, **kw):
-        return "INT64"
+        return "DOUBLE"
 
     def visit_DECIMAL(self, type_, **kw):
         return "DOUBLE"
@@ -36,10 +42,76 @@ class IoTDBTypeCompiler(GenericTypeCompiler):
         return "INT32"
 
     def visit_BIGINT(self, type_, **kw):
-        return "LONG"
+        return "INT64"
 
     def visit_TIMESTAMP(self, type_, **kw):
-        return "LONG"
+        return "TIMESTAMP"
+
+    def visit_DATETIME(self, type_, **kw):
+        return "TIMESTAMP"
+
+    def visit_DATE(self, type_, **kw):
+        return "DATE"
+
+    def visit_TEXT(self, type_, **kw):
+        return "STRING"
+
+    def visit_VARCHAR(self, type_, **kw):
+        return "STRING"
+
+    def visit_NVARCHAR(self, type_, **kw):
+        return "STRING"
+
+    def visit_CHAR(self, type_, **kw):
+        return "STRING"
+
+    def visit_BLOB(self, type_, **kw):
+        return "BLOB"
+
+    def visit_BINARY(self, type_, **kw):
+        return "BLOB"
+
+    def visit_VARBINARY(self, type_, **kw):
+        return "BLOB"
+
+    def visit_LARGE_BINARY(self, type_, **kw):
+        return "BLOB"
+
+    def visit_large_binary(self, type_, **kw):
+        return "BLOB"
+
+    def visit_boolean(self, type_, **kw):
+        return "BOOLEAN"
+
+    def visit_string(self, type_, **kw):
+        return "STRING"
+
+    def visit_unicode(self, type_, **kw):
+        return "STRING"
 
     def visit_text(self, type_, **kw):
-        return "TEXT"
+        return "STRING"
+
+    def visit_unicode_text(self, type_, **kw):
+        return "STRING"
+
+    def visit_float(self, type_, **kw):
+        return "FLOAT"
+
+    def visit_numeric(self, type_, **kw):
+        return "DOUBLE"
+
+    def visit_integer(self, type_, **kw):
+        return "INT32"
+
+    def visit_big_integer(self, type_, **kw):
+        return "INT64"
+
+    def visit_timestamp(self, type_, **kw):
+        return "TIMESTAMP"
+
+    def visit_datetime(self, type_, **kw):
+        return "TIMESTAMP"
+
+    def visit_date(self, type_, **kw):
+        return "DATE"
diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/__init__.py 
b/iotdb-client/client-py/iotdb/sqlalchemy/__init__.py
index 2a1e720805f..7a6b7223a6f 100644
--- a/iotdb-client/client-py/iotdb/sqlalchemy/__init__.py
+++ b/iotdb-client/client-py/iotdb/sqlalchemy/__init__.py
@@ -15,3 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+
+from .IoTDBDialect import IoTDBDialect
+
+__all__ = ["IoTDBDialect"]
diff --git 
a/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py 
b/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py
index 9ed0c808917..c5bb2fa3029 100644
--- a/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py
+++ b/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py
@@ -16,96 +16,236 @@
 # under the License.
 #
 
-import operator
-
-from sqlalchemy import create_engine, inspect
+from sqlalchemy import (
+    create_engine,
+    inspect,
+    Column,
+    Float,
+    Integer,
+    BigInteger,
+    String,
+    Boolean,
+    Table,
+    MetaData,
+)
 from sqlalchemy.dialects import registry
-from sqlalchemy.orm import Session
 from sqlalchemy.sql import text
 from tests.integration.iotdb_container import IoTDBContainer
 from urllib.parse import quote_plus as urlquote
 
-final_flag = True
-failed_count = 0
-
-
-def test_fail():
-    global failed_count
-    global final_flag
-    final_flag = False
-    failed_count += 1
 
+TEST_DB = "test_sqlalchemy"
 
-def print_message(message):
-    print("*********")
-    print(message)
-    print("*********")
-    assert False
 
-
-def test_dialect():
+def test_table_model_dialect():
 
     with IoTDBContainer("iotdb:dev") as db:
         db: IoTDBContainer
-        password = urlquote("root")
+        password = urlquote("TimechoDB@2021")
         host = db.get_container_host_ip()
         port = db.get_exposed_port(6667)
-        url = f"iotdb://root:{password}@{host}:{port}"
+        url = f"iotdb://root:{password}@{host}:{port}/{TEST_DB}"
         registry.register("iotdb", "iotdb.sqlalchemy.IoTDBDialect", 
"IoTDBDialect")
-        eng = create_engine(url)
-
-        with Session(eng) as session:
-            session.execute(text("create database root.cursor"))
-            session.execute(text("create database root.cursor_s1"))
-            session.execute(
-                text(
-                    "create timeseries root.cursor.device1.temperature with 
datatype=FLOAT,encoding=RLE"
-                )
+        engine = create_engine(url)
+
+        with engine.connect() as conn:
+            conn.execute(text("CREATE DATABASE %s" % TEST_DB))
+
+        _test_ddl(engine)
+        _test_dml(engine)
+        _test_reflection(engine)
+        _test_time_column(engine)
+
+        with engine.connect() as conn:
+            conn.execute(text("DROP DATABASE %s" % TEST_DB))
+
+        engine.dispose()
+        print("All table model dialect tests passed!")
+
+
+def _test_ddl(engine):
+    metadata = MetaData()
+
+    sensors = Table(
+        "sensors",
+        metadata,
+        Column("region", String, iotdb_category="TAG"),
+        Column("device_id", String, iotdb_category="TAG"),
+        Column("model", String, iotdb_category="ATTRIBUTE"),
+        Column("temperature", Float, iotdb_category="FIELD"),
+        Column("humidity", Float, iotdb_category="FIELD"),
+        Column("status", Boolean, iotdb_category="FIELD"),
+        schema=TEST_DB,
+        iotdb_ttl=3600000,
+    )
+
+    metadata.create_all(engine)
+
+    insp = inspect(engine)
+    table_names = insp.get_table_names(schema=TEST_DB)
+    assert "sensors" in table_names, (
+        "CREATE TABLE failed: 'sensors' not in %s" % table_names
+    )
+
+    sensors.drop(engine)
+    table_names = insp.get_table_names(schema=TEST_DB)
+    assert "sensors" not in table_names, (
+        "DROP TABLE failed: 'sensors' still in %s" % table_names
+    )
+
+    print("  DDL tests passed")
+
+
+def _test_dml(engine):
+    metadata = MetaData()
+    sensors = Table(
+        "sensors_dml",
+        metadata,
+        Column("region", String, iotdb_category="TAG"),
+        Column("device_id", String, iotdb_category="TAG"),
+        Column("temperature", Float, iotdb_category="FIELD"),
+        Column("humidity", Float, iotdb_category="FIELD"),
+        schema=TEST_DB,
+    )
+    metadata.create_all(engine)
+
+    with engine.connect() as conn:
+        conn.execute(
+            sensors.insert().values(
+                region="asia",
+                device_id="d001",
+                temperature=25.5,
+                humidity=60.0,
             )
-            session.execute(
-                text(
-                    "create timeseries root.cursor.device1.status with 
datatype=FLOAT,encoding=RLE"
-                )
+        )
+        conn.execute(
+            sensors.insert().values(
+                region="europe",
+                device_id="d002",
+                temperature=18.3,
+                humidity=75.0,
             )
-            session.execute(
-                text(
-                    "create timeseries root.cursor.device2.temperature with 
datatype=FLOAT,encoding=RLE"
-                )
+        )
+
+        result = conn.execute(sensors.select()).fetchall()
+        assert len(result) == 2, "INSERT/SELECT failed: expected 2 rows, got 
%d" % len(
+            result
+        )
+
+        result = conn.execute(
+            sensors.select().where(sensors.c.region == "asia")
+        ).fetchall()
+        assert len(result) == 1, (
+            "SELECT WHERE failed: expected 1 row, got %d" % len(result)
+        )
+
+        result = conn.execute(
+            sensors.select().order_by(sensors.c.temperature).limit(1)
+        ).fetchall()
+        assert len(result) == 1, "LIMIT failed: expected 1 row, got %d" % 
len(result)
+
+        conn.execute(
+            sensors.delete().where(sensors.c.device_id == "d002")
+        )
+        result = conn.execute(sensors.select()).fetchall()
+        assert len(result) == 1, (
+            "DELETE failed: expected 1 row after delete, got %d" % len(result)
+        )
+
+    sensors.drop(engine)
+    print("  DML tests passed")
+
+
+def _test_reflection(engine):
+    with engine.connect() as conn:
+        conn.execute(text("USE %s" % TEST_DB))
+        conn.execute(
+            text(
+                "CREATE TABLE reflect_test ("
+                "region STRING TAG, "
+                "device STRING TAG, "
+                "model STRING ATTRIBUTE, "
+                "temperature FLOAT FIELD, "
+                "humidity DOUBLE FIELD"
+                ")"
             )
-
-        insp = inspect(eng)
-        # test get_schema_names
-        schema_names = insp.get_schema_names()
-        if not operator.ge(
-            schema_names, ["root.__audit", "root.cursor", "root.cursor_s1"]
-        ):
-            test_fail()
-            print_message("Actual result " + str(schema_names))
-            print_message("test get_schema_names failed!")
-        # test get_table_names
-        table_names = insp.get_table_names("root.cursor")
-        if not operator.eq(table_names, ["device1", "device2"]):
-            test_fail()
-            print_message("Actual result " + str(table_names))
-            print_message("test get_table_names failed!")
-        # test get_columns
-        columns = insp.get_columns(table_name="device1", schema="root.cursor")
-        if len(columns) != 3:
-            test_fail()
-            print_message("Actual result " + str(columns))
-            print_message("test get_columns failed!")
-
-        with Session(eng) as session:
-            session.execute(text("delete database root.cursor"))
-            session.execute(text("delete database root.cursor_s1"))
-
-        # close engine
-        eng.dispose()
-
-
-if final_flag:
-    print("All executions done!!")
-else:
-    print("Some test failed, please have a check")
-    print("failed count: ", failed_count)
-    exit(1)
+        )
+
+    insp = inspect(engine)
+
+    schemas = insp.get_schema_names()
+    assert TEST_DB in schemas, "%s not in schemas: %s" % (TEST_DB, schemas)
+
+    tables = insp.get_table_names(schema=TEST_DB)
+    assert "reflect_test" in tables, "reflect_test not in tables: %s" % tables
+
+    columns = insp.get_columns(table_name="reflect_test", schema=TEST_DB)
+    col_names = [c["name"] for c in columns]
+    assert "region" in col_names, "region not in columns: %s" % col_names
+    assert "temperature" in col_names, "temperature not in columns: %s" % 
col_names
+
+    pk = insp.get_pk_constraint(table_name="reflect_test", schema=TEST_DB)
+    assert pk["constrained_columns"] == [], "Expected empty PK constraint"
+
+    fks = insp.get_foreign_keys(table_name="reflect_test", schema=TEST_DB)
+    assert fks == [], "Expected empty FK list"
+
+    indexes = insp.get_indexes(table_name="reflect_test", schema=TEST_DB)
+    assert indexes == [], "Expected empty index list"
+
+    with engine.connect() as conn:
+        conn.execute(text("USE %s" % TEST_DB))
+        conn.execute(text("DROP TABLE reflect_test"))
+
+    print("  Reflection tests passed")
+
+
+def _test_time_column(engine):
+    metadata = MetaData()
+
+    table_implicit = Table(
+        "time_implicit",
+        metadata,
+        Column("device", String, iotdb_category="TAG"),
+        Column("value", Float, iotdb_category="FIELD"),
+        schema=TEST_DB,
+    )
+    metadata.create_all(engine)
+
+    with engine.connect() as conn:
+        conn.execute(
+            table_implicit.insert().values(device="d001", value=42.0)
+        )
+        result = conn.execute(table_implicit.select()).fetchall()
+        assert len(result) == 1, (
+            "Implicit TIME insert failed: expected 1 row, got %d" % len(result)
+        )
+
+    table_implicit.drop(engine)
+
+    metadata2 = MetaData()
+    table_explicit = Table(
+        "time_explicit",
+        metadata2,
+        Column("ts", BigInteger, iotdb_category="TIME"),
+        Column("device", String, iotdb_category="TAG"),
+        Column("value", Float, iotdb_category="FIELD"),
+        schema=TEST_DB,
+    )
+    metadata2.create_all(engine)
+
+    with engine.connect() as conn:
+        conn.execute(
+            table_explicit.insert().values(ts=1000000, device="d001", 
value=42.0)
+        )
+        result = conn.execute(table_explicit.select()).fetchall()
+        assert len(result) == 1, (
+            "Explicit TIME insert failed: expected 1 row, got %d" % len(result)
+        )
+
+    table_explicit.drop(engine)
+    print("  TIME column tests passed")
+
+
+if __name__ == "__main__":
+    test_table_model_dialect()


Reply via email to