This is an automated email from the ASF dual-hosted git repository.

shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 944e1311416 Migrate BigQueryInsertJobTrigger to on_kill() for 
user-initiated kills (#66704)
944e1311416 is described below

commit 944e1311416f523a912abaf602c0ef47f7dd5845
Author: Yunhui Chae <[email protected]>
AuthorDate: Tue May 12 23:54:13 2026 +0900

    Migrate BigQueryInsertJobTrigger to on_kill() for user-initiated kills 
(#66704)
---
 .../providers/google/cloud/triggers/bigquery.py    | 121 ++++++++++++---------
 .../unit/google/cloud/triggers/test_bigquery.py    | 117 +++++++++++---------
 2 files changed, 132 insertions(+), 106 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py 
b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
index 4a6d1a7780c..c9f8acde0a5 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
@@ -26,14 +26,14 @@ from asgiref.sync import sync_to_async
 
 from airflow.providers.common.compat.sdk import AirflowException
 from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, 
BigQueryTableAsyncHook
-from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_3_PLUS
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from sqlalchemy.orm.session import Session
 
-if not AIRFLOW_V_3_0_PLUS:
+if not AIRFLOW_V_3_3_PLUS:
     from sqlalchemy import select
 
     from airflow.models.taskinstance import TaskInstance
@@ -103,7 +103,20 @@ class BigQueryInsertJobTrigger(BaseTrigger):
             },
         )
 
-    if not AIRFLOW_V_3_0_PLUS:
+    async def on_kill(self) -> None:
+        """Cancel the BigQuery job when the task is killed by a user action."""
+        if self.job_id and self.cancel_on_kill:
+            self.log.info(
+                "Cancelling BigQuery job. Project ID: %s, Location: %s, Job 
ID: %s",
+                self.project_id,
+                self.location,
+                self.job_id,
+            )
+            hook = self._get_async_hook()
+            await hook.cancel_job(job_id=self.job_id, 
project_id=self.project_id, location=self.location)
+            self.log.info("BigQuery job %s cancelled successfully.", 
self.job_id)
+
+    if not AIRFLOW_V_3_3_PLUS:
 
         @provide_session
         def get_task_instance(self, session: Session) -> TaskInstance:
@@ -125,41 +138,41 @@ class BigQueryInsertJobTrigger(BaseTrigger):
                 )
             return task_instance
 
-    async def get_task_state(self):
-        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+        async def get_task_state(self):
+            from airflow.sdk.execution_time.task_runner import 
RuntimeTaskInstance
 
-        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
-            dag_id=self.task_instance.dag_id,
-            task_ids=[self.task_instance.task_id],
-            run_ids=[self.task_instance.run_id],
-            map_index=self.task_instance.map_index,
-        )
-        try:
-            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
-        except Exception:
-            raise AirflowException(
-                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
-                self.task_instance.dag_id,
-                self.task_instance.task_id,
-                self.task_instance.run_id,
-                self.task_instance.map_index,
+            task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+                dag_id=self.task_instance.dag_id,
+                task_ids=[self.task_instance.task_id],
+                run_ids=[self.task_instance.run_id],
+                map_index=self.task_instance.map_index,
             )
-        return task_state
+            try:
+                task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+            except Exception:
+                raise AirflowException(
+                    "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
+                    self.task_instance.dag_id,
+                    self.task_instance.task_id,
+                    self.task_instance.run_id,
+                    self.task_instance.map_index,
+                )
+            return task_state
 
-    async def safe_to_cancel(self) -> bool:
-        """
-        Whether it is safe to cancel the external job which is being executed 
by this trigger.
+        async def safe_to_cancel(self) -> bool:
+            """
+            Whether it is safe to cancel the external job which is being 
executed by this trigger.
 
-        This is to avoid the case that `asyncio.CancelledError` is called 
because the trigger itself is stopped.
-        Because in those cases, we should NOT cancel the external job.
-        """
-        if AIRFLOW_V_3_0_PLUS:
-            task_state = await self.get_task_state()
-        else:
-            # Database query is needed to get the latest state of the task 
instance.
-            task_instance = self.get_task_instance()  # type: ignore[call-arg]
-            task_state = task_instance.state
-        return task_state != TaskInstanceState.DEFERRED
+            This is to avoid the case that `asyncio.CancelledError` is called 
because the trigger itself is stopped.
+            Because in those cases, we should NOT cancel the external job.
+            """
+            if AIRFLOW_V_3_0_PLUS:
+                task_state = await self.get_task_state()
+            else:
+                # Database query is needed to get the latest state of the task 
instance.
+                task_instance = self.get_task_instance()  # type: 
ignore[call-arg]
+                task_state = task_instance.state
+            return task_state != TaskInstanceState.DEFERRED
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         """Get current job execution status and yields a TriggerEvent."""
@@ -196,25 +209,27 @@ class BigQueryInsertJobTrigger(BaseTrigger):
                     )
                     await asyncio.sleep(self.poll_interval)
         except asyncio.CancelledError:
