Re: [PR] Add cancel_on_kill support for EMR Serverless deferrable operator [airflow]
BasPH merged PR #60440: URL: https://github.com/apache/airflow/pull/60440 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] Add cancel_on_kill support for EMR Serverless deferrable operator [airflow]
akshaykumarsalunke commented on PR #60440: URL: https://github.com/apache/airflow/pull/60440#issuecomment-3848226344 > @cruseakshay could you fix the failing test? Done -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] Add cancel_on_kill support for EMR Serverless deferrable operator [airflow]
BasPH commented on PR #60440: URL: https://github.com/apache/airflow/pull/60440#issuecomment-3847092765 @cruseakshay could you fix the failing test? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] Add cancel_on_kill support for EMR Serverless deferrable operator [airflow]
cruseakshay commented on code in PR #60440:
URL: https://github.com/apache/airflow/pull/60440#discussion_r2700610236
##
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##
@@ -357,10 +377,115 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+self.application_id = application_id
+self.job_id = job_id
+self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
return EmrServerlessHook(self.aws_conn_id)
+if not AIRFLOW_V_3_0_PLUS:
+
+@provide_session
+def get_task_instance(self, session: Session) -> TaskInstance:
+"""Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
+query = session.query(TaskInstance).filter(
+TaskInstance.dag_id == self.task_instance.dag_id,
+TaskInstance.task_id == self.task_instance.task_id,
+TaskInstance.run_id == self.task_instance.run_id,
+TaskInstance.map_index == self.task_instance.map_index,
+)
+task_instance = query.one_or_none()
+if task_instance is None:
+raise ValueError(
+f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+f"task_id: {self.task_instance.task_id}, "
+f"run_id: {self.task_instance.run_id} and "
+f"map_index: {self.task_instance.map_index} is not found"
+)
+return task_instance
+
+async def get_task_state(self):
+"""Get the current state of the task instance (Airflow 3.x)."""
+from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+dag_id=self.task_instance.dag_id,
+task_ids=[self.task_instance.task_id],
+run_ids=[self.task_instance.run_id],
+map_index=self.task_instance.map_index,
+)
+try:
+task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+except Exception:
+raise ValueError(
+f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+f"task_id: {self.task_instance.task_id}, "
+f"run_id: {self.task_instance.run_id} and "
+f"map_index: {self.task_instance.map_index} is not found"
+)
+return task_state
+
+async def safe_to_cancel(self) -> bool:
+"""
+Whether it is safe to cancel the EMR Serverless job.
+
+Returns True if task is NOT DEFERRED (user-initiated cancellation).
+Returns False if task is DEFERRED (triggerer restart - don't cancel
job).
+"""
+if AIRFLOW_V_3_0_PLUS:
+task_state = await self.get_task_state()
+else:
+task_instance = self.get_task_instance() # type: ignore[call-arg]
+task_state = task_instance.state
+return task_state != TaskInstanceState.DEFERRED
+
+async def run(self) -> AsyncIterator[TriggerEvent]:
+"""
+Run the trigger and wait for the job to complete.
+
+If the task is cancelled while waiting, attempt to cancel the EMR
Serverless job
+if cancel_on_kill is enabled and it's safe to do so.
+"""
+hook = self.hook()
+try:
+async with await hook.get_async_conn() as client:
+waiter = hook.get_waiter(
+self.waiter_name,
+deferrable=True,
+client=client,
+config_overrides=self.waiter_config_overrides,
+)
+await async_wait(
+waiter,
+self.waiter_delay,
+self.attempts,
+self.waiter_args,
+self.failure_message,
+self.status_message,
+self.status_queries,
+)
+yield TriggerEvent({"status": "success", self.return_key:
self.return_value})
+except asyncio.CancelledError:
+if self.job_id and self.cancel_on_kill and await
self.safe_to_cancel():
+self.log.info(
+"Task was cancelled. Cancelling EMR Serverless job.
Application ID: %s, Job ID: %s",
+self.application_id,
+self.job_id,
+)
+hook.conn.cancel_job_run(applicationId=self.application_id,
jobRunId=self.job_id)
Review Comment:
The cancellation request is "fire and forget" - once cancel_job_run returns
successfully, AWS handles the state transition. This is consistent with the
approach in Dataproc
(https://github.com/apache/airflow/blob/main/providers/google/src/airflow/providers/google/cloud/trigger
Re: [PR] Add cancel_on_kill support for EMR Serverless deferrable operator [airflow]
o-nikolas commented on code in PR #60440:
URL: https://github.com/apache/airflow/pull/60440#discussion_r2699895465
##
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##
@@ -357,10 +377,115 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+self.application_id = application_id
+self.job_id = job_id
+self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
return EmrServerlessHook(self.aws_conn_id)
+if not AIRFLOW_V_3_0_PLUS:
+
+@provide_session
+def get_task_instance(self, session: Session) -> TaskInstance:
+"""Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
+query = session.query(TaskInstance).filter(
+TaskInstance.dag_id == self.task_instance.dag_id,
+TaskInstance.task_id == self.task_instance.task_id,
+TaskInstance.run_id == self.task_instance.run_id,
+TaskInstance.map_index == self.task_instance.map_index,
+)
+task_instance = query.one_or_none()
+if task_instance is None:
+raise ValueError(
+f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+f"task_id: {self.task_instance.task_id}, "
+f"run_id: {self.task_instance.run_id} and "
+f"map_index: {self.task_instance.map_index} is not found"
+)
+return task_instance
+
+async def get_task_state(self):
+"""Get the current state of the task instance (Airflow 3.x)."""
+from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+dag_id=self.task_instance.dag_id,
+task_ids=[self.task_instance.task_id],
+run_ids=[self.task_instance.run_id],
+map_index=self.task_instance.map_index,
+)
+try:
+task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+except Exception:
+raise ValueError(
+f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+f"task_id: {self.task_instance.task_id}, "
+f"run_id: {self.task_instance.run_id} and "
+f"map_index: {self.task_instance.map_index} is not found"
+)
+return task_state
+
+async def safe_to_cancel(self) -> bool:
+"""
+Whether it is safe to cancel the EMR Serverless job.
+
+Returns True if task is NOT DEFERRED (user-initiated cancellation).
+Returns False if task is DEFERRED (triggerer restart - don't cancel
job).
+"""
+if AIRFLOW_V_3_0_PLUS:
+task_state = await self.get_task_state()
+else:
+task_instance = self.get_task_instance() # type: ignore[call-arg]
+task_state = task_instance.state
+return task_state != TaskInstanceState.DEFERRED
+
+async def run(self) -> AsyncIterator[TriggerEvent]:
+"""
+Run the trigger and wait for the job to complete.
+
+If the task is cancelled while waiting, attempt to cancel the EMR
Serverless job
+if cancel_on_kill is enabled and it's safe to do so.
+"""
+hook = self.hook()
+try:
+async with await hook.get_async_conn() as client:
+waiter = hook.get_waiter(
+self.waiter_name,
+deferrable=True,
+client=client,
+config_overrides=self.waiter_config_overrides,
+)
+await async_wait(
+waiter,
+self.waiter_delay,
+self.attempts,
+self.waiter_args,
+self.failure_message,
+self.status_message,
+self.status_queries,
+)
+yield TriggerEvent({"status": "success", self.return_key:
self.return_value})
+except asyncio.CancelledError:
+if self.job_id and self.cancel_on_kill and await
self.safe_to_cancel():
+self.log.info(
+"Task was cancelled. Cancelling EMR Serverless job.
Application ID: %s, Job ID: %s",
+self.application_id,
+self.job_id,
+)
+hook.conn.cancel_job_run(applicationId=self.application_id,
jobRunId=self.job_id)
Review Comment:
I suppose we have no way to wait to verify the deletion like we do in the
sync case?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: co
Re: [PR] Add cancel_on_kill support for EMR Serverless deferrable operator [airflow]
cruseakshay commented on code in PR #60440: URL: https://github.com/apache/airflow/pull/60440#discussion_r2694452174 ## providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py: ## @@ -357,10 +377,112 @@ def __init__( waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, ) +self.application_id = application_id +self.job_id = job_id +self.cancel_on_kill = cancel_on_kill def hook(self) -> AwsGenericHook: return EmrServerlessHook(self.aws_conn_id) +if not AIRFLOW_V_3_0_PLUS: Review Comment: Similiar approach in DataprocSubmitTrigger, DataprocClusterTrigger and BigQueryInsertJobTrigger. LMK if we need an alternative. https://github.com/apache/airflow/blob/f9877a07cd2693e9176814a2b113742d8f788a82/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py#L126 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] Add cancel_on_kill support for EMR Serverless deferrable operator [airflow]
cruseakshay commented on code in PR #60440:
URL: https://github.com/apache/airflow/pull/60440#discussion_r2694445636
##
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##
@@ -357,10 +377,112 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+self.application_id = application_id
+self.job_id = job_id
+self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
return EmrServerlessHook(self.aws_conn_id)
+if not AIRFLOW_V_3_0_PLUS:
+
+@provide_session
+def get_task_instance(self, session: Session) -> TaskInstance:
+"""Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
+query = session.query(TaskInstance).filter(
+TaskInstance.dag_id == self.task_instance.dag_id,
+TaskInstance.task_id == self.task_instance.task_id,
+TaskInstance.run_id == self.task_instance.run_id,
+TaskInstance.map_index == self.task_instance.map_index,
+)
+task_instance = query.one_or_none()
+if task_instance is None:
+raise ValueError(
+f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+f"task_id: {self.task_instance.task_id}, "
+f"run_id: {self.task_instance.run_id} and "
+f"map_index: {self.task_instance.map_index} is not found"
+)
+return task_instance
+
+async def get_task_state(self):
+"""Get the current state of the task instance (Airflow 3.x)."""
+from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+dag_id=self.task_instance.dag_id,
+task_ids=[self.task_instance.task_id],
+run_ids=[self.task_instance.run_id],
+map_index=self.task_instance.map_index,
+)
+try:
+task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+except Exception:
+raise ValueError(
+f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+f"task_id: {self.task_instance.task_id}, "
+f"run_id: {self.task_instance.run_id} and "
+f"map_index: {self.task_instance.map_index} is not found"
+)
+return task_state
+
+async def safe_to_cancel(self) -> bool:
+"""
+Whether it is safe to cancel the EMR Serverless job.
Review Comment:
Added more details.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] Add cancel_on_kill support for EMR Serverless deferrable operator [airflow]
BasPH commented on code in PR #60440:
URL: https://github.com/apache/airflow/pull/60440#discussion_r2687369896
##
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##
@@ -357,10 +377,112 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+self.application_id = application_id
+self.job_id = job_id
+self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
return EmrServerlessHook(self.aws_conn_id)
+if not AIRFLOW_V_3_0_PLUS:
+
+@provide_session
+def get_task_instance(self, session: Session) -> TaskInstance:
+"""Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
+query = session.query(TaskInstance).filter(
+TaskInstance.dag_id == self.task_instance.dag_id,
+TaskInstance.task_id == self.task_instance.task_id,
+TaskInstance.run_id == self.task_instance.run_id,
+TaskInstance.map_index == self.task_instance.map_index,
+)
+task_instance = query.one_or_none()
+if task_instance is None:
+raise ValueError(
+f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+f"task_id: {self.task_instance.task_id}, "
+f"run_id: {self.task_instance.run_id} and "
+f"map_index: {self.task_instance.map_index} is not found"
+)
+return task_instance
+
+async def get_task_state(self):
+"""Get the current state of the task instance (Airflow 3.x)."""
+from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+dag_id=self.task_instance.dag_id,
+task_ids=[self.task_instance.task_id],
+run_ids=[self.task_instance.run_id],
+map_index=self.task_instance.map_index,
+)
+try:
+task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+except Exception:
+raise ValueError(
+f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+f"task_id: {self.task_instance.task_id}, "
+f"run_id: {self.task_instance.run_id} and "
+f"map_index: {self.task_instance.map_index} is not found"
+)
+return task_state
+
+async def safe_to_cancel(self) -> bool:
+"""
+Whether it is safe to cancel the EMR Serverless job.
Review Comment:
Would be nice to define "safe" in this comment
##
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##
@@ -357,10 +377,112 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+self.application_id = application_id
+self.job_id = job_id
+self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
return EmrServerlessHook(self.aws_conn_id)
+if not AIRFLOW_V_3_0_PLUS:
Review Comment:
This condition looks misplaced to me?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] Add cancel_on_kill support for EMR Serverless deferrable operator [airflow]
boring-cyborg[bot] commented on PR #60440: URL: https://github.com/apache/airflow/pull/60440#issuecomment-3741932287 Congratulations on your first Pull Request and welcome to the Apache Airflow community! If you have any issues or are unsure about any anything please check our Contributors' Guide (https://github.com/apache/airflow/blob/main/contributing-docs/README.rst) Here are some useful points: - Pay attention to the quality of your code (ruff, mypy and type annotations). Our [prek-hooks]( https://github.com/apache/airflow/blob/main/contributing-docs/08_static_code_checks.rst#prerequisites-for-prek-hooks) will help you with that. - In case of a new feature add useful documentation (in docstrings or in `docs/` directory). Adding a new operator? Check this short [guide](https://github.com/apache/airflow/blob/main/airflow-core/docs/howto/custom-operator.rst) Consider adding an example DAG that shows how users should use it. - Consider using [Breeze environment](https://github.com/apache/airflow/blob/main/dev/breeze/doc/README.rst) for testing locally, it's a heavy docker but it ships with a working Airflow and a lot of integrations. - Be patient and persistent. It might take some time to get a review or get the final approval from Committers. - Please follow [ASF Code of Conduct](https://www.apache.org/foundation/policies/conduct) for all communication including (but not limited to) comments on Pull Requests, Mailing list and Slack. - Be sure to read the [Airflow Coding style]( https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#coding-style-and-best-practices). - Always keep your Pull Requests rebased, otherwise your build might fail due to changes not related to your commits. Apache Airflow is a community-driven project and together we are making it better 🚀. In case of doubts contact the developers at: Mailing List: [email protected] Slack: https://s.apache.org/airflow-slack -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
