This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6dc53524a5 Create a generic callbacks class for KubernetesPodOperator 
(#35714)
6dc53524a5 is described below

commit 6dc53524a50d38c9b74e83562cab6f2c63818de5
Author: Hussein Awala <[email protected]>
AuthorDate: Sat Jan 20 18:23:28 2024 +0100

    Create a generic callbacks class for KubernetesPodOperator (#35714)
    
    * Create a generic callbacks class for KubernetesPodOperator
    
    * Trigger tests with old-style union
    
    * Fix GCP K8S test
    
    * Fix callback param type and cleanup calls
    
    * Some fixes and add unit tests
    
    * Replace type by Type
    
    * Reset mock_callbacks in pod manager tests
    
    * Fix static checks
    
    * Add a doc paragraph for the new callbacks
    
    * Add a check for cncf-kuberntes version in google provider
    
    * Switch to None default value and bump min cncf-k8s provider in google 
provider
    
    * Fix tests
    
    * Reduce check intervals to avoid killing the asyncio task
    
    * Revert async callbacks
    
    * fix breeze tests
---
 airflow/providers/cncf/kubernetes/callbacks.py     | 111 +++++++++++++++++++
 airflow/providers/cncf/kubernetes/operators/pod.py |  48 +++++++-
 .../providers/cncf/kubernetes/utils/pod_manager.py |  22 +++-
 .../operators.rst                                  |  75 +++++++++++++
 .../cncf/kubernetes/operators/test_pod.py          | 121 ++++++++++++++++++++-
 tests/providers/cncf/kubernetes/test_callbacks.py  |  65 +++++++++++
 .../cncf/kubernetes/utils/test_pod_manager.py      |  31 +++++-
 7 files changed, 462 insertions(+), 11 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/callbacks.py 
b/airflow/providers/cncf/kubernetes/callbacks.py
new file mode 100644
index 0000000000..4baef440de
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/callbacks.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from enum import Enum
+from typing import Union
+
+import kubernetes.client as k8s
+import kubernetes_asyncio.client as async_k8s
+
+client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api]
+
+
+class ExecutionMode(str, Enum):
+    """Enum class for execution mode."""
+
+    SYNC = "sync"
+    ASYNC = "async"
+
+
+class KubernetesPodOperatorCallback:
+    """`KubernetesPodOperator` callbacks methods.
+
+    Currently, the callbacks methods are not called in the async mode, this 
support will be added
+    in the future.
+    """
+
+    @staticmethod
+    def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None:
+        """Callback method called after creating the sync client.
+
+        :param client: the created `kubernetes.client.CoreV1Api` client.
+        """
+        pass
+
+    @staticmethod
+    def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, 
**kwargs) -> None:
+        """Callback method called after creating the pod.
+
+        :param pod: the created pod.
+        :param client: the Kubernetes client that can be used in the callback.
+        :param mode: the current execution mode, it's one of (`sync`, `async`).
+        """
+        pass
+
+    @staticmethod
+    def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, 
**kwargs) -> None:
+        """Callback method called when the pod starts.
+
+        :param pod: the started pod.
+        :param client: the Kubernetes client that can be used in the callback.
+        :param mode: the current execution mode, it's one of (`sync`, `async`).
+        """
+        pass
+
+    @staticmethod
+    def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, 
**kwargs) -> None:
+        """Callback method called when the pod completes.
+
+        :param pod: the completed pod.
+        :param client: the Kubernetes client that can be used in the callback.
+        :param mode: the current execution mode, it's one of (`sync`, `async`).
+        """
+        pass
+
+    @staticmethod
+    def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, 
**kwargs):
+        """Callback method called after cleaning/deleting the pod.
+
+        :param pod: the completed pod.
+        :param client: the Kubernetes client that can be used in the callback.
+        :param mode: the current execution mode, it's one of (`sync`, `async`).
+        """
+        pass
+
+    @staticmethod
+    def on_operator_resuming(
+        *, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, 
**kwargs
+    ) -> None:
+        """Callback method called when resuming the `KubernetesPodOperator` 
from deferred state.
+
+        :param pod: the current state of the pod.
+        :param event: the returned event from the Trigger.
+        :param client: the Kubernetes client that can be used in the callback.
+        :param mode: the current execution mode, it's one of (`sync`, `async`).
+        """
+        pass
+
+    @staticmethod
+    def progress_callback(*, line: str, client: client_type, mode: str, 
**kwargs) -> None:
+        """Callback method to process pod container logs.
+
+        :param line: the read line of log.
+        :param client: the Kubernetes client that can be used in the callback.
+        :param mode: the current execution mode, it's one of (`sync`, `async`).
+        """
+        pass
diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py 
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 70f8bc2252..0a06ea6ec6 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -49,6 +49,7 @@ from 
airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters im
     convert_volume,
     convert_volume_mount,
 )
