This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e8b4a532d27 Fix hardcoded waiter logic in EmrCreateJobFlowOperator
(#61195)
e8b4a532d27 is described below
commit e8b4a532d2757271c45e10675ec9462b44d9eecf
Author: Henry Chen <[email protected]>
AuthorDate: Wed Feb 4 02:23:02 2026 +0800
Fix hardcoded waiter logic in EmrCreateJobFlowOperator (#61195)
The operator was incorrectly ignoring the `wait_policy` argument and always
defaulting to waiting for cluster completion.
This change ensures the `wait_policy` is correctly persisted and used to
select the appropriate waiter (e.g., for step completion), fixing the hardcoded
behavior.
---
.../airflow/providers/amazon/aws/operators/emr.py | 58 ++++++++++++++--------
.../airflow/providers/amazon/aws/triggers/emr.py | 3 +-
.../aws/operators/test_emr_create_job_flow.py | 36 +++++++++++---
3 files changed, 66 insertions(+), 31 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
index 48e8773c16a..1dc3d54f4d3 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
@@ -24,7 +24,6 @@ from datetime import timedelta
from typing import TYPE_CHECKING, Any
from uuid import uuid4
-from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook,
EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import (
EmrClusterLink,
@@ -657,7 +656,7 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
:param wait_for_completion: Whether to finish task immediately after
creation (False) or wait for jobflow
completion (True)
(default: None)
- :param wait_policy: Deprecated. Use `wait_for_completion` instead. Whether
to finish the task immediately after creation (None) or:
+ :param wait_policy: Whether to finish the task immediately after creation
(None) or:
- wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION)
- wait for the jobflow completion and cluster to terminate
(WaitPolicy.WAIT_FOR_STEPS_COMPLETION)
(default: None)
@@ -697,29 +696,35 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
super().__init__(**kwargs)
self.emr_conn_id = emr_conn_id
self.job_flow_overrides = job_flow_overrides or {}
- self.wait_for_completion = wait_for_completion
self.waiter_max_attempts = waiter_max_attempts or 60
self.waiter_delay = waiter_delay or 60
self.deferrable = deferrable
-
- if wait_policy is not None:
- warnings.warn(
- "`wait_policy` parameter is deprecated and will be removed in
a future release; "
- "please use `wait_for_completion` (bool) instead.",
- AirflowProviderDeprecationWarning,
- stacklevel=2,
- )
-
- if wait_for_completion is not None:
- raise ValueError(
- "Cannot specify both `wait_for_completion` and deprecated
`wait_policy`. "
- "Please use `wait_for_completion` (bool)."
+ self.wait_policy = wait_policy
+
+ # Backwards-compatible default: if the user requested waiting for
+ # completion (wait_for_completion=True) but did not provide an
+ # explicit wait_policy, default the wait_policy to
+ # WaitPolicy.WAIT_FOR_COMPLETION
+ if self.wait_policy is None and wait_for_completion:
+ self.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION
+
+ # Handle deprecated wait_for_completion parameter. If wait_policy is
set,
+ # we always override wait_for_completion to True (since some form of
waiting is
+ # requested). If wait_policy is not set, we use the value of
wait_for_completion
+ # (defaulting to False if not provided).
+ if self.wait_policy is not None:
+ if wait_for_completion is False:
+ warnings.warn(
+ "Setting wait_policy while wait_for_completion is False is
deprecated. "
+ "In future, you must set wait_for_completion=True to
wait.",
+ UserWarning,
+ stacklevel=2,
)
-
- self.wait_for_completion = wait_policy in (
- WaitPolicy.WAIT_FOR_COMPLETION,
- WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
- )
+ self.wait_for_completion = True
+ elif wait_for_completion is not None:
+ self.wait_for_completion = wait_for_completion
+ else:
+ self.wait_for_completion = False
@property
def _hook_parameters(self):
@@ -758,15 +763,24 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
log_uri=get_log_uri(emr_client=self.hook.conn,
job_flow_id=self._job_flow_id),
)
if self.wait_for_completion:
- waiter_name =
WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
+ # Determine which waiter to use. Prefer explicit wait_policy when
provided,
+ # otherwise default to WAIT_FOR_COMPLETION.
+ wp = self.wait_policy
+ if wp is not None:
+ waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
+ else:
+ waiter_name =
WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
if self.deferrable:
+ # Pass the selected waiter_name to the trigger so deferrable
mode waits
+ # according to the requested policy as well.
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
+ waiter_name=waiter_name,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the
timeout does not restart
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
index 3e24fc2b6d1..356a7354d78 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
@@ -82,10 +82,11 @@ class EmrCreateJobFlowTrigger(AwsBaseWaiterTrigger):
aws_conn_id: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
+ waiter_name: str = "job_flow_waiting",
):
super().__init__(
serialized_fields={"job_flow_id": job_flow_id},
- waiter_name="job_flow_waiting",
+ waiter_name=waiter_name,
waiter_args={"ClusterId": job_flow_id},
failure_message="JobFlow creation failed",
status_message="JobFlow creation in progress",
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
index e49f7ba7758..8389979ec95 100644
---
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
+++
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
@@ -26,7 +26,6 @@ import pytest
from botocore.waiter import Waiter
from jinja2 import StrictUndefined
-from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
@@ -254,10 +253,31 @@ class TestEmrCreateJobFlowOperator:
def test_template_fields(self):
validate_template_fields(self.operator)
- def test_wait_policy_deprecation_warning(self):
- """Test that using wait_policy raises a deprecation warning."""
- with pytest.warns(AirflowProviderDeprecationWarning,
match="`wait_policy` parameter is deprecated"):
- EmrCreateJobFlowOperator(
- task_id=TASK_ID,
- wait_policy=WaitPolicy.WAIT_FOR_COMPLETION,
- )
+ def test_wait_policy_behavior(self):
+ """Test that using wait_for_completion but not pass wait_policy."""
+ op = EmrCreateJobFlowOperator(
+ task_id=TASK_ID,
+ wait_for_completion=True,
+ )
+ # wait_policy should be the default WAIT_FOR_COMPLETION
+ assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_COMPLETION
+ assert op.wait_for_completion is True
+
+ def test_specify_both_wait_for_completion_and_wait_policy(self):
+ """Passing both wait_for_completion and wait_policy."""
+ op = EmrCreateJobFlowOperator(
+ task_id=TASK_ID,
+ wait_for_completion=True,
+ wait_policy=WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
+ )
+ assert getattr(op, "wait_policy") ==
WaitPolicy.WAIT_FOR_STEPS_COMPLETION
+ assert op.wait_for_completion is True
+
+ def test_specify_only_wait_policy(self):
+ """Passing only wait_policy."""
+ op = EmrCreateJobFlowOperator(
+ task_id=TASK_ID,
+ wait_policy=WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
+ )
+ assert getattr(op, "wait_policy") ==
WaitPolicy.WAIT_FOR_STEPS_COMPLETION
+ assert op.wait_for_completion is True