This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6dc53524a5 Create a generic callbacks class for KubernetesPodOperator
(#35714)
6dc53524a5 is described below
commit 6dc53524a50d38c9b74e83562cab6f2c63818de5
Author: Hussein Awala <[email protected]>
AuthorDate: Sat Jan 20 18:23:28 2024 +0100
Create a generic callbacks class for KubernetesPodOperator (#35714)
* Create a generic callbacks class for KubernetesPodOperator
* Trigger tests with old-style union
* Fix GCP K8S test
* Fix callback param type and cleanup calls
* Some fixes and add unit tests
* Replace type by Type
* Reset mock_callbacks in pod manager tests
* Fix static checks
* Add a doc paragraph for the new callbacks
* Add a check for cncf-kuberntes version in google provider
* Switch to None default value and bump min cncf-k8s provider in google
provider
* Fix tests
* Reduce check intervals to avoid killing the asyncio task
* Revert async callbacks
* fix breeze tests
---
airflow/providers/cncf/kubernetes/callbacks.py | 111 +++++++++++++++++++
airflow/providers/cncf/kubernetes/operators/pod.py | 48 +++++++-
.../providers/cncf/kubernetes/utils/pod_manager.py | 22 +++-
.../operators.rst | 75 +++++++++++++
.../cncf/kubernetes/operators/test_pod.py | 121 ++++++++++++++++++++-
tests/providers/cncf/kubernetes/test_callbacks.py | 65 +++++++++++
.../cncf/kubernetes/utils/test_pod_manager.py | 31 +++++-
7 files changed, 462 insertions(+), 11 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/callbacks.py
b/airflow/providers/cncf/kubernetes/callbacks.py
new file mode 100644
index 0000000000..4baef440de
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/callbacks.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from enum import Enum
+from typing import Union
+
+import kubernetes.client as k8s
+import kubernetes_asyncio.client as async_k8s
+
+client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api]
+
+
+class ExecutionMode(str, Enum):
+ """Enum class for execution mode."""
+
+ SYNC = "sync"
+ ASYNC = "async"
+
+
+class KubernetesPodOperatorCallback:
+ """`KubernetesPodOperator` callbacks methods.
+
+ Currently, the callbacks methods are not called in the async mode, this
support will be added
+ in the future.
+ """
+
+ @staticmethod
+ def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None:
+ """Callback method called after creating the sync client.
+
+ :param client: the created `kubernetes.client.CoreV1Api` client.
+ """
+ pass
+
+ @staticmethod
+ def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str,
**kwargs) -> None:
+ """Callback method called after creating the pod.
+
+ :param pod: the created pod.
+ :param client: the Kubernetes client that can be used in the callback.
+ :param mode: the current execution mode, it's one of (`sync`, `async`).
+ """
+ pass
+
+ @staticmethod
+ def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str,
**kwargs) -> None:
+ """Callback method called when the pod starts.
+
+ :param pod: the started pod.
+ :param client: the Kubernetes client that can be used in the callback.
+ :param mode: the current execution mode, it's one of (`sync`, `async`).
+ """
+ pass
+
+ @staticmethod
+ def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str,
**kwargs) -> None:
+ """Callback method called when the pod completes.
+
+ :param pod: the completed pod.
+ :param client: the Kubernetes client that can be used in the callback.
+ :param mode: the current execution mode, it's one of (`sync`, `async`).
+ """
+ pass
+
+ @staticmethod
+ def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str,
**kwargs):
+ """Callback method called after cleaning/deleting the pod.
+
+ :param pod: the completed pod.
+ :param client: the Kubernetes client that can be used in the callback.
+ :param mode: the current execution mode, it's one of (`sync`, `async`).
+ """
+ pass
+
+ @staticmethod
+ def on_operator_resuming(
+ *, pod: k8s.V1Pod, event: dict, client: client_type, mode: str,
**kwargs
+ ) -> None:
+ """Callback method called when resuming the `KubernetesPodOperator`
from deferred state.
+
+ :param pod: the current state of the pod.
+ :param event: the returned event from the Trigger.
+ :param client: the Kubernetes client that can be used in the callback.
+ :param mode: the current execution mode, it's one of (`sync`, `async`).
+ """
+ pass
+
+ @staticmethod
+ def progress_callback(*, line: str, client: client_type, mode: str,
**kwargs) -> None:
+ """Callback method to process pod container logs.
+
+ :param line: the read line of log.
+ :param client: the Kubernetes client that can be used in the callback.
+ :param mode: the current execution mode, it's one of (`sync`, `async`).
+ """
+ pass
diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 70f8bc2252..0a06ea6ec6 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -49,6 +49,7 @@ from
airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters im
convert_volume,
convert_volume_mount,
)
+from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode,
KubernetesPodOperatorCallback
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
POD_NAME_MAX_LENGTH,
@@ -198,7 +199,10 @@ class KubernetesPodOperator(BaseOperator):
Default value is "File"
:param active_deadline_seconds: The active_deadline_seconds which
translates to active_deadline_seconds
in V1PodSpec.
+ :param callbacks: KubernetesPodOperatorCallback instance contains the
callbacks methods on different step
+ of KubernetesPodOperator.
:param progress_callback: Callback function for receiving k8s container
logs.
+ `progress_callback` is deprecated, please use :param `callbacks`
instead.
"""
# !!! Changes in KubernetesPodOperator's arguments should be also
reflected in !!!
@@ -290,6 +294,7 @@ class KubernetesPodOperator(BaseOperator):
is_delete_operator_pod: None | bool = None,
termination_message_policy: str = "File",
active_deadline_seconds: int | None = None,
+ callbacks: type[KubernetesPodOperatorCallback] | None = None,
progress_callback: Callable[[str], None] | None = None,
**kwargs,
) -> None:
@@ -381,6 +386,7 @@ class KubernetesPodOperator(BaseOperator):
self._config_dict: dict | None = None # TODO: remove it when removing
convert_config_file_to_dict
self._progress_callback = progress_callback
+ self.callbacks = callbacks
self._killed: bool = False
@cached_property
@@ -459,7 +465,9 @@ class KubernetesPodOperator(BaseOperator):
@cached_property
def pod_manager(self) -> PodManager:
- return PodManager(kube_client=self.client,
progress_callback=self._progress_callback)
+ return PodManager(
+ kube_client=self.client, callbacks=self.callbacks,
progress_callback=self._progress_callback
+ )
@cached_property
def hook(self) -> PodOperatorHookProtocol:
@@ -473,7 +481,10 @@ class KubernetesPodOperator(BaseOperator):
@cached_property
def client(self) -> CoreV1Api:
- return self.hook.core_v1_client
+ client = self.hook.core_v1_client
+ if self.callbacks:
+ self.callbacks.on_sync_client_creation(client=client)
+ return client
def find_pod(self, namespace: str, context: Context, *, exclude_checked:
bool = True) -> k8s.V1Pod | None:
"""Return an already-running pod for this task instance if one
exists."""
@@ -552,7 +563,17 @@ class KubernetesPodOperator(BaseOperator):
# get remote pod for use in cleanup methods
self.remote_pod = self.find_pod(self.pod.metadata.namespace,
context=context)
+ if self.callbacks:
+ self.callbacks.on_pod_creation(
+ pod=self.remote_pod, client=self.client,
mode=ExecutionMode.SYNC
+ )
self.await_pod_start(pod=self.pod)
+ if self.callbacks:
+ self.callbacks.on_pod_starting(
+ pod=self.find_pod(self.pod.metadata.namespace,
context=context),
+ client=self.client,
+ mode=ExecutionMode.SYNC,
+ )
if self.get_logs:
self.pod_manager.fetch_requested_container_logs(
@@ -566,6 +587,12 @@ class KubernetesPodOperator(BaseOperator):
self.pod_manager.await_container_completion(
pod=self.pod, container_name=self.base_container_name
)
+ if self.callbacks:
+ self.callbacks.on_pod_completion(
+ pod=self.find_pod(self.pod.metadata.namespace,
context=context),
+ client=self.client,
+ mode=ExecutionMode.SYNC,
+ )
if self.do_xcom_push:
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
@@ -575,10 +602,13 @@ class KubernetesPodOperator(BaseOperator):
self.pod, istio_enabled, self.base_container_name
)
finally:
+ pod_to_clean = self.pod or self.pod_request_obj
self.cleanup(
- pod=self.pod or self.pod_request_obj,
+ pod=pod_to_clean,
remote_pod=self.remote_pod,
)
+ if self.callbacks:
+ self.callbacks.on_pod_cleanup(pod=pod_to_clean,
client=self.client, mode=ExecutionMode.SYNC)
if self.do_xcom_push:
return result
@@ -589,6 +619,12 @@ class KubernetesPodOperator(BaseOperator):
pod_request_obj=self.pod_request_obj,
context=context,
)
+ if self.callbacks:
+ self.callbacks.on_pod_creation(
+ pod=self.find_pod(self.pod.metadata.namespace,
context=context),
+ client=self.client,
+ mode=ExecutionMode.SYNC,
+ )
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
@@ -625,6 +661,10 @@ class KubernetesPodOperator(BaseOperator):
event["name"],
event["namespace"],
)
+ if self.callbacks:
+ self.callbacks.on_operator_resuming(
+ pod=pod, event=event, client=self.client,
mode=ExecutionMode.SYNC
+ )
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
@@ -677,6 +717,8 @@ class KubernetesPodOperator(BaseOperator):
pod=pod,
remote_pod=remote_pod,
)
+ if self.callbacks:
+ self.callbacks.on_pod_cleanup(pod=pod, client=self.client,
mode=ExecutionMode.SYNC)
def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
# If a task got marked as failed, "on_kill" method would be called and
the pod will be cleaned up
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index e2d0efac83..0e736daa6a 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -40,6 +40,7 @@ from typing_extensions import Literal
from urllib3.exceptions import HTTPError, TimeoutError
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
+from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode,
KubernetesPodOperatorCallback
from airflow.providers.cncf.kubernetes.pod_generator import PodDefaults
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timezone import utcnow
@@ -50,6 +51,7 @@ if TYPE_CHECKING:
from kubernetes.client.models.v1_pod import V1Pod
from urllib3.response import HTTPResponse
+
EMPTY_XCOM_RESULT = "__airflow_xcom_result_empty__"
"""
Sentinel for no xcom result.
@@ -287,18 +289,22 @@ class PodManager(LoggingMixin):
def __init__(
self,
kube_client: client.CoreV1Api,
+ callbacks: type[KubernetesPodOperatorCallback] | None = None,
progress_callback: Callable[[str], None] | None = None,
):
"""
Create the launcher.
:param kube_client: kubernetes client
+ :param callbacks:
:param progress_callback: Callback function invoked when fetching
container log.
+ This parameter is deprecated, please use ````
"""
super().__init__()
self._client = kube_client
self._progress_callback = progress_callback
self._watch = watch.Watch()
+ self._callbacks = callbacks
def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
"""Run POD asynchronously."""
@@ -441,9 +447,13 @@ class PodManager(LoggingMixin):
message_timestamp = line_timestamp
progress_callback_lines.append(line)
else: # previous log line is complete
- if self._progress_callback:
- for line in progress_callback_lines:
+ for line in progress_callback_lines:
+ if self._progress_callback:
self._progress_callback(line)
+ if self._callbacks:
+ self._callbacks.progress_callback(
+ line=line, client=self._client,
mode=ExecutionMode.SYNC
+ )
self.log.info("[%s] %s", container_name,
message_to_log)
last_captured_timestamp = message_timestamp
message_to_log = message
@@ -454,9 +464,13 @@ class PodManager(LoggingMixin):
progress_callback_lines.append(line)
finally:
# log the last line and update the last_captured_timestamp
- if self._progress_callback:
- for line in progress_callback_lines:
+ for line in progress_callback_lines:
+ if self._progress_callback:
self._progress_callback(line)
+ if self._callbacks:
+ self._callbacks.progress_callback(
+ line=line, client=self._client,
mode=ExecutionMode.SYNC
+ )
self.log.info("[%s] %s", container_name, message_to_log)
last_captured_timestamp = message_timestamp
except TimeoutError as e:
diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
index 690a857ea0..f24fd61602 100644
--- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
+++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst
@@ -195,6 +195,81 @@ included in the exception message if the task fails.
Read more on termination-log `here
<https://kubernetes.io/docs/tasks/debug/debug-application/determine-reason-pod-failure/>`__.
+KubernetesPodOperator callbacks
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The
:class:`~airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator`
supports different
+callbacks that can be used to trigger actions during the lifecycle of the pod.
In order to use them, you need to
+create a subclass of
:class:`~airflow.providers.cncf.kubernetes.callbacks.KubernetesPodOperatorCallback`
and override
+the callbacks methods you want to use. Then you can pass your callback class
to the operator using the ``callbacks``
+parameter.
+
+The following callbacks are supported:
+
+* on_sync_client_creation: called after creating the sync client
+* on_pod_creation: called after creating the pod
+* on_pod_starting: called after the pod starts
+* on_pod_completion: called when the pod completes
+* on_pod_cleanup: called after cleaning/deleting the pod
+* on_operator_resuming: when resuming the task from deferred state
+* progress_callback: called on each line of containers logs
+
+Currently, the callbacks methods are not called in the async mode, this
support will be added in the future.
+
+Example:
+~~~~~~~~
+.. code-block:: python
+
+ import kubernetes.client as k8s
+ import kubernetes_asyncio.client as async_k8s
+
+ from airflow.providers.cncf.kubernetes.operators.pod import
KubernetesPodOperator
+ from airflow.providers.cncf.kubernetes.callbacks import
KubernetesPodOperatorCallback
+
+
+ class MyCallback(KubernetesPodOperatorCallback):
+ @staticmethod
+ def on_pod_creation(*, pod: k8s.V1Pod, client: k8s.CoreV1Api, mode:
str, **kwargs) -> None:
+ client.create_namespaced_service(
+ namespace=pod.metadata.namespace,
+ body=k8s.V1Service(
+ metadata=k8s.V1ObjectMeta(
+ name=pod.metadata.name,
+ labels=pod.metadata.labels,
+ owner_references=[
+ k8s.V1OwnerReference(
+ api_version=pod.api_version,
+ kind=pod.kind,
+ name=pod.metadata.name,
+ uid=pod.metadata.uid,
+ controller=True,
+ block_owner_deletion=True,
+ )
+ ],
+ ),
+ spec=k8s.V1ServiceSpec(
+ selector=pod.metadata.labels,
+ ports=[
+ k8s.V1ServicePort(
+ name="http",
+ port=80,
+ target_port=80,
+ )
+ ],
+ ),
+ ),
+ )
+
+
+ k = KubernetesPodOperator(
+ task_id="test_callback",
+ image="alpine",
+ cmds=["/bin/sh"],
+ arguments=["-c", "echo hello world; echo Custom error >
/dev/termination-log; exit 1;"],
+ name="test-callback",
+ callbacks=MyCallback,
+ )
+
Reference
^^^^^^^^^
For further information, look at:
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py
b/tests/providers/cncf/kubernetes/operators/test_pod.py
index 690dece953..a737e06cbe 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -1446,6 +1446,119 @@ class TestKubernetesPodOperator:
# check that we wait for the xcom sidecar to start before extracting
XCom
mock_await_xcom_sidecar.assert_called_once_with(pod=pod)
+ @patch(HOOK_CLASS, new=MagicMock)
+ @patch(KUB_OP_PATH.format("find_pod"))
+ def test_execute_sync_callbacks(self, find_pod_mock):
+ from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode
+
+ from ..test_callbacks import MockKubernetesPodOperatorCallback,
MockWrapper
+
+ MockWrapper.reset()
+ mock_callbacks = MockWrapper.mock_callbacks
+ found_pods = [MagicMock(), MagicMock(), MagicMock()]
+ find_pod_mock.side_effect = [None] + found_pods
+
+ remote_pod_mock = MagicMock()
+ remote_pod_mock.status.phase = "Succeeded"
+ self.await_pod_mock.return_value = remote_pod_mock
+ k = KubernetesPodOperator(
+ namespace="default",
+ image="ubuntu:16.04",
+ cmds=["bash", "-cx"],
+ arguments=["echo 10"],
+ labels={"foo": "bar"},
+ name="test",
+ task_id="task",
+ do_xcom_push=False,
+ callbacks=MockKubernetesPodOperatorCallback,
+ )
+ self.run_pod(k)
+
+ # check on_sync_client_creation callback
+ mock_callbacks.on_sync_client_creation.assert_called_once()
+ assert mock_callbacks.on_sync_client_creation.call_args.kwargs ==
{"client": k.client}
+
+ # check on_pod_creation callback
+ mock_callbacks.on_pod_creation.assert_called_once()
+ assert mock_callbacks.on_pod_creation.call_args.kwargs == {
+ "client": k.client,
+ "mode": ExecutionMode.SYNC,
+ "pod": found_pods[0],
+ }
+
+ # check on_pod_starting callback
+ mock_callbacks.on_pod_starting.assert_called_once()
+ assert mock_callbacks.on_pod_starting.call_args.kwargs == {
+ "client": k.client,
+ "mode": ExecutionMode.SYNC,
+ "pod": found_pods[1],
+ }
+
+ # check on_pod_completion callback
+ mock_callbacks.on_pod_completion.assert_called_once()
+ assert mock_callbacks.on_pod_completion.call_args.kwargs == {
+ "client": k.client,
+ "mode": ExecutionMode.SYNC,
+ "pod": found_pods[2],
+ }
+
+ # check on_pod_cleanup callback
+ mock_callbacks.on_pod_cleanup.assert_called_once()
+ assert mock_callbacks.on_pod_cleanup.call_args.kwargs == {
+ "client": k.client,
+ "mode": ExecutionMode.SYNC,
+ "pod": k.pod,
+ }
+
+ @patch(HOOK_CLASS, new=MagicMock)
+ def test_execute_async_callbacks(self):
+ from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode
+
+ from ..test_callbacks import MockKubernetesPodOperatorCallback,
MockWrapper
+
+ MockWrapper.reset()
+ mock_callbacks = MockWrapper.mock_callbacks
+ remote_pod_mock = MagicMock()
+ remote_pod_mock.status.phase = "Succeeded"
+ self.await_pod_mock.return_value = remote_pod_mock
+
+ k = KubernetesPodOperator(
+ namespace="default",
+ image="ubuntu:16.04",
+ cmds=["bash", "-cx"],
+ arguments=["echo 10"],
+ labels={"foo": "bar"},
+ name="test",
+ task_id="task",
+ do_xcom_push=False,
+ callbacks=MockKubernetesPodOperatorCallback,
+ )
+ k.execute_complete(
+ context=create_context(k),
+ event={
+ "status": "success",
+ "message": TEST_SUCCESS_MESSAGE,
+ "name": TEST_NAME,
+ "namespace": TEST_NAMESPACE,
+ },
+ )
+
+ # check on_operator_resuming callback
+ mock_callbacks.on_pod_cleanup.assert_called_once()
+ assert mock_callbacks.on_pod_cleanup.call_args.kwargs == {
+ "client": k.client,
+ "mode": ExecutionMode.SYNC,
+ "pod": remote_pod_mock,
+ }
+
+ # check on_pod_cleanup callback
+ mock_callbacks.on_pod_cleanup.assert_called_once()
+ assert mock_callbacks.on_pod_cleanup.call_args.kwargs == {
+ "client": k.client,
+ "mode": ExecutionMode.SYNC,
+ "pod": remote_pod_mock,
+ }
+
class TestSuppress:
def test__suppress(self, caplog):
@@ -1554,9 +1667,13 @@ class TestKubernetesPodOperatorAsync:
return remote_pod_mock
@pytest.mark.parametrize("do_xcom_push", [True, False])
+ @patch(KUB_OP_PATH.format("client"))
+ @patch(KUB_OP_PATH.format("find_pod"))
@patch(KUB_OP_PATH.format("build_pod_request_obj"))
@patch(KUB_OP_PATH.format("get_or_create_pod"))
- def test_async_create_pod_should_execute_successfully(self, mocked_pod,
mocked_pod_obj, do_xcom_push):
+ def test_async_create_pod_should_execute_successfully(
+ self, mocked_pod, mocked_pod_obj, mocked_found_pod, mocked_client,
do_xcom_push
+ ):
"""
Asserts that a task is deferred and the KubernetesCreatePodTrigger
will be fired
when the KubernetesPodOperator is executed in deferrable mode when
deferrable=True.
@@ -1584,7 +1701,7 @@ class TestKubernetesPodOperatorAsync:
mocked_pod.return_value.metadata.namespace = TEST_NAMESPACE
context = create_context(k)
- ti_mock = MagicMock()
+ ti_mock = MagicMock(**{"map_index": -1})
context["ti"] = ti_mock
with pytest.raises(TaskDeferred) as exc:
diff --git a/tests/providers/cncf/kubernetes/test_callbacks.py
b/tests/providers/cncf/kubernetes/test_callbacks.py
new file mode 100644
index 0000000000..2757b8296a
--- /dev/null
+++ b/tests/providers/cncf/kubernetes/test_callbacks.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+from airflow.providers.cncf.kubernetes.callbacks import
KubernetesPodOperatorCallback
+
+
+class MockWrapper:
+ mock_callbacks = MagicMock()
+
+ @classmethod
+ def reset(cls):
+ cls.mock_callbacks.reset_mock()
+
+
+class MockKubernetesPodOperatorCallback(KubernetesPodOperatorCallback):
+ """`KubernetesPodOperator` callbacks methods."""
+
+ @staticmethod
+ def on_sync_client_creation(*args, **kwargs) -> None:
+ MockWrapper.mock_callbacks.on_sync_client_creation(*args, **kwargs)
+
+ @staticmethod
+ def on_async_client_creation(*args, **kwargs) -> None:
+ MockWrapper.mock_callbacks.on_async_client_creation(*args, **kwargs)
+
+ @staticmethod
+ def on_pod_creation(*args, **kwargs) -> None:
+ MockWrapper.mock_callbacks.on_pod_creation(*args, **kwargs)
+
+ @staticmethod
+ def on_pod_starting(*args, **kwargs) -> None:
+ MockWrapper.mock_callbacks.on_pod_starting(*args, **kwargs)
+
+ @staticmethod
+ def on_pod_completion(*args, **kwargs) -> None:
+ MockWrapper.mock_callbacks.on_pod_completion(*args, **kwargs)
+
+ @staticmethod
+ def on_pod_cleanup(*args, **kwargs) -> None:
+ MockWrapper.mock_callbacks.on_pod_cleanup(*args, **kwargs)
+
+ @staticmethod
+ def on_operator_resuming(*args, **kwargs) -> None:
+ MockWrapper.mock_callbacks.on_operator_resuming(*args, **kwargs)
+
+ @staticmethod
+ def progress_callback(*args, **kwargs) -> None:
+ MockWrapper.mock_callbacks.progress_callback(*args, **kwargs)
diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
index 13f8123550..fc09d6bb02 100644
--- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
+++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
@@ -41,6 +41,8 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager
import (
)
from airflow.utils.timezone import utc
+from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper
+
if TYPE_CHECKING:
from pendulum import DateTime
@@ -50,7 +52,9 @@ class TestPodManager:
self.mock_progress_callback = mock.Mock()
self.mock_kube_client = mock.Mock()
self.pod_manager = PodManager(
- kube_client=self.mock_kube_client,
progress_callback=self.mock_progress_callback
+ kube_client=self.mock_kube_client,
+ callbacks=MockKubernetesPodOperatorCallback,
+ progress_callback=self.mock_progress_callback,
)
def test_read_pod_logs_successfully_returns_logs(self):
@@ -274,7 +278,7 @@ class TestPodManager:
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")
- def test_fetch_container_logs_invoke_progress_callback(
+ def test_fetch_container_logs_invoke_deprecated_progress_callback(
self, mock_read_pod_logs, mock_container_is_running
):
message = "2020-10-08T14:16:17.793417674Z message"
@@ -285,8 +289,30 @@ class TestPodManager:
self.pod_manager.fetch_container_logs(mock.MagicMock(),
mock.MagicMock(), follow=True)
self.mock_progress_callback.assert_has_calls([mock.call(message),
mock.call(no_ts_message)])
+
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
+
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")
+ def test_fetch_container_logs_invoke_progress_callback(
+ self, mock_read_pod_logs, mock_container_is_running
+ ):
+ MockWrapper.reset()
+ mock_callbacks = MockWrapper.mock_callbacks
+ message = "2020-10-08T14:16:17.793417674Z message"
+ no_ts_message = "notimestamp"
+ mock_read_pod_logs.return_value = [bytes(message, "utf-8"),
bytes(no_ts_message, "utf-8")]
+ mock_container_is_running.return_value = False
+
+ self.pod_manager.fetch_container_logs(mock.MagicMock(),
mock.MagicMock(), follow=True)
+ mock_callbacks.progress_callback.assert_has_calls(
+ [
+ mock.call(line=message, client=self.pod_manager._client,
mode="sync"),
+ mock.call(line=no_ts_message, client=self.pod_manager._client,
mode="sync"),
+ ]
+ )
+
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
def test_fetch_container_logs_failures(self, mock_container_is_running):
+ MockWrapper.reset()
+ mock_callbacks = MockWrapper.mock_callbacks
last_timestamp_string = "2020-10-08T14:18:17.793417674Z"
messages = [
bytes("2020-10-08T14:16:17.793417674Z message", "utf-8"),
@@ -309,6 +335,7 @@ class TestPodManager:
status = self.pod_manager.fetch_container_logs(mock.MagicMock(),
mock.MagicMock(), follow=True)
assert status.last_log_time == cast("DateTime",
pendulum.parse(last_timestamp_string))
assert self.mock_progress_callback.call_count == expected_call_count
+ assert mock_callbacks.progress_callback.call_count ==
expected_call_count
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")