+from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, 
KubernetesPodOperatorCallback
 from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
 from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
     POD_NAME_MAX_LENGTH,
@@ -198,7 +199,10 @@ class KubernetesPodOperator(BaseOperator):
         Default value is "File"
     :param active_deadline_seconds: The active_deadline_seconds which 
translates to active_deadline_seconds
         in V1PodSpec.
+    :param callbacks: KubernetesPodOperatorCallback instance contains the 
callbacks methods on different step
+        of KubernetesPodOperator.
     :param progress_callback: Callback function for receiving k8s container 
logs.
+        `progress_callback` is deprecated, please use :param `callbacks` 
instead.
     """
 
     # !!! Changes in KubernetesPodOperator's arguments should be also 
reflected in !!!
@@ -290,6 +294,7 @@ class KubernetesPodOperator(BaseOperator):
         is_delete_operator_pod: None | bool = None,
         termination_message_policy: str = "File",
         active_deadline_seconds: int | None = None,
+        callbacks: type[KubernetesPodOperatorCallback] | None = None,
         progress_callback: Callable[[str], None] | None = None,
         **kwargs,
     ) -> None:
@@ -381,6 +386,7 @@ class KubernetesPodOperator(BaseOperator):
 
         self._config_dict: dict | None = None  # TODO: remove it when removing 
convert_config_file_to_dict
         self._progress_callback = progress_callback
+        self.callbacks = callbacks
         self._killed: bool = False
 
     @cached_property
@@ -459,7 +465,9 @@ class KubernetesPodOperator(BaseOperator):
 
     @cached_property
     def pod_manager(self) -> PodManager:
-        return PodManager(kube_client=self.client, 
progress_callback=self._progress_callback)
+        return PodManager(
+            kube_client=self.client, callbacks=self.callbacks, 
progress_callback=self._progress_callback
+        )
 
     @cached_property
     def hook(self) -> PodOperatorHookProtocol:
@@ -473,7 +481,10 @@ class KubernetesPodOperator(BaseOperator):
 
     @cached_property
     def client(self) -> CoreV1Api:
-        return self.hook.core_v1_client
+        client = self.hook.core_v1_client
+        if self.callbacks:
+            self.callbacks.on_sync_client_creation(client=client)
+        return client
 
     def find_pod(self, namespace: str, context: Context, *, exclude_checked: 
bool = True) -> k8s.V1Pod | None:
         """Return an already-running pod for this task instance if one 
