This is an automated email from the ASF dual-hosted git repository.

pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f69b5f007 Add task context logging feature to allow forwarding 
messages to task logs (#32646)
2f69b5f007 is described below

commit 2f69b5f007b544f992432a3c681f393317e16c16
Author: Pankaj Koti <pankajkoti...@gmail.com>
AuthorDate: Fri Nov 17 12:49:49 2023 +0530

    Add task context logging feature to allow forwarding messages to task logs 
(#32646)
    
    The PR adds a feature by adding the `TaskContextLogger` class that
    can forward messages from Airflow components like Scheduler,
    Executor, etc to the task logs. This is helpful when in exception
    scenarios the task is marked as failed from these components
    e.g. when task times out because it remains stuck in the queue or
    when a zombie task is killed. In such scenarios, currently, no task
    logs are available in the UI and the user does not have a clue why
    the task failed. Forwarding such event messages to the task logs
    will give a clear idea to the user.
    
    The PR also adds a config param to disable this feature in case
    something is observed to be not working well in the Airflow
    components due to the addition of this feature.
    
    ---------
    
    Co-authored-by: Niko Oliveira <oniko...@amazon.com>
    Co-authored-by: Daniel Standish 
<15932138+dstand...@users.noreply.github.com>
---
 airflow/config_templates/config.yml                |  12 ++
 airflow/jobs/scheduler_job_runner.py               |   7 +-
 .../providers/elasticsearch/log/es_task_handler.py |   2 +-
 .../microsoft/azure/log/wasb_task_handler.py       |   4 +-
 airflow/providers/redis/log/redis_task_handler.py  |   2 +-
 airflow/utils/log/file_task_handler.py             |  50 +++++-
 airflow/utils/log/task_context_logger.py           | 182 +++++++++++++++++++++
 tests/utils/log/test_task_context_logger.py        | 107 ++++++++++++
 8 files changed, 355 insertions(+), 11 deletions(-)

diff --git a/airflow/config_templates/config.yml 
b/airflow/config_templates/config.yml
index ee246e1c01..224674b360 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -928,6 +928,18 @@ logging:
       type: boolean
       example: ~
       default: "False"
+    enable_task_context_logger:
+      description: |
+        If enabled, Airflow may ship messages to task logs from outside the 
task run context, e.g. from
+        the scheduler, executor, or callback execution context. This can help 
in circumstances such as
+        when there's something blocking the execution of the task and 
ordinarily there may be no task
+        logs at all.
+        This is set to True by default. If you encounter issues with this 
feature
+        (e.g. scheduler performance issues) it can be disabled.
+      version_added: 2.8.0
+      type: boolean
+      example: ~
+      default: "True"
 metrics:
   description: |
     StatsD (https://github.com/etsy/statsd) integration settings.
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 4eebf9967c..334790817b 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -62,6 +62,7 @@ from airflow.timetables.simple import 
DatasetTriggeredTimetable
 from airflow.utils import timezone
 from airflow.utils.event_scheduler import EventScheduler
 from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.log.task_context_logger import TaskContextLogger
 from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, 
run_with_db_retries
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.sqlalchemy import (
@@ -233,6 +234,10 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         self.processor_agent: DagFileProcessorAgent | None = None
 
         self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, 
load_op_links=False)
+        self._task_context_logger: TaskContextLogger = TaskContextLogger(
+            component_name=self.job_type,
+            call_site_logger=self.log,
+        )
 
     @provide_session
     def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
@@ -773,7 +778,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                     "Executor reports task instance %s finished (%s) although 
the "
                     "task says it's %s. (Info: %s) Was the task killed 
externally?"
                 )
-                self.log.error(msg, ti, state, ti.state, info)
+                self._task_context_logger.error(msg, ti, state, ti.state, 
info, ti=ti)
 
                 # Get task from the Serialized DAG
                 try:
diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py 
b/airflow/providers/elasticsearch/log/es_task_handler.py
index f961f374f2..79f9ad0b41 100644
--- a/airflow/providers/elasticsearch/log/es_task_handler.py
+++ b/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -377,7 +377,7 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
             setattr(record, self.offset_field, int(time.time() * (10**9)))
             self.handler.emit(record)
 
-    def set_context(self, ti: TaskInstance) -> None:
+    def set_context(self, ti: TaskInstance, **kwargs) -> None:
         """
         Provide task_instance context to airflow task handler.
 
diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py 
b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
index ac45fb6c42..6a719724b9 100644
--- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py
+++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
@@ -33,6 +33,8 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 if TYPE_CHECKING:
     import logging
 
+    from airflow.models.taskinstance import TaskInstance
+
 
 def get_default_delete_local_copy():
     """Load delete_local_logs conf if Airflow version > 2.6 and return False 
if not.
@@ -93,7 +95,7 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
             )
             return None
 
-    def set_context(self, ti) -> None:
+    def set_context(self, ti: TaskInstance, **kwargs) -> None:
         super().set_context(ti)
         # Local location and remote location is needed to open and
         # upload local log file to Wasb remote storage.
diff --git a/airflow/providers/redis/log/redis_task_handler.py 
b/airflow/providers/redis/log/redis_task_handler.py
index 2582faf4ed..7107b2ab3a 100644
--- a/airflow/providers/redis/log/redis_task_handler.py
+++ b/airflow/providers/redis/log/redis_task_handler.py
@@ -81,7 +81,7 @@ class RedisTaskHandler(FileTaskHandler, LoggingMixin):
         ).decode()
         return log_str, {"end_of_log": True}
 
