ashb commented on a change in pull request #9973:
URL: https://github.com/apache/airflow/pull/9973#discussion_r629342372



##########
File path: airflow/api_connexion/endpoints/user_endpoint.py
##########
@@ -49,7 +49,7 @@ def get_users(limit, order_by='id', offset=None):
     to_replace = {"user_id": "id"}
     allowed_filter_attrs = [
         "user_id",
-        'id',
+        "id",

Review comment:
       ```suggestion
           'id',
   ```
   
   Otherwise unchanged file  -- let's keep our git history clean.

##########
File path: airflow/models/serialized_dag.py
##########
@@ -313,7 +315,9 @@ def get_dag_dependencies(cls, session: Session = None) -> 
Dict[str, List['DagDep
         if session.bind.dialect.name in ["sqlite", "mysql"]:
             for row in session.query(cls.dag_id, func.json_extract(cls.data, 
"$.dag.dag_dependencies")).all():
                 dependencies[row[0]] = [DagDependency(**d) for d in 
json.loads(row[1])]
-
+        elif session.bind.dialect.name in ["mssql"]:

Review comment:
       ```suggestion
           elif session.bind.dialect.name == "mssql":
   ```

##########
File path: 
airflow/migrations/versions/83f031fd9f1c_improve_mssql_compatibility.py
##########
@@ -0,0 +1,262 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""improve mssql compatibility
+
+Revision ID: 83f031fd9f1c
+Revises: a13f7613ad25
+Create Date: 2021-04-06 12:22:02.197726
+
+"""
+
+from collections import defaultdict
+
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy.dialects import mssql
+
+# revision identifiers, used by Alembic.
+revision = '83f031fd9f1c'
+down_revision = 'a13f7613ad25'
+branch_labels = None
+depends_on = None
+
+
+def is_table_empty(conn, table_name):
+    """
+    This function checks if the mssql table is empty
+    :param conn: sql connection object
+    :param table_name: table name
+    :return: Booelan indicating if the table is present
+    """
+    return conn.execute(f'select TOP 1 * from {table_name}').first() is None
+
+
+def get_table_constraints(conn, table_name):
+    """
+    This function return primary and unique constraint
+    along with column name. some tables like task_instance
+    is missing primary key constraint name and the name is
+    auto-generated by sql server. so this function helps to
+    retrieve any primary or unique constraint name.
+
+    :param conn: sql connection object
+    :param table_name: table name
+    :return: a dictionary of ((constraint name, constraint type), column name) 
of table
+    :rtype: defaultdict(list)
+    """
+    query = """SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
+     FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
+     JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON 
ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
+     WHERE tc.TABLE_NAME = '{table_name}' AND
+     (tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 
'UNIQUE')
+    """.format(
+        table_name=table_name
+    )
+    result = conn.execute(query).fetchall()
+    constraint_dict = defaultdict(list)
+    for constraint, constraint_type, column in result:
+        constraint_dict[(constraint, constraint_type)].append(column)
+    return constraint_dict
+
+
+def drop_column_constraints(operator, column_name, constraint_dict):
+    """
+    Drop a primary key or unique constraint
+
+    :param operator: batch_alter_table for the table
+    :param constraint_dict: a dictionary of ((constraint name, constraint 
type), column name) of table
+    """
+    for constraint, columns in constraint_dict.items():
+        if column_name in columns:
+            if constraint[1].lower().startswith("primary"):
+                operator.drop_constraint(constraint[0], type_='primary')
+            elif constraint[1].lower().startswith("unique"):
+                operator.drop_constraint(constraint[0], type_='unique')
+
+
+def create_constraints(operator, column_name, constraint_dict):
+    """
+    Create a primary key or unique constraint
+
+    :param operator: batch_alter_table for the table
+    :param constraint_dict: a dictionary of ((constraint name, constraint 
type), column name) of table
+    """
+    for constraint, columns in constraint_dict.items():
+        if column_name in columns:
+            if constraint[1].lower().startswith("primary"):
+                operator.create_primary_key(constraint_name=constraint[0], 
columns=columns)
+            elif constraint[1].lower().startswith("unique"):
+                
operator.create_unique_constraint(constraint_name=constraint[0], 
columns=columns)
+
+
+def _use_date_time2(conn):
+    result = conn.execute(
+        """SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY 
('productversion'))
+        like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY 
('productversion'))
+        like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion"""
+    ).fetchone()
+    mssql_version = result[0]
+    return mssql_version not in ("2000", "2005")
+
+
+def _is_timestamp(conn, table_name, column_name):
+    query = f"""SELECT
+    TYPE_NAME(C.USER_TYPE_ID) AS DATA_TYPE
+    FROM SYS.COLUMNS C
+    JOIN SYS.TYPES T
+    ON C.USER_TYPE_ID=T.USER_TYPE_ID
+    WHERE C.OBJECT_ID=OBJECT_ID('{table_name}') and C.NAME='{column_name}';
+    """
+    column_type = conn.execute(query).fetchone()[0]
+    return column_type == "timestamp"
+
+
+def recreate_mssql_ts_column(conn, op, table_name, column_name):
+    """
+    Drop the timestamp column and recreate it as
+    datetime or datetime2(6)
+    """
+    if _is_timestamp(conn, table_name, column_name) and is_table_empty(conn, 
table_name):
+        with op.batch_alter_table(table_name) as batch_op:
+            constraint_dict = get_table_constraints(conn, table_name)
+            drop_column_constraints(batch_op, column_name, constraint_dict)
+            batch_op.drop_column(column_name=column_name)
+            if _use_date_time2(conn):
+                batch_op.add_column(sa.Column(column_name, 
mssql.DATETIME2(precision=6), nullable=False))
+            else:
+                batch_op.add_column(sa.Column(column_name, mssql.DATETIME, 
nullable=False))
+            create_constraints(batch_op, column_name, constraint_dict)
+
+
+def alter_mssql_datetime_column(conn, op, table_name, column_name, nullable):
+    """Update the datetime column to datetime2(6)"""
+    if _use_date_time2(conn):
+        op.alter_column(
+            table_name=table_name,
+            column_name=column_name,
+            type_=mssql.DATETIME2(precision=6),
+            nullable=nullable,
+        )
+
+
+def alter_mssql_datetime2_column(conn, op, table_name, column_name, nullable):
+    """Update the datetime2(6) column to datetime"""
+    if _use_date_time2(conn):
+        op.alter_column(
+            table_name=table_name, column_name=column_name, 
type_=mssql.DATETIME, nullable=nullable
+        )
+
+
+def _get_timestamp(conn):
+    result = conn.execute(
+        """SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY 
('productversion'))
+        like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY 
('productversion'))
+        like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion"""
+    ).fetchone()
+    mssql_version = result[0]
+    if mssql_version not in ("2000", "2005"):

Review comment:
       Duplicated logic from `_use_date_time2` fn.

##########
File path: airflow/models/serialized_dag.py
##########
@@ -114,14 +114,17 @@ def write_dag(cls, dag: DAG, min_update_interval: 
Optional[int] = None, session:
         # If Yes, does nothing
         # If No or the DAG does not exists, updates / writes Serialized DAG to 
DB
         if min_update_interval is not None:
-            if session.query(
-                exists().where(
+            if (
+                session.query(literal(True))
+                .filter(
                     and_(
                         cls.dag_id == dag.dag_id,
                         (timezone.utcnow() - 
timedelta(seconds=min_update_interval)) < cls.last_updated,
                     )
                 )
-            ).scalar():
+                .first()
+                is not None

Review comment:
       Could you add a todo comment here to change this back once we are on SQL 
1.4-- looks like the fixed the issue.?

##########
File path: airflow/sensors/smart_sensor.py
##########
@@ -391,13 +391,21 @@ def _update_ti_hostname(self, sensor_works, session=None):
         :param session: The sqlalchemy session.
         """
         TI = TaskInstance