exists."""
@@ -552,7 +563,17 @@ class KubernetesPodOperator(BaseOperator):
 
             # get remote pod for use in cleanup methods
             self.remote_pod = self.find_pod(self.pod.metadata.namespace, 
context=context)
+            if self.callbacks:
+                self.callbacks.on_pod_creation(
+                    pod=self.remote_pod, client=self.client, 
mode=ExecutionMode.SYNC
+                )
             self.await_pod_start(pod=self.pod)
+            if self.callbacks:
+                self.callbacks.on_pod_starting(
+                    pod=self.find_pod(self.pod.metadata.namespace, 
context=context),
+                    client=self.client,
+                    mode=ExecutionMode.SYNC,
+                )
 
             if self.get_logs:
                 self.pod_manager.fetch_requested_container_logs(
@@ -566,6 +587,12 @@ class KubernetesPodOperator(BaseOperator):
                 self.pod_manager.await_container_completion(
                     pod=self.pod, container_name=self.base_container_name
                 )
+            if self.callbacks:
+                self.callbacks.on_pod_completion(
+                    pod=self.find_pod(self.pod.metadata.namespace, 
context=context),
+                    client=self.client,
+                    mode=ExecutionMode.SYNC,
+                )
 
             if self.do_xcom_push:
                 
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
@@ -575,10 +602,13 @@ class KubernetesPodOperator(BaseOperator):
                 self.pod, istio_enabled, self.base_container_name
             )
         finally:
+            pod_to_clean = self.pod or self.pod_request_obj
             self.cleanup(
-                pod=self.pod or self.pod_request_obj,
+                pod=pod_to_clean,
                 remote_pod=self.remote_pod,
             )
+            if self.callbacks:
+                self.callbacks.on_pod_cleanup(pod=pod_to_clean, 
client=self.client, mode=ExecutionMode.SYNC)
 
         if self.do_xcom_push:
             return result
@@ -589,6 +619,12 @@ class KubernetesPodOperator(BaseOperator):
             pod_request_obj=self.pod_request_obj,
             context=context,
         )
+        if self.callbacks:
+            self.callbacks.on_pod_creation(
+                pod=self.find_pod(self.pod.metadata.namespace, 
context=context),
+                client=self.client,
+                mode=ExecutionMode.SYNC,
+            )
         ti = context["ti"]
         ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
         ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
@@ -625,6 +661,10 @@ class KubernetesPodOperator(BaseOperator):
                 event["name"],
                 event["namespace"],
             )
+            if self.callbacks:
+                self.callbacks.on_operator_resuming(
+                    pod=pod, event=event, client=self.client, 
mode=ExecutionMode.SYNC
+                )
             if event["status"] in ("error", "failed", "timeout"):
                 # fetch some logs when pod is failed
                 if self.get_logs:
@@ -677,6 +717,8 @@ class KubernetesPodOperator(BaseOperator):
             pod=pod,
             remote_pod=remote_pod,
         )
+        if self.callbacks:
+            self.callbacks.on_pod_cleanup(pod=pod, client=self.client, 
mode=ExecutionMode.SYNC)
 
     def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
         # If a task got marked as failed, "on_kill" method would be called and 
the pod will be cleaned up
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py 
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index e2d0efac83..0e736daa6a 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -40,6 +40,7 @@ from typing_extensions import Literal
 from urllib3.exceptions import HTTPError, TimeoutError
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
+from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, 
KubernetesPodOperatorCallback
 from airflow.providers.cncf.kubernetes.pod_generator import PodDefaults
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.timezone import utcnow
@@ -50,6 +51,7 @@ if TYPE_CHECKING:
     from kubernetes.client.models.v1_pod import V1Pod
     from urllib3.response import HTTPResponse
 
+
 EMPTY_XCOM_RESULT = "__airflow_xcom_result_empty__"
 """
 Sentinel for no xcom result.
@@ -287,18 +289,22 @@ class PodManager(LoggingMixin):
     def __init__(
         self,
         kube_client: client.CoreV1Api,
+        callbacks: type[KubernetesPodOperatorCallback] | None = None,
         progress_callback: Callable[[str], None] | None = None,
     ):
         """
         Create the launcher.
 
         :param kube_client: kubernetes client
+        :param callbacks:
         :param progress_callback: Callback function invoked when fetching 
container log.
+            This parameter is deprecated, please use ````
         """
         super().__init__()
         self._client = kube_client
         self._progress_callback = progress_callback
         self._watch = watch.Watch()
+        self._callbacks = callbacks
 
     def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
         """Run POD asynchronously."""
