This is an automated email from the ASF dual-hosted git repository.
basph pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 98a8bfc242c Add cancel_on_kill support for EMR Serverless deferrable
operator (#60440)
98a8bfc242c is described below
commit 98a8bfc242c018a1955c11eb4aee296e6461d61f
Author: Akshay <[email protected]>
AuthorDate: Thu Feb 5 14:41:42 2026 +0530
Add cancel_on_kill support for EMR Serverless deferrable operator (#60440)
* Add cancel_on_kill support for EMR Serverless deferrable operator
* Add details to safe_to_cancel method.
* Fix formatting in EMR trigger test
Apply prek hook formatting to single-line assertion.
Co-authored-by: Cursor <[email protected]>
* Use modern SQLAlchemy 2.0 style query
Replace deprecated session.query().filter() with select().where()
and session.scalars() pattern.
Co-authored-by: Cursor <[email protected]>
---------
Co-authored-by: Akshay <[email protected]>
Co-authored-by: Akshay <[email protected]>
Co-authored-by: Cursor <[email protected]>
---
.../airflow/providers/amazon/aws/operators/emr.py | 9 +-
.../airflow/providers/amazon/aws/triggers/emr.py | 131 +++++++++++++++++++-
.../tests/unit/amazon/aws/triggers/test_emr.py | 132 +++++++++++++++++++++
3 files changed, 269 insertions(+), 3 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
index 1dc3d54f4d3..6241436ad74 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
@@ -1164,6 +1164,9 @@ class
EmrServerlessStartJobOperator(AwsBaseOperator[EmrServerlessHook]):
:param enable_application_ui_links: If True, the operator will generate
one-time links to EMR Serverless
application UIs. The generated links will allow any user with access
to the DAG to see the Spark or
Tez UI or Spark stdout logs. Defaults to False.
+ :param cancel_on_kill: If True, the EMR Serverless job will be cancelled
when the task is killed
+ while in deferrable mode. This ensures that orphan jobs are not left
running in EMR Serverless
+ when an Airflow task is cancelled. Defaults to True.
"""
aws_hook_class = EmrServerlessHook
@@ -1202,6 +1205,7 @@ class
EmrServerlessStartJobOperator(AwsBaseOperator[EmrServerlessHook]):
waiter_delay: int | ArgNotSet = NOTSET,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
enable_application_ui_links: bool = False,
+ cancel_on_kill: bool = True,
**kwargs,
):
waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay
@@ -1219,6 +1223,7 @@ class
EmrServerlessStartJobOperator(AwsBaseOperator[EmrServerlessHook]):
self.job_id: str | None = None
self.deferrable = deferrable
self.enable_application_ui_links = enable_application_ui_links
+ self.cancel_on_kill = cancel_on_kill
super().__init__(**kwargs)
self.client_request_token = client_request_token or str(uuid4())
@@ -1283,6 +1288,7 @@ class
EmrServerlessStartJobOperator(AwsBaseOperator[EmrServerlessHook]):
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
+ cancel_on_kill=self.cancel_on_kill,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
@@ -1334,7 +1340,8 @@ class
EmrServerlessStartJobOperator(AwsBaseOperator[EmrServerlessHook]):
"""
Cancel the submitted job run.
- Note: this method will not run in deferrable mode.
+ Note: In deferrable mode, this method will not run. Instead, job
cancellation
+ is handled by the trigger's cancel_on_kill parameter when the task is
killed.
"""
if self.job_id:
self.log.info("Stopping job run with jobId - %s", self.job_id)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
index 356a7354d78..5fb5759a9de 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
@@ -16,15 +16,29 @@
# under the License.
from __future__ import annotations
+import asyncio
import sys
+from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
+from asgiref.sync import sync_to_async
+
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook,
EmrServerlessHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.triggers.base import TriggerEvent
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
+ from sqlalchemy.orm.session import Session
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+if not AIRFLOW_V_3_0_PLUS:
+ from airflow.models.taskinstance import TaskInstance
+ from airflow.utils.session import provide_session
+
class EmrAddStepsTrigger(AwsBaseWaiterTrigger):
"""
@@ -332,9 +346,10 @@ class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):
:param application_id: The ID of the application the job in being run on.
:param job_id: The ID of the job run.
- :waiter_delay: polling period in seconds to check for the status
+ :param waiter_delay: polling period in seconds to check for the status
:param waiter_max_attempts: The maximum number of attempts to be made
:param aws_conn_id: Reference to AWS connection id
+ :param cancel_on_kill: Flag to indicate whether to cancel the job when the
task is killed.
"""
def __init__(
@@ -344,9 +359,14 @@ class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
aws_conn_id: str | None = "aws_default",
+ cancel_on_kill: bool = True,
) -> None:
super().__init__(
- serialized_fields={"application_id": application_id, "job_id":
job_id},
+ serialized_fields={
+ "application_id": application_id,
+ "job_id": job_id,
+ "cancel_on_kill": cancel_on_kill,
+ },
waiter_name="serverless_job_completed",
waiter_args={"applicationId": application_id, "jobRunId": job_id},
failure_message="Serverless Job failed",
@@ -358,10 +378,117 @@ class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+ self.application_id = application_id
+ self.job_id = job_id
+ self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
return EmrServerlessHook(self.aws_conn_id)
+ if not AIRFLOW_V_3_0_PLUS:
+
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
+ from sqlalchemy import select
+
+ query = select(TaskInstance).where(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ task_instance = session.scalars(query).one_or_none()
+ if task_instance is None:
+ raise ValueError(
+ f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+ f"task_id: {self.task_instance.task_id}, "
+ f"run_id: {self.task_instance.run_id} and "
+ f"map_index: {self.task_instance.map_index} is not found"
+ )
+ return task_instance
+
+ async def get_task_state(self):
+ """Get the current state of the task instance (Airflow 3.x)."""
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise ValueError(
+ f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+ f"task_id: {self.task_instance.task_id}, "
+ f"run_id: {self.task_instance.run_id} and "
+ f"map_index: {self.task_instance.map_index} is not found"
+ )
+ return task_state
+
+ async def safe_to_cancel(self) -> bool:
+ """
+ Whether it is safe to cancel the EMR Serverless job.
+
+ Returns True if task is NOT DEFERRED (user-initiated cancellation).
+ Returns False if task is DEFERRED (triggerer restart - don't cancel
job).
+ """
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = await self.get_task_state()
+ else:
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ task_state = task_instance.state
+ return task_state != TaskInstanceState.DEFERRED
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """
+ Run the trigger and wait for the job to complete.
+
+ If the task is cancelled while waiting, attempt to cancel the EMR
Serverless job
+ if cancel_on_kill is enabled and it's safe to do so.
+ """
+ hook = self.hook()
+ try:
+ async with await hook.get_async_conn() as client:
+ waiter = hook.get_waiter(
+ self.waiter_name,
+ deferrable=True,
+ client=client,
+ config_overrides=self.waiter_config_overrides,
+ )
+ await async_wait(
+ waiter,
+ self.waiter_delay,
+ self.attempts,
+ self.waiter_args,
+ self.failure_message,
+ self.status_message,
+ self.status_queries,
+ )
+ yield TriggerEvent({"status": "success", self.return_key:
self.return_value})
+ except asyncio.CancelledError:
+ if self.job_id and self.cancel_on_kill and await
self.safe_to_cancel():
+ self.log.info(
+ "Task was cancelled. Cancelling EMR Serverless job.
Application ID: %s, Job ID: %s",
+ self.application_id,
+ self.job_id,
+ )
+ hook.conn.cancel_job_run(applicationId=self.application_id,
jobRunId=self.job_id)
+ self.log.info("EMR Serverless job %s cancelled successfully.",
self.job_id)
+ else:
+ self.log.info(
+ "Trigger may have shutdown or cancel_on_kill is disabled. "
+ "Skipping job cancellation. Application ID: %s, Job ID:
%s",
+ self.application_id,
+ self.job_id,
+ )
+ raise
+ except Exception as e:
+ yield TriggerEvent({"status": "failure", "message": str(e)})
+
class EmrServerlessDeleteApplicationTrigger(AwsBaseWaiterTrigger):
"""
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py
index eb7f1851155..fef643dc927 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py
@@ -16,7 +16,11 @@
# under the License.
from __future__ import annotations
+import asyncio
import sys
+from unittest import mock
+
+import pytest
from airflow.providers.amazon.aws.triggers.emr import (
EmrAddStepsTrigger,
@@ -269,8 +273,136 @@ class TestEmrServerlessStartJobTrigger:
"waiter_max_attempts": 60,
"job_id": "job_id",
"aws_conn_id": "aws_default",
+ "cancel_on_kill": True,
}
+ def test_serialization_cancel_on_kill_false(self):
+ """Test that cancel_on_kill=False is correctly serialized."""
+ trigger = EmrServerlessStartJobTrigger(
+ application_id="test_app",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=False,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger"
+ assert kwargs["cancel_on_kill"] is False
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger.safe_to_cancel")
+ async def test_emr_serverless_trigger_cancellation(self,
mock_safe_to_cancel, mock_async_wait):
+ """
+ Test that EmrServerlessStartJobTrigger cancels the job when task is
killed
+ and safe_to_cancel returns True.
+ """
+ mock_safe_to_cancel.return_value = True
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrServerlessStartJobTrigger(
+ application_id="test_app",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+ mock_hook.conn.cancel_job_run.return_value = {"ResponseMetadata":
{"HTTPStatusCode": 200}}
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+ with mock.patch.object(trigger, "hook", return_value=mock_hook):
+ with pytest.raises(asyncio.CancelledError):
+ async for _ in trigger.run():
+ pass
+
+
mock_hook.conn.cancel_job_run.assert_called_once_with(applicationId="test_app",
jobRunId="test_job")
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger.safe_to_cancel")
+ async def test_emr_serverless_trigger_no_cancellation_when_unsafe(
+ self, mock_safe_to_cancel, mock_async_wait
+ ):
+ """
+ Test that EmrServerlessStartJobTrigger does NOT cancel the job when
+ safe_to_cancel returns False (e.g., triggerer shutdown).
+ """
+ mock_safe_to_cancel.return_value = False
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrServerlessStartJobTrigger(
+ application_id="test_app",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+ with mock.patch.object(trigger, "hook", return_value=mock_hook):
+ with pytest.raises(asyncio.CancelledError):
+ async for _ in trigger.run():
+ pass
+
+ mock_hook.conn.cancel_job_run.assert_not_called()
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger.safe_to_cancel")
+ async def test_emr_serverless_trigger_no_cancellation_when_disabled(
+ self, mock_safe_to_cancel, mock_async_wait
+ ):
+ """
+ Test that EmrServerlessStartJobTrigger does NOT cancel the job when
+ cancel_on_kill=False.
+ """
+ mock_safe_to_cancel.return_value = True
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrServerlessStartJobTrigger(
+ application_id="test_app",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=False, # Disabled
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+ with mock.patch.object(trigger, "hook", return_value=mock_hook):
+ with pytest.raises(asyncio.CancelledError):
+ async for _ in trigger.run():
+ pass
+
+ mock_hook.conn.cancel_job_run.assert_not_called()
+
class TestEmrServerlessDeleteApplicationTrigger:
def test_serialization(self):