This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 46aa4294e4 Use `dialect.name` in custom SA types (#33503)
46aa4294e4 is described below

commit 46aa4294e453d800ef6d327addf72a004be3765f
Author: Andrey Anshin <[email protected]>
AuthorDate: Fri Aug 18 23:40:52 2023 +0400

    Use `dialect.name` in custom SA types (#33503)
    
    * Use `dialect.name` in custom SA types
    
    * Fix removed import
---
 airflow/utils/sqlalchemy.py | 39 +++++++++++++++------------------------
 1 file changed, 15 insertions(+), 24 deletions(-)

diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 38716d4eb5..b4b726ca32 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -30,23 +30,22 @@ from sqlalchemy import TIMESTAMP, PickleType, and_, event, 
false, nullsfirst, or
 from sqlalchemy.dialects import mssql, mysql
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.sql import ColumnElement, Select
-from sqlalchemy.sql.expression import ColumnOperators
 from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
 
 from airflow import settings
 from airflow.configuration import conf
 from airflow.serialization.enums import Encoding
+from airflow.utils.timezone import make_naive
 
 if TYPE_CHECKING:
     from kubernetes.client.models.v1_pod import V1Pod
     from sqlalchemy.orm import Query, Session
+    from sqlalchemy.sql.expression import ColumnOperators
 
 log = logging.getLogger(__name__)
 
 utc = pendulum.tz.timezone("UTC")
 
-using_mysql = conf.get_mandatory_value("database", 
"sql_alchemy_conn").lower().startswith("mysql")
-
 
 class UtcDateTime(TypeDecorator):
     """
@@ -67,22 +66,18 @@ class UtcDateTime(TypeDecorator):
     cache_ok = True
 
     def process_bind_param(self, value, dialect):
-        if value is not None:
-            if not isinstance(value, datetime.datetime):
-                raise TypeError("expected datetime.datetime, not " + 
repr(value))
-            elif value.tzinfo is None:
-                raise ValueError("naive datetime is disallowed")
+        if not isinstance(value, datetime.datetime):
+            if value is None:
+                return None
+            raise TypeError("expected datetime.datetime, not " + repr(value))
+        elif value.tzinfo is None:
+            raise ValueError("naive datetime is disallowed")
+        elif dialect.name == "mysql":
             # For mysql we should store timestamps as naive values
-            # Timestamp in MYSQL is not timezone aware. In MySQL 5.6
-            # timezone added at the end is ignored but in MySQL 5.7
-            # inserting timezone value fails with 'invalid-date'
+            # In MySQL 5.7 inserting timezone value fails with 'invalid-date'
             # See https://issues.apache.org/jira/browse/AIRFLOW-7001
-            if using_mysql:
-                from airflow.utils.timezone import make_naive
-
-                return make_naive(value, timezone=utc)
-            return value.astimezone(utc)
-        return None
+            return make_naive(value, timezone=utc)
+        return value.astimezone(utc)
 
     def process_result_value(self, value, dialect):
         """
@@ -119,12 +114,8 @@ class ExtendedJSON(TypeDecorator):
 
     cache_ok = True
 
-    def db_supports_json(self):
-        """Check if the database supports JSON (i.e. is NOT MSSQL)."""
-        return not conf.get("database", "sql_alchemy_conn").startswith("mssql")
-
     def load_dialect_impl(self, dialect) -> TypeEngine:
-        if self.db_supports_json():
+        if dialect.name != "mssql":
             return dialect.type_descriptor(JSON)
         return dialect.type_descriptor(UnicodeText)
 
@@ -138,7 +129,7 @@ class ExtendedJSON(TypeDecorator):
         value = BaseSerialization.serialize(value)
 
         # Then, if the database does not have native JSON support, encode it 
again as a string
-        if not self.db_supports_json():
+        if dialect.name == "mssql":
             value = json.dumps(value)
 
         return value
@@ -150,7 +141,7 @@ class ExtendedJSON(TypeDecorator):
             return None
 
         # Deserialize from a string first if needed
-        if not self.db_supports_json():
+        if dialect.name == "mssql":
             value = json.loads(value)
 
         return BaseSerialization.deserialize(value)

Reply via email to