This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 541f600c5e2 Fix `EksPodOperator` in deferrable mode (#51255)
541f600c5e2 is described below

commit 541f600c5e2467848e38ca0164ea5b1aa6ad99fb
Author: Vincent <[email protected]>
AuthorDate: Tue Jun 3 09:30:57 2025 -0400

    Fix `EksPodOperator` in deferrable mode (#51255)
---
 .../airflow/providers/amazon/aws/operators/eks.py  | 13 ++++++++++++
 .../tests/unit/amazon/aws/operators/test_eks.py    | 24 ++++++++++++++++++++++
 .../providers/cncf/kubernetes/operators/pod.py     |  5 +++++
 .../providers/cncf/kubernetes/triggers/pod.py      | 11 ++++++++++
 .../unit/cncf/kubernetes/triggers/test_pod.py      |  1 +
 5 files changed, 54 insertions(+)

diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
index a6c42d8a83b..3c03aa3f501 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
@@ -1056,6 +1056,7 @@ class EksPodOperator(KubernetesPodOperator):
             in_cluster=self.in_cluster,
             namespace=self.namespace,
             name=self.pod_name,
+            trigger_kwargs={"eks_cluster_name": cluster_name},
             **kwargs,
         )
         # There is no need to manage the kube_config file, as it will be 
generated automatically.
@@ -1072,3 +1073,15 @@ class EksPodOperator(KubernetesPodOperator):
             eks_cluster_name=self.cluster_name, pod_namespace=self.namespace
         ) as self.config_file:
             return super().execute(context)
+
+    def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
+        eks_hook = EksHook(
+            aws_conn_id=self.aws_conn_id,
+            region_name=self.region,
+        )
+        eks_cluster_name = event["eks_cluster_name"]
+        pod_namespace = event["namespace"]
+        with eks_hook.generate_config_file(
+            eks_cluster_name=eks_cluster_name, pod_namespace=pod_namespace
+        ) as self.config_file:
+            return super().trigger_reentry(context, event)
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
index 826ed9b9971..66cda6292cc 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
@@ -823,3 +823,27 @@ class TestEksPodOperator:
         )
 
         validate_template_fields(op)
+
+    
@mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.trigger_reentry")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.generate_config_file")
+    def test_trigger_reentry(self, mock_generate_config_file, 
mock_k8s_pod_operator_trigger_reentry):
+        ti_context = mock.MagicMock(name="ti_context")
+        event = {"eks_cluster_name": "eks_cluster_name", "namespace": 
"namespace"}
+
+        op = EksPodOperator(
+            task_id="run_pod",
+            pod_name="run_pod",
+            cluster_name=CLUSTER_NAME,
+            image="amazon/aws-cli:latest",
+            cmds=["sh", "-c", "ls"],
+            labels={"demo": "hello_world"},
+            get_logs=True,
+            # Delete the pod when it reaches its final state, or the execution 
is interrupted.
+            on_finish_action="delete_pod",
+        )
+        op.trigger_reentry(ti_context, event)
+        
mock_k8s_pod_operator_trigger_reentry.assert_called_once_with(ti_context, event)
+        mock_generate_config_file.assert_called_once_with(
+            eks_cluster_name="eks_cluster_name", pod_namespace="namespace"
+        )
+        assert mock_generate_config_file.return_value.__enter__.return_value 
== op.config_file
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py
index 8706e444c4b..a8ea2fcd4f1 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -233,6 +233,7 @@ class KubernetesPodOperator(BaseOperator):
     :param logging_interval: max time in seconds that task should be in 
deferred state before
         resuming to fetch the latest logs. If ``None``, then the task will 
remain in deferred state until pod
         is done, and no logs will be visible until that time.