-        ti_keys = [(x.dag_id, x.task_id, x.execution_date) for x in 
sensor_works]
 
         def update_ti_hostname_with_count(count, ti_keys):
             # Using or_ instead of in_ here to prevent from full table scan.
             tis = (
                 session.query(TI)
-                .filter(or_(tuple_(TI.dag_id, TI.task_id, TI.execution_date) 
== ti_key for ti_key in ti_keys))
+                .filter(
+                    or_(
+                        and_(
+                            TI.dag_id == ti_key.dag_id,
+                            TI.task_id == ti_key.task_id,
+                            TI.execution_date == ti_key.execution_date,
+                        )
+                        for ti_key in ti_keys
+                    )
+                )

Review comment:
       Can you make this stay as it was for non-mssql please? Big queries with 
lots of x OR y OR z is slow to build.

##########
File path: airflow/models/dag.py
##########
@@ -896,7 +897,7 @@ def get_num_active_runs(self, external_trigger=None, 
session=None):
         )
 
         if external_trigger is not None:
-            query = query.filter(DagRun.external_trigger == external_trigger)
+            query = query.filter(DagRun.external_trigger == 
expression.literal(external_trigger))

Review comment:
       Not sure about this one -- feels a bit risky.
   
   ```suggestion
               query = query.filter(DagRun.external_trigger == 
(expression.true() if external_trigger else expression.false())))
   ```
   
   feels less prone to risk of SQLi attack

##########
File path: airflow/www/security.py
##########
@@ -348,19 +348,21 @@ def can_read_dag(self, dag_id, user=None) -> bool:
         if not user:
             user = g.user
         dag_resource_name = permissions.resource_name_for_dag(dag_id)
-        return self._has_view_access(
-            user, permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG
-        ) or self._has_view_access(user, permissions.ACTION_CAN_READ, 
dag_resource_name)
+        return bool(
+            self._has_view_access(user, permissions.ACTION_CAN_READ, 
permissions.RESOURCE_DAG)
+            or self._has_view_access(user, permissions.ACTION_CAN_READ, 
dag_resource_name)
+        )

Review comment:
       All the changes in this file should be unnecssary -- can you revert them 
please?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to