This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 27dc7e80df Optimize `SnowflakeSqlApiOperator` execution in deferrable 
mode (#36850)
27dc7e80df is described below

commit 27dc7e80df3ecf5aa61718334f32a1d128b0125c
Author: vatsrahul1001 <43964496+vatsrahul1...@users.noreply.github.com>
AuthorDate: Thu Jan 18 19:33:04 2024 +0530

    Optimize `SnowflakeSqlApiOperator` execution in deferrable mode (#36850)
---
 airflow/providers/snowflake/operators/snowflake.py | 15 +++++
 .../snowflake/operators/test_snowflake.py          | 67 +++++++++++++++++++++-
 2 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/snowflake/operators/snowflake.py 
b/airflow/providers/snowflake/operators/snowflake.py
index 9e0bf3d1cf..f7890b87e1 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -514,6 +514,21 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
         if self.do_xcom_push:
             context["ti"].xcom_push(key="query_ids", value=self.query_ids)
 
+        succeeded_query_ids = []
+        for query_id in self.query_ids:
+            self.log.info("Retrieving status for query id %s", query_id)
+            statement_status = self._hook.get_sql_api_query_status(query_id)
+            if statement_status.get("status") == "running":
+                break
+            elif statement_status.get("status") == "success":
+                succeeded_query_ids.append(query_id)
+            else:
+                raise AirflowException(f"{statement_status.get('status')}: 
{statement_status.get('message')}")
+
+        if len(self.query_ids) == len(succeeded_query_ids):
+            self.log.info("%s completed successfully.", self.task_id)
+            return
+
         if self.deferrable:
             self.defer(
                 timeout=self.execution_timeout,
diff --git a/tests/providers/snowflake/operators/test_snowflake.py 
b/tests/providers/snowflake/operators/test_snowflake.py
index 07df5fb147..7f429277b9 100644
--- a/tests/providers/snowflake/operators/test_snowflake.py
+++ b/tests/providers/snowflake/operators/test_snowflake.py
@@ -253,7 +253,9 @@ class TestSnowflakeSqlApiOperator:
 
     @pytest.mark.parametrize("mock_sql, statement_count", 
[(SQL_MULTIPLE_STMTS, 4), (SINGLE_STMT, 1)])
     
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query")
-    def test_snowflake_sql_api_execute_operator_async(self, mock_db_hook, 
mock_sql, statement_count):
+    def test_snowflake_sql_api_execute_operator_async(
+        self, mock_execute_query, mock_sql, statement_count, 
mock_get_sql_api_query_status
+    ):
         """
         Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be 
fired
         when the SnowflakeSqlApiOperator is executed.
@@ -266,6 +268,9 @@ class TestSnowflakeSqlApiOperator:
             deferrable=True,
         )
 
+        mock_execute_query.return_value = ["uuid1"]
+        mock_get_sql_api_query_status.side_effect = [{"status": "running"}]
+
         with pytest.raises(TaskDeferred) as exc:
             operator.execute(create_context(operator))
 
@@ -311,3 +316,63 @@ class TestSnowflakeSqlApiOperator:
         with mock.patch.object(operator.log, "info") as mock_log_info:
             operator.execute_complete(context=None, event=mock_event)
         mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)
+
+    
@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
+    def test_snowflake_sql_api_execute_operator_failed_before_defer(
+        self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
+    ):
+        """Asserts that a task is not deferred when its failed"""
+
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id="snowflake_default",
+            sql=SQL_MULTIPLE_STMTS,
+            statement_count=4,
+            do_xcom_push=False,
+            deferrable=True,
+        )
+        mock_execute_query.return_value = ["uuid1"]
+        mock_get_sql_api_query_status.side_effect = [{"status": "error"}]
+        with pytest.raises(AirflowException):
+            operator.execute(create_context(operator))
+        assert not mock_defer.called
+
+    
@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
+    def test_snowflake_sql_api_execute_operator_succeeded_before_defer(
+        self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
+    ):
+        """Asserts that a task is not deferred when its succeeded"""
+
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id="snowflake_default",
+            sql=SQL_MULTIPLE_STMTS,
+            statement_count=4,
+            do_xcom_push=False,
+            deferrable=True,
+        )
+        mock_execute_query.return_value = ["uuid1"]
+        mock_get_sql_api_query_status.side_effect = [{"status": "success"}]
+        operator.execute(create_context(operator))
+
+        assert not mock_defer.called
+
+    
@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
+    def test_snowflake_sql_api_execute_operator_running_before_defer(
+        self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
+    ):
+        """Asserts that a task is deferred when its running"""
+
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id="snowflake_default",
+            sql=SQL_MULTIPLE_STMTS,
+            statement_count=4,
+            do_xcom_push=False,
+            deferrable=True,
+        )
+        mock_execute_query.return_value = ["uuid1"]
+        mock_get_sql_api_query_status.side_effect = [{"status": "running"}]
+        operator.execute(create_context(operator))
+
+        assert mock_defer.called

Reply via email to