@@ -441,9 +447,13 @@ class PodManager(LoggingMixin):
                                 message_timestamp = line_timestamp
                                 progress_callback_lines.append(line)
                             else:  # previous log line is complete
-                                if self._progress_callback:
-                                    for line in progress_callback_lines:
+                                for line in progress_callback_lines:
+                                    if self._progress_callback:
                                         self._progress_callback(line)
+                                    if self._callbacks:
+                                        self._callbacks.progress_callback(
+                                            line=line, client=self._client, 
mode=ExecutionMode.SYNC
+                                        )
                                 self.log.info("[%s] %s", container_name, 
message_to_log)
                                 last_captured_timestamp = message_timestamp
                                 message_to_log = message
@@ -454,9 +464,13 @@ class PodManager(LoggingMixin):
                             progress_callback_lines.append(line)
                 finally:
                     # log the last line and update the last_captured_timestamp
-                    if self._progress_callback:
-                        for line in progress_callback_lines:
+                    for line in progress_callback_lines:
+                        if self._progress_callback:
                             self._progress_callback(line)
+                        if self._callbacks:
+                            self._callbacks.progress_callback(
+                                line=line, client=self._client, 
mode=ExecutionMode.SYNC
+                            )
                     self.log.info("[%s] %s", container_name, message_to_log)
                     last_captured_timestamp = message_timestamp
             except TimeoutError as e:
diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst 
b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
index 690a857ea0..f24fd61602 100644
--- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
+++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
@@ -195,6 +195,81 @@ included in the exception message if the task fails.
 
 Read more on termination-log `here 
<https://kubernetes.io/docs/tasks/debug/debug-application/determine-reason-pod-failure/>`__.
 
+KubernetesPodOperator callbacks
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The 
:class:`~airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator`
 supports different
+callbacks that can be used to trigger actions during the lifecycle of the pod. 
In order to use them, you need to
+create a subclass of 
:class:`~airflow.providers.cncf.kubernetes.callbacks.KubernetesPodOperatorCallback`
 and override
