This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d9cfdd8131d797735b406554674282ea0f8dcbae
Author: Pankaj Koti <pankajkoti...@gmail.com>
AuthorDate: Tue Nov 21 11:55:03 2023 +0530

    Extend task context logging support for remote logging using Elasticsearch 
(#32977)
    
    * Extend task context logging support for remote logging using Elasticsearch
    
    With the addition of task context logging feature in PR #32646,
    this PR extends the feature to Elasticsearch when is it set as
    remote logging store. Here, backward compatibility is ensured for
    older versions of Airflow that do not have the feature included
    in Airflow Core.
    
    * update ensure_ti
    
    ---------
    
    Co-authored-by: Daniel Standish 
<15932138+dstand...@users.noreply.github.com>
---
 .../providers/elasticsearch/log/es_task_handler.py | 46 +++++++++++++++++++---
 1 file changed, 41 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py 
b/airflow/providers/elasticsearch/log/es_task_handler.py
index 79f9ad0b41..1e8c75b7e3 100644
--- a/airflow/providers/elasticsearch/log/es_task_handler.py
+++ b/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -34,7 +34,7 @@ import pendulum
 from elasticsearch.exceptions import NotFoundError
 
 from airflow.configuration import conf
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models.dagrun import DagRun
 from airflow.providers.elasticsearch.log.es_json_formatter import 
ElasticsearchJSONFormatter
 from airflow.providers.elasticsearch.log.es_response import 
ElasticSearchResponse, Hit
@@ -46,7 +46,8 @@ from airflow.utils.session import create_session
 if TYPE_CHECKING:
     from datetime import datetime
 
-    from airflow.models.taskinstance import TaskInstance
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
 
 LOG_LINE_DEFAULTS = {"exc_text": "", "stack_info": ""}
 # Elasticsearch hosted log type
@@ -84,6 +85,32 @@ def get_es_kwargs_from_config() -> dict[str, Any]:
     return kwargs_dict
 
 
+def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance:
+    """Given TI | TIKey, return a TI object.
+
+    Will raise exception if no TI is found in the database.
+    """
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
+    if not isinstance(ti, TaskInstanceKey):
+        return ti
+    val = (
+        session.query(TaskInstance)
+        .filter(
+            TaskInstance.task_id == ti.task_id,
+            TaskInstance.dag_id == ti.dag_id,
+            TaskInstance.run_id == ti.run_id,
+            TaskInstance.map_index == ti.map_index,
+        )
+        .one_or_none()
+    )
+    if isinstance(val, TaskInstance):
+        val._try_number = ti.try_number
+        return val
+    else:
+        raise AirflowException(f"Could not find TaskInstance for {ti}")
+
+
 class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin, 
LoggingMixin):
     """
     ElasticsearchTaskHandler is a python log handler that reads logs from 
Elasticsearch.
@@ -182,8 +209,12 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
 
         return host
 
-    def _render_log_id(self, ti: TaskInstance, try_number: int) -> str:
+    def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: 
int) -> str:
+        from airflow.models.taskinstance import TaskInstanceKey
+
         with create_session() as session:
+            if isinstance(ti, TaskInstanceKey):
+                ti = _ensure_ti(ti, session)
             dag_run = ti.get_dagrun(session=session)
             if USE_PER_RUN_LOG_ID:
                 log_id_template = 
dag_run.get_log_template(session=session).elasticsearch_id
@@ -377,11 +408,13 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
             setattr(record, self.offset_field, int(time.time() * (10**9)))
             self.handler.emit(record)
 
-    def set_context(self, ti: TaskInstance, **kwargs) -> None:
+    def set_context(self, ti: TaskInstance, *, identifier: str | None = None) 
-> None:
         """
         Provide task_instance context to airflow task handler.
 
         :param ti: task instance object
+        :param identifier: if set, identifies the Airflow component which is 
relaying logs from
+            exceptional scenarios related to the task instance
         """
         is_trigger_log_context = getattr(ti, "is_trigger_log_context", None)
         is_ti_raw = getattr(ti, "raw", None)
@@ -410,7 +443,10 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
             self.handler.setLevel(self.level)
             self.handler.setFormatter(self.formatter)
         else:
-            super().set_context(ti)
+            if getattr(self, "supports_task_context_logging", False):
+                super().set_context(ti, identifier=identifier)
+            else:
+                super().set_context(ti)
         self.context_set = True
 
     def close(self) -> None:

Reply via email to