This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e8b4a532d27 Fix hardcoded waiter logic in EmrCreateJobFlowOperator 
(#61195)
e8b4a532d27 is described below

commit e8b4a532d2757271c45e10675ec9462b44d9eecf
Author: Henry Chen <[email protected]>
AuthorDate: Wed Feb 4 02:23:02 2026 +0800

    Fix hardcoded waiter logic in EmrCreateJobFlowOperator (#61195)
    
    The operator was incorrectly ignoring the `wait_policy` argument and always 
defaulting to waiting for cluster completion.
    
    This change ensures the `wait_policy` is correctly persisted and used to 
select the appropriate waiter (e.g., for step completion), fixing the hardcoded 
behavior.
---
 .../airflow/providers/amazon/aws/operators/emr.py  | 58 ++++++++++++++--------
 .../airflow/providers/amazon/aws/triggers/emr.py   |  3 +-
 .../aws/operators/test_emr_create_job_flow.py      | 36 +++++++++++---
 3 files changed, 66 insertions(+), 31 deletions(-)

diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
index 48e8773c16a..1dc3d54f4d3 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
@@ -24,7 +24,6 @@ from datetime import timedelta
 from typing import TYPE_CHECKING, Any
 from uuid import uuid4
 
-from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, 
EmrServerlessHook
 from airflow.providers.amazon.aws.links.emr import (
     EmrClusterLink,
@@ -657,7 +656,7 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
     :param wait_for_completion: Whether to finish task immediately after 
creation (False) or wait for jobflow
         completion (True)
         (default: None)
-    :param wait_policy: Deprecated. Use `wait_for_completion` instead. Whether 
to finish the task immediately after creation (None) or:
+    :param wait_policy: Whether to finish the task immediately after creation 
(None) or:
         - wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION)
         - wait for the jobflow completion and cluster to terminate 
(WaitPolicy.WAIT_FOR_STEPS_COMPLETION)
         (default: None)
@@ -697,29 +696,35 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
         super().__init__(**kwargs)
         self.emr_conn_id = emr_conn_id
         self.job_flow_overrides = job_flow_overrides or {}
-        self.wait_for_completion = wait_for_completion
         self.waiter_max_attempts = waiter_max_attempts or 60
         self.waiter_delay = waiter_delay or 60
         self.deferrable = deferrable
-
-        if wait_policy is not None:
-            warnings.warn(
-                "`wait_policy` parameter is deprecated and will be removed in 
a future release; "
-                "please use `wait_for_completion` (bool) instead.",
-                AirflowProviderDeprecationWarning,
-                stacklevel=2,
-            )
-
-            if wait_for_completion is not None:
-                raise ValueError(
-                    "Cannot specify both `wait_for_completion` and deprecated 
`wait_policy`. "
-                    "Please use `wait_for_completion` (bool)."
+        self.wait_policy = wait_policy
+
+        # Backwards-compatible default: if the user requested waiting for
+        # completion (wait_for_completion=True) but did not provide an
+        # explicit wait_policy, default the wait_policy to
+        # WaitPolicy.WAIT_FOR_COMPLETION
+        if self.wait_policy is None and wait_for_completion:
+            self.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION
+
+        # Handle deprecated wait_for_completion parameter. If wait_policy is 
set,
+        # we always override wait_for_completion to True (since some form of 
waiting is
+        # requested). If wait_policy is not set, we use the value of 
wait_for_completion
+        # (defaulting to False if not provided).
+        if self.wait_policy is not None:
+            if wait_for_completion is False:
+                warnings.warn(
+                    "Setting wait_policy while wait_for_completion is False is 
deprecated. "
+                    "In future, you must set wait_for_completion=True to 
wait.",
+                    UserWarning,
+                    stacklevel=2,
                 )
-
-            self.wait_for_completion = wait_policy in (
-                WaitPolicy.WAIT_FOR_COMPLETION,
-                WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
-            )
+            self.wait_for_completion = True
+        elif wait_for_completion is not None:
+            self.wait_for_completion = wait_for_completion
+        else:
+            self.wait_for_completion = False
 
     @property
     def _hook_parameters(self):
@@ -758,15 +763,24 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
                 log_uri=get_log_uri(emr_client=self.hook.conn, 
job_flow_id=self._job_flow_id),
             )
         if self.wait_for_completion:
-            waiter_name = 
WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
+            # Determine which waiter to use. Prefer explicit wait_policy when 
provided,
+            # otherwise default to WAIT_FOR_COMPLETION.
+            wp = self.wait_policy
+            if wp is not None:
+                waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
+            else:
+                waiter_name = 
WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
 
             if self.deferrable:
+                # Pass the selected waiter_name to the trigger so deferrable 
mode waits
+                # according to the requested policy as well.
                 self.defer(
                     trigger=EmrCreateJobFlowTrigger(
                         job_flow_id=self._job_flow_id,
                         aws_conn_id=self.aws_conn_id,
                         waiter_delay=self.waiter_delay,
                         waiter_max_attempts=self.waiter_max_attempts,
+                        waiter_name=waiter_name,
                     ),
                     method_name="execute_complete",
                     # timeout is set to ensure that if a trigger dies, the 
timeout does not restart
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
index 3e24fc2b6d1..356a7354d78 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
@@ -82,10 +82,11 @@ class EmrCreateJobFlowTrigger(AwsBaseWaiterTrigger):
         aws_conn_id: str | None = None,
         waiter_delay: int = 30,
         waiter_max_attempts: int = 60,
+        waiter_name: str = "job_flow_waiting",
     ):
         super().__init__(
             serialized_fields={"job_flow_id": job_flow_id},
-            waiter_name="job_flow_waiting",
+            waiter_name=waiter_name,
             waiter_args={"ClusterId": job_flow_id},
             failure_message="JobFlow creation failed",
             status_message="JobFlow creation in progress",
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
index e49f7ba7758..8389979ec95 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
@@ -26,7 +26,6 @@ import pytest
 from botocore.waiter import Waiter
 from jinja2 import StrictUndefined
 
-from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.models import DAG, DagRun, TaskInstance
 from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
 from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
@@ -254,10 +253,31 @@ class TestEmrCreateJobFlowOperator:
     def test_template_fields(self):
         validate_template_fields(self.operator)
 
-    def test_wait_policy_deprecation_warning(self):
-        """Test that using wait_policy raises a deprecation warning."""
-        with pytest.warns(AirflowProviderDeprecationWarning, 
match="`wait_policy` parameter is deprecated"):
-            EmrCreateJobFlowOperator(
-                task_id=TASK_ID,
-                wait_policy=WaitPolicy.WAIT_FOR_COMPLETION,
-            )
+    def test_wait_policy_behavior(self):
+        """Test that using wait_for_completion but not pass wait_policy."""
+        op = EmrCreateJobFlowOperator(
+            task_id=TASK_ID,
+            wait_for_completion=True,
+        )
+        # wait_policy should be the default WAIT_FOR_COMPLETION
+        assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_COMPLETION
+        assert op.wait_for_completion is True
+
+    def test_specify_both_wait_for_completion_and_wait_policy(self):
+        """Passing both wait_for_completion and wait_policy."""
+        op = EmrCreateJobFlowOperator(
+            task_id=TASK_ID,
+            wait_for_completion=True,
+            wait_policy=WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
+        )
+        assert getattr(op, "wait_policy") == 
WaitPolicy.WAIT_FOR_STEPS_COMPLETION
+        assert op.wait_for_completion is True
+
+    def test_specify_only_wait_policy(self):
+        """Passing only wait_policy."""
+        op = EmrCreateJobFlowOperator(
+            task_id=TASK_ID,
+            wait_policy=WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
+        )
+        assert getattr(op, "wait_policy") == 
WaitPolicy.WAIT_FOR_STEPS_COMPLETION
+        assert op.wait_for_completion is True

Reply via email to