+the callbacks methods you want to use. Then you can pass your callback class 
to the operator using the ``callbacks``
+parameter.
+
+The following callbacks are supported:
+
+* on_sync_client_creation: called after creating the sync client
+* on_pod_creation: called after creating the pod
+* on_pod_starting: called after the pod starts
+* on_pod_completion: called when the pod completes
+* on_pod_cleanup: called after cleaning/deleting the pod
+* on_operator_resuming: when resuming the task from deferred state
+* progress_callback: called on each line of containers logs
+
+Currently, the callbacks methods are not called in the async mode, this 
support will be added in the future.
+
+Example:
+~~~~~~~~
+.. code-block:: python
+
+    import kubernetes.client as k8s
+    import kubernetes_asyncio.client as async_k8s
+
+    from airflow.providers.cncf.kubernetes.operators.pod import 
KubernetesPodOperator
+    from airflow.providers.cncf.kubernetes.callbacks import 
KubernetesPodOperatorCallback
+
+
+    class MyCallback(KubernetesPodOperatorCallback):
+        @staticmethod
+        def on_pod_creation(*, pod: k8s.V1Pod, client: k8s.CoreV1Api, mode: 
str, **kwargs) -> None:
+            client.create_namespaced_service(
+                namespace=pod.metadata.namespace,
+                body=k8s.V1Service(
+                    metadata=k8s.V1ObjectMeta(
+                        name=pod.metadata.name,
+                        labels=pod.metadata.labels,
+                        owner_references=[
+                            k8s.V1OwnerReference(
+                                api_version=pod.api_version,
+                                kind=pod.kind,
+                                name=pod.metadata.name,
+                                uid=pod.metadata.uid,
+                                controller=True,
+                                block_owner_deletion=True,
+                            )
+                        ],
+                    ),
+                    spec=k8s.V1ServiceSpec(
+                        selector=pod.metadata.labels,
+                        ports=[
+                            k8s.V1ServicePort(
+                                name="http",
+                                port=80,
+                                target_port=80,
+                            )
+                        ],
+                    ),
+                ),
+            )
+
+
+    k = KubernetesPodOperator(
+        task_id="test_callback",
+        image="alpine",
+        cmds=["/bin/sh"],
+        arguments=["-c", "echo hello world; echo Custom error > 
/dev/termination-log; exit 1;"],
+        name="test-callback",
+        callbacks=MyCallback,
+    )
+
 Reference
 ^^^^^^^^^
 For further information, look at:
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py 
b/tests/providers/cncf/kubernetes/operators/test_pod.py
index 690dece953..a737e06cbe 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -1446,6 +1446,119 @@ class TestKubernetesPodOperator:
         # check that we wait for the xcom sidecar to start before extracting 
XCom
         mock_await_xcom_sidecar.assert_called_once_with(pod=pod)
 
+    @patch(HOOK_CLASS, new=MagicMock)
+    @patch(KUB_OP_PATH.format("find_pod"))
+    def test_execute_sync_callbacks(self, find_pod_mock):
+        from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode
+
+        from ..test_callbacks import MockKubernetesPodOperatorCallback, 
MockWrapper
+
+        MockWrapper.reset()
+        mock_callbacks = MockWrapper.mock_callbacks
+        found_pods = [MagicMock(), MagicMock(), MagicMock()]
+        find_pod_mock.side_effect = [None] + found_pods
+
+        remote_pod_mock = MagicMock()
+        remote_pod_mock.status.phase = "Succeeded"
+        self.await_pod_mock.return_value = remote_pod_mock
+        k = KubernetesPodOperator(
+            namespace="default",
+            image="ubuntu:16.04",
+            cmds=["bash", "-cx"],
+            arguments=["echo 10"],
+            labels={"foo": "bar"},
+            name="test",
+            task_id="task",
+            do_xcom_push=False,
+            callbacks=MockKubernetesPodOperatorCallback,
+        )
+        self.run_pod(k)
+
+        # check on_sync_client_creation callback
+        mock_callbacks.on_sync_client_creation.assert_called_once()
+        assert mock_callbacks.on_sync_client_creation.call_args.kwargs == 
{"client": k.client}
+
+        # check on_pod_creation callback
+        mock_callbacks.on_pod_creation.assert_called_once()
+        assert mock_callbacks.on_pod_creation.call_args.kwargs == {
+            "client": k.client,
+            "mode": ExecutionMode.SYNC,
+            "pod": found_pods[0],
+        }
+
+        # check on_pod_starting callback
+        mock_callbacks.on_pod_starting.assert_called_once()
+        assert mock_callbacks.on_pod_starting.call_args.kwargs == {
+            "client": k.client,
+            "mode": ExecutionMode.SYNC,
+            "pod": found_pods[1],
+        }
+
+        # check on_pod_completion callback
+        mock_callbacks.on_pod_completion.assert_called_once()
+        assert mock_callbacks.on_pod_completion.call_args.kwargs == {
+            "client": k.client,
+            "mode": ExecutionMode.SYNC,
+            "pod": found_pods[2],
+        }
+
+        # check on_pod_cleanup callback
+        mock_callbacks.on_pod_cleanup.assert_called_once()
+        assert mock_callbacks.on_pod_cleanup.call_args.kwargs == {
+            "client": k.client,
+            "mode": ExecutionMode.SYNC,
+            "pod": k.pod,
+        }
+
+    @patch(HOOK_CLASS, new=MagicMock)
+    def test_execute_async_callbacks(self):
+        from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode
+
+        from ..test_callbacks import MockKubernetesPodOperatorCallback, 
MockWrapper
+
+        MockWrapper.reset()
+        mock_callbacks = MockWrapper.mock_callbacks
+        remote_pod_mock = MagicMock()
+        remote_pod_mock.status.phase = "Succeeded"
+        self.await_pod_mock.return_value = remote_pod_mock
+
+        k = KubernetesPodOperator(
+            namespace="default",
+            image="ubuntu:16.04",
+            cmds=["bash", "-cx"],
+            arguments=["echo 10"],
+            labels={"foo": "bar"},
+            name="test",
+            task_id="task",
+            do_xcom_push=False,
+            callbacks=MockKubernetesPodOperatorCallback,
+        )
+        k.execute_complete(
+            context=create_context(k),
+            event={
+                "status": "success",
+                "message": TEST_SUCCESS_MESSAGE,
+                "name": TEST_NAME,
+                "namespace": TEST_NAMESPACE,
+            },
+        )
+
+        # check on_operator_resuming callback
+        mock_callbacks.on_pod_cleanup.assert_called_once()
+        assert mock_callbacks.on_pod_cleanup.call_args.kwargs == {
+            "client": k.client,
+            "mode": ExecutionMode.SYNC,
+            "pod": remote_pod_mock,
+        }
+
+        # check on_pod_cleanup callback
+        mock_callbacks.on_pod_cleanup.assert_called_once()
+        assert mock_callbacks.on_pod_cleanup.call_args.kwargs == {
+            "client": k.client,
+            "mode": ExecutionMode.SYNC,
+            "pod": remote_pod_mock,
+        }
+
 
 class TestSuppress:
     def test__suppress(self, caplog):
@@ -1554,9 +1667,13 @@ class TestKubernetesPodOperatorAsync:
         return remote_pod_mock
 
     @pytest.mark.parametrize("do_xcom_push", [True, False])
+    @patch(KUB_OP_PATH.format("client"))
+    @patch(KUB_OP_PATH.format("find_pod"))
     @patch(KUB_OP_PATH.format("build_pod_request_obj"))
     @patch(KUB_OP_PATH.format("get_or_create_pod"))
-    def test_async_create_pod_should_execute_successfully(self, mocked_pod, 
mocked_pod_obj, do_xcom_push):
+    def test_async_create_pod_should_execute_successfully(
+        self, mocked_pod, mocked_pod_obj, mocked_found_pod, mocked_client, 
do_xcom_push
+    ):
         """
         Asserts that a task is deferred and the KubernetesCreatePodTrigger 
will be fired
         when the KubernetesPodOperator is executed in deferrable mode when 
deferrable=True.
@@ -1584,7 +1701,7 @@ class TestKubernetesPodOperatorAsync:
         mocked_pod.return_value.metadata.namespace = TEST_NAMESPACE
 
         context = create_context(k)
-        ti_mock = MagicMock()
+        ti_mock = MagicMock(**{"map_index": -1})
         context["ti"] = ti_mock
 
         with pytest.raises(TaskDeferred) as exc:
diff --git a/tests/providers/cncf/kubernetes/test_callbacks.py 
b/tests/providers/cncf/kubernetes/test_callbacks.py
new file mode 100644
index 0000000000..2757b8296a
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/test_callbacks.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+from airflow.providers.cncf.kubernetes.callbacks import 
KubernetesPodOperatorCallback
+
+
+class MockWrapper:
+    mock_callbacks = MagicMock()
+
+    @classmethod
+    def reset(cls):
+        cls.mock_callbacks.reset_mock()
+
+
+class MockKubernetesPodOperatorCallback(KubernetesPodOperatorCallback):
+    """`KubernetesPodOperator` callbacks methods."""
+
+    @staticmethod
+    def on_sync_client_creation(*args, **kwargs) -> None:
+        MockWrapper.mock_callbacks.on_sync_client_creation(*args, **kwargs)
+
+    @staticmethod
+    def on_async_client_creation(*args, **kwargs) -> None:
+        MockWrapper.mock_callbacks.on_async_client_creation(*args, **kwargs)
+
+    @staticmethod
+    def on_pod_creation(*args, **kwargs) -> None:
+        MockWrapper.mock_callbacks.on_pod_creation(*args, **kwargs)
+
+    @staticmethod
+    def on_pod_starting(*args, **kwargs) -> None:
+        MockWrapper.mock_callbacks.on_pod_starting(*args, **kwargs)
+
+    @staticmethod
+    def on_pod_completion(*args, **kwargs) -> None:
+        MockWrapper.mock_callbacks.on_pod_completion(*args, **kwargs)
+
+    @staticmethod
+    def on_pod_cleanup(*args, **kwargs) -> None:
+        MockWrapper.mock_callbacks.on_pod_cleanup(*args, **kwargs)
+
+    @staticmethod
+    def on_operator_resuming(*args, **kwargs) -> None:
+        MockWrapper.mock_callbacks.on_operator_resuming(*args, **kwargs)
+
+    @staticmethod
+    def progress_callback(*args, **kwargs) -> None:
+        MockWrapper.mock_callbacks.progress_callback(*args, **kwargs)
diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py 
b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
index 13f8123550..fc09d6bb02 100644
--- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
+++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
@@ -41,6 +41,8 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager 
import (
 )
 from airflow.utils.timezone import utc
 
+from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper
+
 if TYPE_CHECKING:
     from pendulum import DateTime
 
@@ -50,7 +52,9 @@ class TestPodManager:
         self.mock_progress_callback = mock.Mock()
         self.mock_kube_client = mock.Mock()
         self.pod_manager = PodManager(
-            kube_client=self.mock_kube_client, 
progress_callback=self.mock_progress_callback
+            kube_client=self.mock_kube_client,
+            callbacks=MockKubernetesPodOperatorCallback,
+            progress_callback=self.mock_progress_callback,
         )
 
     def test_read_pod_logs_successfully_returns_logs(self):
@@ -274,7 +278,7 @@ class TestPodManager:
 
     
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
     
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")
-    def test_fetch_container_logs_invoke_progress_callback(
+    def test_fetch_container_logs_invoke_deprecated_progress_callback(
         self, mock_read_pod_logs, mock_container_is_running
     ):
         message = "2020-10-08T14:16:17.793417674Z message"
@@ -285,8 +289,30 @@ class TestPodManager:
         self.pod_manager.fetch_container_logs(mock.MagicMock(), 
mock.MagicMock(), follow=True)
         self.mock_progress_callback.assert_has_calls([mock.call(message), 
mock.call(no_ts_message)])
 
+    
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
+    
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")
+    def test_fetch_container_logs_invoke_progress_callback(
+        self, mock_read_pod_logs, mock_container_is_running
+    ):
+        MockWrapper.reset()
+        mock_callbacks = MockWrapper.mock_callbacks
+        message = "2020-10-08T14:16:17.793417674Z message"
+        no_ts_message = "notimestamp"
+        mock_read_pod_logs.return_value = [bytes(message, "utf-8"), 
bytes(no_ts_message, "utf-8")]
+        mock_container_is_running.return_value = False
+
+        self.pod_manager.fetch_container_logs(mock.MagicMock(), 
mock.MagicMock(), follow=True)
+        mock_callbacks.progress_callback.assert_has_calls(
+            [
+                mock.call(line=message, client=self.pod_manager._client, 
mode="sync"),
+                mock.call(line=no_ts_message, client=self.pod_manager._client, 
mode="sync"),
+            ]
+        )
+
     
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
     def test_fetch_container_logs_failures(self, mock_container_is_running):
+        MockWrapper.reset()
+        mock_callbacks = MockWrapper.mock_callbacks
         last_timestamp_string = "2020-10-08T14:18:17.793417674Z"
         messages = [
             bytes("2020-10-08T14:16:17.793417674Z message", "utf-8"),
@@ -309,6 +335,7 @@ class TestPodManager:
             status = self.pod_manager.fetch_container_logs(mock.MagicMock(), 
mock.MagicMock(), follow=True)
         assert status.last_log_time == cast("DateTime", 
pendulum.parse(last_timestamp_string))
         assert self.mock_progress_callback.call_count == expected_call_count
+        assert mock_callbacks.progress_callback.call_count == 
expected_call_count
 
     
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
     
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")

Reply via email to