-            if self.job_id and self.cancel_on_kill and await 
self.safe_to_cancel():
-                self.log.info(
-                    "The job is safe to cancel the as airflow TaskInstance is 
not in deferred state."
-                )
-                self.log.info(
-                    "Cancelling job. Project ID: %s, Location: %s, Job ID: %s",
-                    self.project_id,
-                    self.location,
-                    self.job_id,
-                )
-                await hook.cancel_job(job_id=self.job_id, 
project_id=self.project_id, location=self.location)
-            else:
-                self.log.info(
-                    "Trigger may have shutdown. Skipping to cancel job because 
the airflow "
-                    "task is not cancelled yet: Project ID: %s, Location:%s, 
Job ID:%s",
-                    self.project_id,
-                    self.location,
-                    self.job_id,
-                )
+            # Legacy path for Airflow < 3.3.0
+            # On Airflow 3.3.0+, on_kill() handles user-initiated kills
+            if not AIRFLOW_V_3_3_PLUS:
+                if self.job_id and self.cancel_on_kill and await 
self.safe_to_cancel():
+                    self.log.info(
+                        "Cancelling job (legacy path). Project ID: %s, 
Location: %s, Job ID: %s",
+                        self.project_id,
+                        self.location,
+                        self.job_id,
+                    )
+                    await hook.cancel_job(
+                        job_id=self.job_id, project_id=self.project_id, 
location=self.location
+                    )
+                else:
+                    self.log.info(
+                        "Trigger may have shutdown. Skipping to cancel job 
because the airflow "
+                        "task is not cancelled yet: Project ID: %s, 
Location:%s, Job ID:%s",
+                        self.project_id,
+                        self.location,
+                        self.job_id,
+                    )
             raise
         except Exception as e:
             self.log.exception("Exception occurred while checking for query 
completion")
diff --git a/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py 
b/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
index 720a9a0d806..78448c064a0 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
@@ -40,6 +40,8 @@ from airflow.providers.google.cloud.triggers.bigquery import (
 )
 from airflow.triggers.base import TriggerEvent
 
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+
 TEST_CONN_ID = "bq_default"
 TEST_JOB_ID = "1234"
 TEST_GCP_PROJECT_ID = "test-project"
@@ -234,72 +236,81 @@ class TestBigQueryInsertJobTrigger:
         assert TriggerEvent({"status": "error", "message": "Test exception"}) 
== actual
 
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
-    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+    @pytest.mark.skipif(AIRFLOW_V_3_3_PLUS, reason="on_kill() handles 
cancellation for Airflow 3.3.0+")
+    @pytest.mark.parametrize("is_safe_to_cancel", [True, False])
+    
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger._get_async_hook")
     
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
-    async def test_bigquery_insert_job_trigger_cancellation(
-        self, mock_get_task_instance, mock_get_job_status, mock_cancel_job, 
caplog, insert_job_trigger
+    async def test_insert_job_trigger_run_cancelled(
+        self, mock_safe_to_cancel, mock_get_async_hook, insert_job_trigger, 
is_safe_to_cancel
     ):
-        """
-        Test that BigQueryInsertJobTrigger handles cancellation correctly, 
logs the appropriate message,
-        and conditionally cancels the job based on the `cancel_on_kill` 
attribute.
-        """
-        mock_get_task_instance.return_value = True
+        """Test CancelledError handling for Airflow < 3.3.0."""
+        mock_safe_to_cancel.return_value = is_safe_to_cancel
+        mock_hook = mock_get_async_hook.return_value
+        mock_hook.get_job_status = AsyncMock()
+        mock_hook.get_job_status.side_effect = asyncio.CancelledError
+        mock_hook.cancel_job = AsyncMock()
+
+        async_gen = insert_job_trigger.run()
+        try:
+            await async_gen.asend(None)
+        except (asyncio.CancelledError, StopAsyncIteration):
+            pass
+        except Exception as e:
+            pytest.fail(f"Unexpected exception raised: {e}")
+
+        if insert_job_trigger.cancel_on_kill and is_safe_to_cancel:
+            mock_hook.cancel_job.assert_awaited_once_with(
+                job_id=insert_job_trigger.job_id,
+                project_id=insert_job_trigger.project_id,
+                location=insert_job_trigger.location,
+            )
+        else:
+            mock_hook.cancel_job.assert_not_awaited()
+
+        await async_gen.aclose()
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger._get_async_hook")
+    async def test_on_kill_cancels_job(self, mock_get_async_hook, 
insert_job_trigger):
+        """Test that on_kill cancels the BigQuery job."""
+        mock_hook = mock_get_async_hook.return_value
+        mock_hook.cancel_job = AsyncMock()
+        insert_job_trigger.job_id = TEST_JOB_ID
         insert_job_trigger.cancel_on_kill = True
-        insert_job_trigger.job_id = "1234"
 
-        mock_get_job_status.side_effect = [
-            {"status": "running", "message": "Job is still running"},
-            asyncio.CancelledError(),
-        ]
+        await insert_job_trigger.on_kill()
 
-        mock_cancel_job.return_value = asyncio.Future()
-        mock_cancel_job.return_value.set_result(None)
+        mock_hook.cancel_job.assert_awaited_once_with(
+            job_id=TEST_JOB_ID,
+            project_id=TEST_GCP_PROJECT_ID,
+            location=TEST_LOCATION,
+        )
 
-        caplog.set_level(logging.INFO)
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger._get_async_hook")
+    async def test_on_kill_respects_cancel_on_kill_false(self, 
mock_get_async_hook, insert_job_trigger):
+        """Test that on_kill does not cancel the job when cancel_on_kill is 
False."""
+        mock_hook = mock_get_async_hook.return_value
+        mock_hook.cancel_job = AsyncMock()
+        insert_job_trigger.job_id = TEST_JOB_ID
+        insert_job_trigger.cancel_on_kill = False
 
-        with pytest.raises(asyncio.CancelledError):
-            async for _ in insert_job_trigger.run():
-                pass
+        await insert_job_trigger.on_kill()
 
-        assert (
-            "Task was killed" in caplog.text
-            or "Bigquery job status is running. Sleeping for 4.0 seconds." in 
caplog.text
-        ), "Expected messages about task status or cancellation not found in 
log."
-        mock_cancel_job.assert_awaited_once()
+        mock_hook.cancel_job.assert_not_called()
 
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
-    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
-    
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
-    async def 
test_bigquery_insert_job_trigger_cancellation_unsafe_cancellation(
-        self, mock_safe_to_cancel, mock_get_job_status, mock_cancel_job, 
caplog, insert_job_trigger
-    ):
-        """
-        Test that BigQueryInsertJobTrigger logs the appropriate message and 
does not cancel the job
-        if safe_to_cancel returns False even when the task is cancelled.
-        """
-        mock_safe_to_cancel.return_value = False
+    
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger._get_async_hook")
+    async def test_on_kill_no_job_id_does_not_cancel(self, 
mock_get_async_hook, insert_job_trigger):
+        """Test that on_kill does not attempt to cancel when job_id is not 
set."""
+        mock_hook = mock_get_async_hook.return_value
+        mock_hook.cancel_job = AsyncMock()
+        insert_job_trigger.job_id = None
         insert_job_trigger.cancel_on_kill = True
-        insert_job_trigger.job_id = "1234"
-
-        # Simulate the initial job status as running
-        mock_get_job_status.side_effect = [
-            {"status": "running", "message": "Job is still running"},
-            asyncio.CancelledError(),
-            {"status": "running", "message": "Job is still running after 
cancellation"},
-        ]
 
-        caplog.set_level(logging.INFO)
+        await insert_job_trigger.on_kill()
 
-        with pytest.raises(asyncio.CancelledError):
-            async for _ in insert_job_trigger.run():
-                pass
-
-        assert "Skipping to cancel job" in caplog.text, (
-            "Expected message about skipping cancellation not found in log."
-        )
-        assert mock_get_job_status.call_count == 2, "Job status should be 
checked multiple times"
+        mock_hook.cancel_job.assert_not_called()
 
 
 class TestBigQueryGetDataTrigger:

Reply via email to