+    :param trigger_kwargs: additional keyword parameters passed to the trigger
     """
 
     # !!! Changes in KubernetesPodOperator's arguments should be also 
reflected in !!!
@@ -266,6 +267,7 @@ class KubernetesPodOperator(BaseOperator):
         "node_selector",
         "kubernetes_conn_id",
         "base_container_name",
+        "trigger_kwargs",
     )
     template_fields_renderers = {"env_vars": "py"}
 
@@ -339,6 +341,7 @@ class KubernetesPodOperator(BaseOperator):
         ) = None,
         progress_callback: Callable[[str], None] | None = None,
         logging_interval: int | None = None,
+        trigger_kwargs: dict | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -428,6 +431,7 @@ class KubernetesPodOperator(BaseOperator):
         self.termination_message_policy = termination_message_policy
         self.active_deadline_seconds = active_deadline_seconds
         self.logging_interval = logging_interval
+        self.trigger_kwargs = trigger_kwargs
 
         self._config_dict: dict | None = None  # TODO: remove it when removing 
convert_config_file_to_dict
         self._progress_callback = progress_callback
@@ -812,6 +816,7 @@ class KubernetesPodOperator(BaseOperator):
                 on_finish_action=self.on_finish_action.value,
                 last_log_time=last_log_time,
                 logging_interval=self.logging_interval,
+                trigger_kwargs=self.trigger_kwargs,
             ),
             method_name="trigger_reentry",
         )
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
index 6ec0c99932b..719e1f6f344 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -75,6 +75,7 @@ class KubernetesPodTrigger(BaseTrigger):
     :param logging_interval: number of seconds to wait before kicking it back 
to
         the operator to print latest logs. If ``None`` will wait until 
container done.
     :param last_log_time: where to resume logs from
+    :param trigger_kwargs: additional keyword parameters to send in the event
     """
 
     def __init__(
@@ -94,6 +95,7 @@ class KubernetesPodTrigger(BaseTrigger):
         on_finish_action: str = "delete_pod",
         last_log_time: DateTime | None = None,
         logging_interval: int | None = None,
+        trigger_kwargs: dict | None = None,
     ):
         super().__init__()
         self.pod_name = pod_name
@@ -111,6 +113,7 @@ class KubernetesPodTrigger(BaseTrigger):
         self.last_log_time = last_log_time
         self.logging_interval = logging_interval
         self.on_finish_action = OnFinishAction(on_finish_action)
+        self.trigger_kwargs = trigger_kwargs or {}
 
         self._since_time = None
 
@@ -134,6 +137,7 @@ class KubernetesPodTrigger(BaseTrigger):
                 "on_finish_action": self.on_finish_action.value,
                 "last_log_time": self.last_log_time,
                 "logging_interval": self.logging_interval,
+                "trigger_kwargs": self.trigger_kwargs,
             },
         )
 
@@ -149,6 +153,7 @@ class KubernetesPodTrigger(BaseTrigger):
                         "namespace": self.pod_namespace,
                         "name": self.pod_name,
                         "message": "All containers inside pod have started 
successfully.",
+                        **self.trigger_kwargs,
                     }
                 )
             elif state == ContainerState.FAILED:
@@ -158,6 +163,7 @@ class KubernetesPodTrigger(BaseTrigger):
                         "namespace": self.pod_namespace,
                         "name": self.pod_name,
                         "message": "pod failed",
+                        **self.trigger_kwargs,
                     }
                 )
             else:
@@ -172,6 +178,7 @@ class KubernetesPodTrigger(BaseTrigger):
                     "namespace": self.pod_namespace,
                     "status": "timeout",
                     "message": message,
+                    **self.trigger_kwargs,
                 }
             )
             return
@@ -183,6 +190,7 @@ class KubernetesPodTrigger(BaseTrigger):
                     "status": "error",
                     "message": str(e),
                     "stack_trace": traceback.format_exc(),
+                    **self.trigger_kwargs,
                 }
             )
             return
@@ -234,6 +242,7 @@ class KubernetesPodTrigger(BaseTrigger):
                         "namespace": self.pod_namespace,
                         "name": self.pod_name,
                         "last_log_time": self.last_log_time,
+                        **self.trigger_kwargs,
                     }
                 )
             if container_state == ContainerState.FAILED:
@@ -244,6 +253,7 @@ class KubernetesPodTrigger(BaseTrigger):
                         "name": self.pod_name,
                         "message": "Container state failed",
                         "last_log_time": self.last_log_time,
+                        **self.trigger_kwargs,
                     }
                 )
             self.log.debug("Container is not completed and still working.")
@@ -254,6 +264,7 @@ class KubernetesPodTrigger(BaseTrigger):
                         "last_log_time": self.last_log_time,
                         "namespace": self.pod_namespace,
                         "name": self.pod_name,
+                        **self.trigger_kwargs,
                     }
                 )
             self.log.debug("Sleeping for %s seconds.", self.poll_interval)
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
index 2afb34aa3df..66fae2524d6 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
@@ -111,6 +111,7 @@ class TestKubernetesPodTrigger:
             "on_finish_action": ON_FINISH_ACTION,
             "last_log_time": None,
             "logging_interval": None,
+            "trigger_kwargs": {},
         }
 
     @pytest.mark.asyncio

Reply via email to