-    def set_context(self, ti: TaskInstance):
+    def set_context(self, ti: TaskInstance, **kwargs) -> None:
         super().set_context(ti)
         self.handler = _RedisHandler(
             self.conn,
diff --git a/airflow/utils/log/file_task_handler.py 
b/airflow/utils/log/file_task_handler.py
index 25b664074e..3e2561ba75 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -18,6 +18,7 @@
 """File logging handler for tasks."""
 from __future__ import annotations
 
+import inspect
 import logging
 import os
 import warnings
@@ -31,7 +32,7 @@ from urllib.parse import urljoin
 import pendulum
 
 from airflow.configuration import conf
-from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.utils.context import Context
 from airflow.utils.helpers import parse_template_string, 
render_template_to_string
@@ -41,7 +42,7 @@ from airflow.utils.session import create_session
 from airflow.utils.state import State, TaskInstanceState
 
 if TYPE_CHECKING:
-    from airflow.models import TaskInstance
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 
 logger = logging.getLogger(__name__)
 
@@ -131,6 +132,32 @@ def _interleave_logs(*logs):
         last = v
 
 
+def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance:
+    """Given TI | TIKey, return a TI object.
+
+    Will raise exception if no TI is found in the database.
+    """
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
+    if not isinstance(ti, TaskInstanceKey):
+        return ti
+    val = (
+        session.query(TaskInstance)
+        .filter(
+            TaskInstance.task_id == ti.task_id,
+            TaskInstance.dag_id == ti.dag_id,
+            TaskInstance.run_id == ti.run_id,
+            TaskInstance.map_index == ti.map_index,
+        )
+        .one_or_none()
+    )
+    if isinstance(val, TaskInstance):
+        val._try_number = ti.try_number
+        return val
+    else:
+        raise AirflowException(f"Could not find TaskInstance for {ti}")
+
+
 class FileTaskHandler(logging.Handler):
     """
     FileTaskHandler is a python log handler that handles and reads task 
instance logs.
@@ -170,7 +197,7 @@ class FileTaskHandler(logging.Handler):
         Some handlers emit "end of log" markers, and may not wish to do so 
when task defers.
         """
 
-    def set_context(self, ti: TaskInstance) -> None | SetContextPropagate:
+    def set_context(self, ti: TaskInstance, *, identifier: str | None = None) 
-> None | SetContextPropagate:
         """
         Provide task_instance context to airflow task handler.
 
@@ -181,14 +208,20 @@ class FileTaskHandler(logging.Handler):
         functionality is only used in unit testing.
 
         :param ti: task instance object
+        :param identifier: if set, adds suffix to log file. For use when 
relaying exceptional messages
+            to task logs from a context other than task or trigger run
         """
-        local_loc = self._init_file(ti)
+        local_loc = self._init_file(ti, identifier=identifier)
         self.handler = NonCachingFileHandler(local_loc, encoding="utf-8")
         if self.formatter:
             self.handler.setFormatter(self.formatter)
         self.handler.setLevel(self.level)
         return SetContextPropagate.MAINTAIN_PROPAGATE if 
self.maintain_propagate else None
 
+    @cached_property
+    def supports_task_context_logging(self) -> bool:
+        return "identifier" in inspect.signature(self.set_context).parameters
+
     @staticmethod
     def add_triggerer_suffix(full_path, job_id=None):
         """
@@ -217,9 +250,10 @@ class FileTaskHandler(logging.Handler):
         if self.handler:
             self.handler.close()
 
-    def _render_filename(self, ti: TaskInstance, try_number: int) -> str:
+    def _render_filename(self, ti: TaskInstance | TaskInstanceKey, try_number: 
int) -> str:
         """Return the worker log filename."""
         with create_session() as session:
+            ti = _ensure_ti(ti, session)
             dag_run = ti.get_dagrun(session=session)
             template = dag_run.get_log_template(session=session).filename
             str_tpl, jinja_tpl = parse_template_string(template)
@@ -458,7 +492,7 @@ class FileTaskHandler(logging.Handler):
                 print(f"Failed to change {directory} permission to 
{new_folder_permissions}: {e}")
                 pass
 
-    def _init_file(self, ti):
+    def _init_file(self, ti, *, identifier: str | None = None):
         """
         Create log directory and give it permissions that are configured.
 
@@ -472,7 +506,9 @@ class FileTaskHandler(logging.Handler):
         )
         local_relative_path = self._render_filename(ti, ti.try_number)
         full_path = os.path.join(self.local_base, local_relative_path)
