This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 46aa4294e4 Use `dialect.name` in custom SA types (#33503)
46aa4294e4 is described below
commit 46aa4294e453d800ef6d327addf72a004be3765f
Author: Andrey Anshin <[email protected]>
AuthorDate: Fri Aug 18 23:40:52 2023 +0400
Use `dialect.name` in custom SA types (#33503)
* Use `dialect.name` in custom SA types
* Fix removed import
---
airflow/utils/sqlalchemy.py | 39 +++++++++++++++------------------------
1 file changed, 15 insertions(+), 24 deletions(-)
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 38716d4eb5..b4b726ca32 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -30,23 +30,22 @@ from sqlalchemy import TIMESTAMP, PickleType, and_, event,
false, nullsfirst, or
from sqlalchemy.dialects import mssql, mysql
from sqlalchemy.exc import OperationalError
from sqlalchemy.sql import ColumnElement, Select
-from sqlalchemy.sql.expression import ColumnOperators
from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
from airflow import settings
from airflow.configuration import conf
from airflow.serialization.enums import Encoding
+from airflow.utils.timezone import make_naive
if TYPE_CHECKING:
from kubernetes.client.models.v1_pod import V1Pod
from sqlalchemy.orm import Query, Session
+ from sqlalchemy.sql.expression import ColumnOperators
log = logging.getLogger(__name__)
utc = pendulum.tz.timezone("UTC")
-using_mysql = conf.get_mandatory_value("database",
"sql_alchemy_conn").lower().startswith("mysql")
-
class UtcDateTime(TypeDecorator):
"""
@@ -67,22 +66,18 @@ class UtcDateTime(TypeDecorator):
cache_ok = True
def process_bind_param(self, value, dialect):
- if value is not None:
- if not isinstance(value, datetime.datetime):
- raise TypeError("expected datetime.datetime, not " +
repr(value))
- elif value.tzinfo is None:
- raise ValueError("naive datetime is disallowed")
+ if not isinstance(value, datetime.datetime):
+ if value is None:
+ return None
+ raise TypeError("expected datetime.datetime, not " + repr(value))
+ elif value.tzinfo is None:
+ raise ValueError("naive datetime is disallowed")
+ elif dialect.name == "mysql":
# For mysql we should store timestamps as naive values
- # Timestamp in MYSQL is not timezone aware. In MySQL 5.6
- # timezone added at the end is ignored but in MySQL 5.7
- # inserting timezone value fails with 'invalid-date'
+ # In MySQL 5.7 inserting timezone value fails with 'invalid-date'
# See https://issues.apache.org/jira/browse/AIRFLOW-7001
- if using_mysql:
- from airflow.utils.timezone import make_naive
-
- return make_naive(value, timezone=utc)
- return value.astimezone(utc)
- return None
+ return make_naive(value, timezone=utc)
+ return value.astimezone(utc)
def process_result_value(self, value, dialect):
"""
@@ -119,12 +114,8 @@ class ExtendedJSON(TypeDecorator):
cache_ok = True
- def db_supports_json(self):
- """Check if the database supports JSON (i.e. is NOT MSSQL)."""
- return not conf.get("database", "sql_alchemy_conn").startswith("mssql")
-
def load_dialect_impl(self, dialect) -> TypeEngine:
- if self.db_supports_json():
+ if dialect.name != "mssql":
return dialect.type_descriptor(JSON)
return dialect.type_descriptor(UnicodeText)
@@ -138,7 +129,7 @@ class ExtendedJSON(TypeDecorator):
value = BaseSerialization.serialize(value)
# Then, if the database does not have native JSON support, encode it
again as a string
- if not self.db_supports_json():
+ if dialect.name == "mssql":
value = json.dumps(value)
return value
@@ -150,7 +141,7 @@ class ExtendedJSON(TypeDecorator):
return None
# Deserialize from a string first if needed
- if not self.db_supports_json():
+ if dialect.name == "mssql":
value = json.loads(value)
return BaseSerialization.deserialize(value)