-        if ti.is_trigger_log_context is True:
+        if identifier:
+            full_path += f".{identifier}.log"
+        elif ti.is_trigger_log_context is True:
             # if this is true, we're invoked via set_context in the context of
             # setting up individual trigger logging. return trigger log path.
             full_path = self.add_triggerer_suffix(full_path=full_path, 
job_id=ti.triggerer_job.id)
diff --git a/airflow/utils/log/task_context_logger.py 
b/airflow/utils/log/task_context_logger.py
new file mode 100644
index 0000000000..0661789f5b
--- /dev/null
+++ b/airflow/utils/log/task_context_logger.py
@@ -0,0 +1,182 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from contextlib import suppress
+from copy import copy
+from logging import Logger
+from typing import TYPE_CHECKING
+
+from airflow.configuration import conf
+
+if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance
+    from airflow.utils.log.file_task_handler import FileTaskHandler
+
+logger = logging.getLogger(__name__)
+
+
+class TaskContextLogger:
+    """
+    Class for sending messages to task instance logs from outside task 
execution context.
+
+    This is intended to be used mainly in exceptional circumstances, to give 
visibility into
+    events related to task execution when otherwise there would be none.
+
+    :meta private:
+    """
+
+    def __init__(self, component_name: str, call_site_logger: Logger | None = 
None):
+        """
+        Initialize the task context logger with the component name.
+
+        :param component_name: the name of the component that will be used to 
identify the log messages
+        :param call_site_logger: if provided, message will also be emitted 
through this logger
+        """
+        self.component_name = component_name
+        self.task_handler = self._get_task_handler()
+        self.enabled = self._should_enable()
+        self.call_site_logger = call_site_logger
+
+    def _should_enable(self) -> bool:
+        if not conf.getboolean("logging", "enable_task_context_logger"):
+            return False
+        if not getattr(self.task_handler, "supports_task_context_logging", 
False):
+            logger.warning("Task handler does not support task context 
logging")
+            return False
+        logger.info("Task context logging is enabled")
+        return True
+
+    @staticmethod
+    def _get_task_handler() -> FileTaskHandler | None:
+        """Returns the task handler that supports task context logging."""
+        handlers = [
+            handler
+            for handler in logging.getLogger("airflow.task").handlers
+            if getattr(handler, "supports_task_context_logging", False)
+        ]
+        if not handlers:
+            return None
+        h = handlers[0]
+        if TYPE_CHECKING:
+            assert isinstance(h, FileTaskHandler)
+        return h
+
+    def _log(self, level: int, msg: str, *args, ti: TaskInstance):
+        """
+        Emit a log message to the task instance logs.
+
+        :param level: the log level
+        :param msg: the message to relay to task context log
+        :param ti: the task instance
+        """
+        if self.call_site_logger and 
self.call_site_logger.isEnabledFor(level=level):
+            with suppress(Exception):
+                self.call_site_logger.log(level, msg, *args)
+
+        if not self.enabled:
+            return
+
+        if not self.task_handler:
+            return
+
+        task_handler = copy(self.task_handler)
+        try:
+            if hasattr(task_handler, "mark_end_on_close"):
+                task_handler.mark_end_on_close = False
+            task_handler.set_context(ti, identifier=self.component_name)
+            filename, lineno, func, stackinfo = logger.findCaller()
+            record = logging.LogRecord(
+                self.component_name, level, filename, lineno, msg, args, None, 
func=func
+            )
+            task_handler.emit(record)
+        finally:
+            task_handler.close()
+
+    def critical(self, msg: str, *args, ti: TaskInstance):
+        """
+        Emit a log message with level CRITICAL to the task instance logs.
+
+        :param msg: the message to relay to task context log
+        :param ti: the task instance
+        """
+        self._log(logging.CRITICAL, msg, *args, ti=ti)
+
+    def fatal(self, msg: str, *args, ti: TaskInstance):
+        """
+        Emit a log message with level FATAL to the task instance logs.
+
+        :param msg: the message to relay to task context log
+        :param ti: the task instance
+        """
+        self._log(logging.FATAL, msg, *args, ti=ti)
+
+    def error(self, msg: str, *args, ti: TaskInstance):
+        """
+        Emit a log message with level ERROR to the task instance logs.
+
+        :param msg: the message to relay to task context log
+        :param ti: the task instance
+        """
+        self._log(logging.ERROR, msg, *args, ti=ti)
+
+    def warn(self, msg: str, *args, ti: TaskInstance):
+        """
+        Emit a log message with level WARN to the task instance logs.
+
+        :param msg: the message to relay to task context log
+        :param ti: the task instance
+        """
+        self._log(logging.WARN, msg, *args, ti=ti)
+
+    def warning(self, msg: str, *args, ti: TaskInstance):
+        """
+        Emit a log message with level WARNING to the task instance logs.
+
+        :param msg: the message to relay to task context log
+        :param ti: the task instance
+        """
+        self._log(logging.WARNING, msg, *args, ti=ti)
+
+    def info(self, msg: str, *args, ti: TaskInstance):
+        """
+        Emit a log message with level INFO to the task instance logs.
+
+        :param msg: the message to relay to task context log
+        :param ti: the task instance
+        """
+        self._log(logging.INFO, msg, *args, ti=ti)
+
+    def debug(self, msg: str, *args, ti: TaskInstance):
+        """
+        Emit a log message with level DEBUG to the task instance logs.
+
+        :param msg: the message to relay to task context log
+        :param ti: the task instance
+        """
+        self._log(logging.DEBUG, msg, *args, ti=ti)
+
+    def notset(self, msg: str, *args, ti: TaskInstance):
+        """
+        Emit a log message with level NOTSET to the task instance logs.
+
+        :param msg: the message to relay to task context log
+        :param ti: the task instance
+        """
+        self._log(logging.NOTSET, msg, *args, ti=ti)
diff --git a/tests/utils/log/test_task_context_logger.py 
b/tests/utils/log/test_task_context_logger.py
new file mode 100644
index 0000000000..a8754f4d0b
--- /dev/null
+++ b/tests/utils/log/test_task_context_logger.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from unittest.mock import Mock
+
+import pytest
+
+from airflow.utils.log.task_context_logger import TaskContextLogger
+from tests.test_utils.config import conf_vars
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture
+def mock_handler():
+    logger = logging.getLogger("airflow.task")
+    old = logger.handlers[:]
+    h = Mock()
+    logger.handlers[:] = [h]
+    yield h
+    logger.handlers[:] = old
+
+
+@pytest.fixture
+def ti(dag_maker):
+    with dag_maker() as dag:
+
+        @dag.task()
+        def nothing():
+            return None
+
+        nothing()
+
+    dr = dag.create_dagrun("running", run_id="abc")
+    ti = dr.get_task_instances()[0]
+    yield ti
+
+
+def test_task_context_logger_enabled_by_default():
+    t = TaskContextLogger(component_name="test_component")
+    assert t.enabled is True
+
+
+@pytest.mark.parametrize("supported", [True, False])
+def test_task_handler_not_supports_task_context_logging(mock_handler, 
supported):
+    mock_handler.supports_task_context_logging = supported
+    t = TaskContextLogger(component_name="test_component")
+    assert t.enabled is supported
+
+
+@pytest.mark.db_test
+@pytest.mark.parametrize("supported", [True, False])
+def test_task_context_log_with_correct_arguments(ti, mock_handler, supported):
+    mock_handler.supports_task_context_logging = supported
+    t = TaskContextLogger(component_name="test_component")
+    t.info("test message with args %s, %s", "a", "b", ti=ti)
+    if supported:
+        mock_handler.set_context.assert_called_once_with(ti, 
identifier="test_component")
+        mock_handler.emit.assert_called_once()
+    else:
+        mock_handler.set_context.assert_not_called()
+        mock_handler.emit.assert_not_called()
+
+
+@pytest.mark.db_test
+def test_task_context_log_closes_task_handler(ti, mock_handler):
+    t = TaskContextLogger("blah")
+    t.info("test message", ti=ti)
+    mock_handler.close.assert_called_once()
+
+
+@pytest.mark.db_test
+def test_task_context_log_also_emits_to_call_site_logger(ti):
+    logger = logging.getLogger("abc123567")
+    logger.setLevel(logging.INFO)
+    logger.log = Mock()
+    t = TaskContextLogger("blah", call_site_logger=logger)
+    t.info("test message", ti=ti)
+    logger.log.assert_called_once_with(logging.INFO, "test message")
+
+
+@pytest.mark.db_test
+@pytest.mark.parametrize("val, expected", [("true", True), ("false", False)])
+def test_task_context_logger_config_works(ti, mock_handler, val, expected):
+    with conf_vars({("logging", "enable_task_context_logger"): val}):
+        t = TaskContextLogger("abc")
+        t.info("test message", ti=ti)
+        if expected:
+            mock_handler.emit.assert_called()
+        else:
+            mock_handler.emit.assert_not_called()

Reply via email to