This is an automated email from the ASF dual-hosted git repository.
shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d6dee49db22 Rework StackdriverTaskHandler for the structlog era #65191
(#65198)
d6dee49db22 is described below
commit d6dee49db228ca14f829e12df066b1cac994ceee
Author: Haseeb Malik <[email protected]>
AuthorDate: Tue May 26 17:16:46 2026 -0400
Rework StackdriverTaskHandler for the structlog era #65191 (#65198)
---
.../google/cloud/log/stackdriver_task_handler.py | 396 +++++++++++++--------
.../cloud/log/test_stackdriver_task_handler.py | 250 ++++++++++++-
2 files changed, 484 insertions(+), 162 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
index b11a1cd6a47..145c2564539 100644
---
a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
+++
b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
@@ -18,13 +18,21 @@
from __future__ import annotations
+import contextlib
+import copy
import logging
+import os
+import shutil
import warnings
from collections.abc import Collection
+from datetime import datetime
from functools import cached_property
+from logging import getLogRecordFactory
+from pathlib import Path
from typing import TYPE_CHECKING
from urllib.parse import urlencode
+import attrs
from google.cloud import logging as gcp_logging
from google.cloud.logging import Resource
from google.cloud.logging.handlers.transports import
BackgroundThreadTransport, Transport
@@ -35,6 +43,7 @@ from airflow.exceptions import
AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.utils.credentials_provider import
get_credentials_and_project_id
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.utils.log.logging_mixin import LoggingMixin
try:
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
@@ -45,9 +54,12 @@ if not AIRFLOW_V_3_0_PLUS:
from airflow.utils.log.trigger_handler import ctx_indiv_trigger
if TYPE_CHECKING:
+ import structlog.typing
from google.auth.credentials import Credentials
from airflow.models import TaskInstance
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
+ from airflow.utils.log.file_task_handler import LogResponse
DEFAULT_LOGGER_NAME = "airflow"
_GLOBAL_RESOURCE = Resource(type="global", labels={})
@@ -56,6 +68,209 @@ _DEFAULT_SCOPESS = frozenset(
["https://www.googleapis.com/auth/logging.read",
"https://www.googleapis.com/auth/logging.write"]
)
+LABEL_TASK_ID = "task_id"
+LABEL_DAG_ID = "dag_id"
+LABEL_LOGICAL_DATE = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date"
+LABEL_TRY_NUMBER = "try_number"
+
+
[email protected](kw_only=True)
+class StackdriverRemoteLogIO(LoggingMixin):
+ """Remote log IO that streams logs to and reads from Google Cloud
Stackdriver Logging."""
+
+ base_log_folder: Path = attrs.field(converter=Path)
+ delete_local_copy: bool = True
+
+ gcp_key_path: str | None = None
+ scopes: Collection[str] | None = _DEFAULT_SCOPESS
+ gcp_log_name: str = DEFAULT_LOGGER_NAME
+ transport_type: type[Transport] = BackgroundThreadTransport
+ resource: Resource = _GLOBAL_RESOURCE
+ labels: dict[str, str] | None = None
+
+ @cached_property
+ def credentials_and_project(self) -> tuple[Credentials, str]:
+ credentials, project = get_credentials_and_project_id(
+ key_path=self.gcp_key_path, scopes=self.scopes,
disable_logging=True
+ )
+ return credentials, project
+
+ @cached_property
+ def _client(self) -> gcp_logging.Client:
+ """The Cloud Library API client."""
+ credentials, project = self.credentials_and_project
+ return gcp_logging.Client(
+ credentials=credentials,
+ project=project,
+ client_info=CLIENT_INFO,
+ )
+
+ @cached_property
+ def _logging_service_client(self) -> LoggingServiceV2Client:
+ """The Cloud logging service v2 client."""
+ credentials, _ = self.credentials_and_project
+ return LoggingServiceV2Client(
+ credentials=credentials,
+ client_info=CLIENT_INFO,
+ )
+
+ @cached_property
+ def transport(self) -> Transport:
+ """Object responsible for sending data to Stackdriver."""
+ return self.transport_type(self._client, self.gcp_log_name)
+
+ @cached_property
+ def processors(self) -> tuple[structlog.typing.Processor, ...]:
+ import structlog.stdlib
+
+ from airflow.sdk.log import relative_path_from_logger
+
+ log_record_factory = getLogRecordFactory()
+ _transport = self.transport
+
+ def proc(
+ logger: structlog.typing.WrappedLogger,
+ method_name: str,
+ event: structlog.typing.EventDict,
+ ):
+ if not logger or not relative_path_from_logger(logger):
+ return event
+
+ name = event.get("logger_name") or event.get("logger", "")
+ level = structlog.stdlib.NAME_TO_LEVEL.get(method_name.lower(),
logging.INFO)
+ msg = copy.copy(event)
+ created = None
+ if ts := msg.pop("timestamp", None):
+ with contextlib.suppress(Exception):
+ created = datetime.fromisoformat(ts)
+ record = log_record_factory(
+ name,
+ level,
+ pathname="",
+ lineno=0,
+ msg=msg,
+ args=(),
+ exc_info=None,
+ func=None,
+ sinfo=None,
+ )
+ if created is not None:
+ ct = created.timestamp()
+ record.created = ct
+ record.msecs = int((ct - int(ct)) * 1000) + 0.0
+
+ ti = getattr(record, "task_instance", None)
+ labels: dict[str, str] = {}
+ if self.labels:
+ labels.update(self.labels)
+ if ti:
+ labels.update(_task_instance_to_labels(ti))
+ _transport.send(record, str(msg.get("event", "")),
resource=self.resource, labels=labels)
+ return event
+
+ return (proc,)
+
+ def upload(self, path: os.PathLike | str, ti: RuntimeTI) -> None:
+ """Flush the transport and optionally delete local log files."""
+ self.transport.flush()
+ if self.delete_local_copy:
+ base = self.base_log_folder.resolve()
+ raw = Path(path)
+ local_path = (raw if raw.is_absolute() else base / raw).resolve()
+ try:
+ local_path.relative_to(base)
+ except ValueError:
+ self.log.warning(
+ "Skipping deletion: path %s is outside base_log_folder %s",
+ local_path,
+ base,
+ )
+ return
+ parent = local_path.parent
+ if parent.exists():
+ shutil.rmtree(parent, ignore_errors=True)
+
+ def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse:
+ """Read logs from Stackdriver Logging using task instance labels."""
+ ti_labels = _task_instance_to_labels(ti)
+ log_filter = self.prepare_log_filter(ti_labels)
+ messages, end_of_log, _ = self.read_logs(log_filter,
next_page_token=None, all_pages=True)
+ return [f"Reading remote log from Stackdriver for {relative_path}"],
[messages] if messages else []
+
+ def prepare_log_filter(self, ti_labels: dict[str, str]) -> str:
+ def escape_label_key(key: str) -> str:
+ return f'"{key}"' if "." in key else key
+
+ def escape_label_value(value: str) -> str:
+ escaped_value = value.replace("\\", "\\\\").replace('"', '\\"')
+ return f'"{escaped_value}"'
+
+ _, project = self.credentials_and_project
+ log_filters = [
+ f"resource.type={escape_label_value(self.resource.type)}",
+ f'logName="projects/{project}/logs/{self.gcp_log_name}"',
+ ]
+
+ for key, value in self.resource.labels.items():
+
log_filters.append(f"resource.labels.{escape_label_key(key)}={escape_label_value(value)}")
+
+ for key, value in ti_labels.items():
+
log_filters.append(f"labels.{escape_label_key(key)}={escape_label_value(value)}")
+ return "\n".join(log_filters)
+
+ def read_logs(
+ self, log_filter: str, next_page_token: str | None, all_pages: bool
+ ) -> tuple[str, bool, str | None]:
+ messages = []
+ new_messages, next_page_token = self._read_single_logs_page(
+ log_filter=log_filter,
+ page_token=next_page_token,
+ )
+ messages.append(new_messages)
+ if all_pages:
+ while next_page_token:
+ new_messages, next_page_token = self._read_single_logs_page(
+ log_filter=log_filter, page_token=next_page_token
+ )
+ messages.append(new_messages)
+
+ end_of_log = True
+ next_page_token = None
+ else:
+ end_of_log = not bool(next_page_token)
+ return "\n".join(messages), end_of_log, next_page_token
+
+ def _read_single_logs_page(self, log_filter: str, page_token: str | None =
None) -> tuple[str, str]:
+ _, project = self.credentials_and_project
+ request = ListLogEntriesRequest(
+ resource_names=[f"projects/{project}"],
+ filter=log_filter,
+ page_token=page_token,
+ order_by="timestamp asc",
+ page_size=1000,
+ )
+ response =
self._logging_service_client.list_log_entries(request=request)
+ page: ListLogEntriesResponse = next(response.pages)
+ messages: list[str] = []
+ for entry in page.entries:
+ if "message" in (entry.json_payload or {}):
+ messages.append(entry.json_payload["message"]) # type: ignore
+ elif entry.text_payload:
+ messages.append(entry.text_payload)
+ return "\n".join(messages), page.next_page_token
+
+
+def _task_instance_to_labels(ti) -> dict[str, str]:
+ """Convert a task instance to Stackdriver labels."""
+ return {
+ LABEL_TASK_ID: ti.task_id,
+ LABEL_DAG_ID: ti.dag_id,
+ LABEL_LOGICAL_DATE: str(ti.logical_date.isoformat())
+ if AIRFLOW_V_3_0_PLUS
+ else str(ti.execution_date.isoformat()),
+ LABEL_TRY_NUMBER: str(ti.try_number),
+ }
+
class StackdriverTaskHandler(logging.Handler):
"""
@@ -88,10 +303,11 @@ class StackdriverTaskHandler(logging.Handler):
:param labels: (Optional) Mapping of labels for the entry.
"""
- LABEL_TASK_ID = "task_id"
- LABEL_DAG_ID = "dag_id"
- LABEL_LOGICAL_DATE = "logical_date" if AIRFLOW_V_3_0_PLUS else
"execution_date"
- LABEL_TRY_NUMBER = "try_number"
+ # Re-export module-level constants for back-compat with external code
reading them off the class
+ LABEL_TASK_ID = LABEL_TASK_ID
+ LABEL_DAG_ID = LABEL_DAG_ID
+ LABEL_LOGICAL_DATE = LABEL_LOGICAL_DATE
+ LABEL_TRY_NUMBER = LABEL_TRY_NUMBER
LOG_VIEWER_BASE_URL = "https://console.cloud.google.com/logs/viewer"
LOG_NAME = "Google Stackdriver"
@@ -120,53 +336,23 @@ class StackdriverTaskHandler(logging.Handler):
gcp_log_name = str(name)
super().__init__()
- self.gcp_key_path: str | None = gcp_key_path
- self.scopes: Collection[str] | None = scopes
- self.gcp_log_name: str = gcp_log_name
- self.transport_type: type[Transport] = transport
- self.resource: Resource = resource
+ self.io = StackdriverRemoteLogIO(
+ base_log_folder=Path("."),
+ gcp_key_path=gcp_key_path,
+ scopes=scopes,
+ gcp_log_name=gcp_log_name,
+ transport_type=transport,
+ resource=resource,
+ labels=labels,
+ )
self.labels: dict[str, str] | None = labels
+ self.resource: Resource = resource
self.task_instance_labels: dict[str, str] | None = {}
self.task_instance_hostname = "default-hostname"
- @cached_property
- def _credentials_and_project(self) -> tuple[Credentials, str]:
- credentials, project = get_credentials_and_project_id(
- key_path=self.gcp_key_path, scopes=self.scopes,
disable_logging=True
- )
- return credentials, project
-
- @property
- def _client(self) -> gcp_logging.Client:
- """The Cloud Library API client."""
- credentials, project = self._credentials_and_project
- client = gcp_logging.Client(
- credentials=credentials,
- project=project,
- client_info=CLIENT_INFO,
- )
- return client
-
- @property
- def _logging_service_client(self) -> LoggingServiceV2Client:
- """The Cloud logging service v2 client."""
- credentials, _ = self._credentials_and_project
- client = LoggingServiceV2Client(
- credentials=credentials,
- client_info=CLIENT_INFO,
- )
- return client
-
- @cached_property
- def _transport(self) -> Transport:
- """Object responsible for sending data to Stackdriver."""
- # The Transport object is badly defined (no init) but in the docs
client/name as constructor
- # arguments are a requirement for any class that derives from
Transport class, hence ignore:
- return self.transport_type(self._client, self.gcp_log_name)
-
def _get_labels(self, task_instance=None):
if task_instance:
- ti_labels = self._task_instance_to_labels(task_instance)
+ ti_labels = _task_instance_to_labels(task_instance)
else:
ti_labels = self.task_instance_labels
labels: dict[str, str] | None
@@ -193,7 +379,7 @@ class StackdriverTaskHandler(logging.Handler):
if not AIRFLOW_V_3_0_PLUS and getattr(record, ctx_indiv_trigger.name,
None):
ti = getattr(record, "task_instance", None) # trigger context
labels = self._get_labels(ti)
- self._transport.send(record, message, resource=self.resource,
labels=labels)
+ self.io.transport.send(record, message, resource=self.resource,
labels=labels)
def set_context(self, task_instance: TaskInstance) -> None:
"""
@@ -201,7 +387,7 @@ class StackdriverTaskHandler(logging.Handler):
:param task_instance: Currently executed task
"""
- self.task_instance_labels =
self._task_instance_to_labels(task_instance)
+ self.task_instance_labels = _task_instance_to_labels(task_instance)
self.task_instance_hostname = task_instance.hostname or
"default-hostname"
def read(
@@ -225,18 +411,18 @@ class StackdriverTaskHandler(logging.Handler):
if not metadata:
metadata = {}
- ti_labels = self._task_instance_to_labels(task_instance)
+ ti_labels = _task_instance_to_labels(task_instance)
if try_number is not None:
- ti_labels[self.LABEL_TRY_NUMBER] = str(try_number)
+ ti_labels[LABEL_TRY_NUMBER] = str(try_number)
else:
- del ti_labels[self.LABEL_TRY_NUMBER]
+ del ti_labels[LABEL_TRY_NUMBER]
- log_filter = self._prepare_log_filter(ti_labels)
+ log_filter = self.io.prepare_log_filter(ti_labels)
next_page_token = metadata.get("next_page_token", None)
all_pages = "download_logs" in metadata and metadata["download_logs"]
- messages, end_of_log, next_page_token = self._read_logs(log_filter,
next_page_token, all_pages)
+ messages, end_of_log, next_page_token = self.io.read_logs(log_filter,
next_page_token, all_pages)
new_metadata: dict[str, str | bool] = {"end_of_log": end_of_log}
@@ -245,102 +431,6 @@ class StackdriverTaskHandler(logging.Handler):
return [((self.task_instance_hostname, messages),)], [new_metadata]
- def _prepare_log_filter(self, ti_labels: dict[str, str]) -> str:
- """
- Prepare the filter that chooses which log entries to fetch.
-
- More information:
-
https://cloud.google.com/logging/docs/reference/v2/rest/v2/entries/list#body.request_body.FIELDS.filter
- https://cloud.google.com/logging/docs/view/advanced-queries
-
- :param ti_labels: Task Instance's labels that will be used to search
for logs
- :return: logs filter
- """
-
- def escape_label_key(key: str) -> str:
- return f'"{key}"' if "." in key else key
-
- def escale_label_value(value: str) -> str:
- escaped_value = value.replace("\\", "\\\\").replace('"', '\\"')
- return f'"{escaped_value}"'
-
- _, project = self._credentials_and_project
- log_filters = [
- f"resource.type={escale_label_value(self.resource.type)}",
- f'logName="projects/{project}/logs/{self.gcp_log_name}"',
- ]
-
- for key, value in self.resource.labels.items():
-
log_filters.append(f"resource.labels.{escape_label_key(key)}={escale_label_value(value)}")
-
- for key, value in ti_labels.items():
-
log_filters.append(f"labels.{escape_label_key(key)}={escale_label_value(value)}")
- return "\n".join(log_filters)
-
- def _read_logs(
- self, log_filter: str, next_page_token: str | None, all_pages: bool
- ) -> tuple[str, bool, str | None]:
- """
- Send requests to the Stackdriver service and downloads logs.
-
- :param log_filter: Filter specifying the logs to be downloaded.
- :param next_page_token: The token of the page from which the log
download will start.
- If None is passed, it will start from the first page.
- :param all_pages: If True is passed, all subpages will be downloaded.
Otherwise, only the first
- page will be downloaded
- :return: A token that contains the following items:
- * string with logs
- * Boolean value describing whether there are more logs,
- * token of the next page
- """
- messages = []
- new_messages, next_page_token = self._read_single_logs_page(
- log_filter=log_filter,
- page_token=next_page_token,
- )
- messages.append(new_messages)
- if all_pages:
- while next_page_token:
- new_messages, next_page_token = self._read_single_logs_page(
- log_filter=log_filter, page_token=next_page_token
- )
- messages.append(new_messages)
- if not messages:
- break
-
- end_of_log = True
- next_page_token = None
- else:
- end_of_log = not bool(next_page_token)
- return "\n".join(messages), end_of_log, next_page_token
-
- def _read_single_logs_page(self, log_filter: str, page_token: str | None =
None) -> tuple[str, str]:
- """
- Send requests to the Stackdriver service and downloads single pages
with logs.
-
- :param log_filter: Filter specifying the logs to be downloaded.
- :param page_token: The token of the page to be downloaded. If None is
passed, the first page will be
- downloaded.
- :return: Downloaded logs and next page token
- """
- _, project = self._credentials_and_project
- request = ListLogEntriesRequest(
- resource_names=[f"projects/{project}"],
- filter=log_filter,
- page_token=page_token,
- order_by="timestamp asc",
- page_size=1000,
- )
- response =
self._logging_service_client.list_log_entries(request=request)
- page: ListLogEntriesResponse = next(response.pages)
- messages: list[str] = []
- for entry in page.entries:
- if "message" in (entry.json_payload or {}):
- messages.append(entry.json_payload["message"]) # type: ignore
- elif entry.text_payload:
- messages.append(entry.text_payload)
- return "\n".join(messages), page.next_page_token
-
@classmethod
def _task_instance_to_labels(cls, ti: TaskInstance) -> dict[str, str]:
return {
@@ -375,12 +465,12 @@ class StackdriverTaskHandler(logging.Handler):
:param try_number: task instance try_number to read logs from
:return: URL to the external log collection service
"""
- _, project_id = self._credentials_and_project
+ _, project_id = self.io.credentials_and_project
- ti_labels = self._task_instance_to_labels(task_instance)
- ti_labels[self.LABEL_TRY_NUMBER] = str(try_number)
+ ti_labels = _task_instance_to_labels(task_instance)
+ ti_labels[LABEL_TRY_NUMBER] = str(try_number)
- log_filter = self._prepare_log_filter(ti_labels)
+ log_filter = self.io.prepare_log_filter(ti_labels)
url_query_string = {
"project": project_id,
@@ -393,4 +483,4 @@ class StackdriverTaskHandler(logging.Handler):
return url
def close(self) -> None:
- self._transport.flush()
+ self.io.transport.flush()
diff --git
a/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py
b/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py
index 66c2ebe29fc..a55b62fa4f4 100644
---
a/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py
+++
b/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py
@@ -18,14 +18,19 @@ from __future__ import annotations
import logging
from contextlib import nullcontext
+from pathlib import Path
from unittest import mock
+from unittest.mock import PropertyMock
from urllib.parse import parse_qs, urlsplit
import pytest
from google.cloud.logging import Resource
from google.cloud.logging_v2.types import ListLogEntriesRequest,
ListLogEntriesResponse, LogEntry
-from airflow.providers.google.cloud.log.stackdriver_task_handler import
StackdriverTaskHandler
+from airflow.providers.google.cloud.log.stackdriver_task_handler import (
+ StackdriverRemoteLogIO,
+ StackdriverTaskHandler,
+)
from airflow.utils import timezone
from airflow.utils.state import TaskInstanceState
@@ -50,6 +55,232 @@ def clean_stackdriver_handlers():
del handler
+class TestStackdriverRemoteLogIO:
+ @pytest.fixture(autouse=True)
+ def _setup(self, tmp_path):
+ self.local_log_location = str(tmp_path / "local/stackdriver/logs")
+ self.io = StackdriverRemoteLogIO(
+ base_log_folder=self.local_log_location,
+ gcp_key_path="KEY_PATH",
+ gcp_log_name="airflow",
+ delete_local_copy=True,
+ )
+
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client")
+ def test_read_logs(self, mock_client, mock_get_creds_and_project_id):
+ mock_client.return_value.list_log_entries.return_value.pages = iter(
+ [_create_list_log_entries_response_mock(["MSG1", "MSG2"], None)]
+ )
+ mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+ ti = mock.MagicMock()
+ ti.task_id = "test_task"
+ ti.dag_id = "test_dag"
+ ti.try_number = 1
+ if AIRFLOW_V_3_0_PLUS:
+ ti.logical_date = timezone.datetime(2016, 1, 1)
+ else:
+ ti.execution_date = timezone.datetime(2016, 1, 1)
+
+ messages, logs =
self.io.read("dag_id=test_dag/run_id=run1/task_id=test_task/attempt=1.log", ti)
+
+ assert len(messages) == 1
+ assert "Stackdriver" in messages[0]
+ assert logs == ["MSG1\nMSG2"]
+
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client")
+ def test_read_logs_empty(self, mock_client, mock_get_creds_and_project_id):
+ mock_client.return_value.list_log_entries.return_value.pages = iter(
+ [_create_list_log_entries_response_mock([], None)]
+ )
+ mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+ ti = mock.MagicMock()
+ ti.task_id = "test_task"
+ ti.dag_id = "test_dag"
+ ti.try_number = 1
+ if AIRFLOW_V_3_0_PLUS:
+ ti.logical_date = timezone.datetime(2016, 1, 1)
+ else:
+ ti.execution_date = timezone.datetime(2016, 1, 1)
+
+ messages, logs = self.io.read("test/path", ti)
+
+ assert len(messages) == 1
+ assert logs == []
+
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
+ def test_credentials(self, mock_client, mock_get_creds_and_project_id):
+ mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+ _ = self.io._client
+
+ mock_get_creds_and_project_id.assert_called_once_with(
+ disable_logging=True,
+ key_path="KEY_PATH",
+ scopes=frozenset(
+ {
+ "https://www.googleapis.com/auth/logging.write",
+ "https://www.googleapis.com/auth/logging.read",
+ }
+ ),
+ )
+ mock_client.assert_called_once_with(credentials="creds",
client_info=mock.ANY, project="project_id")
+
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
+ def test_transport_init(self, mock_client, mock_get_creds_and_project_id):
+ mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+ transport_type = mock.MagicMock()
+ io = StackdriverRemoteLogIO(
+ base_log_folder=self.local_log_location,
+ gcp_log_name="test-log",
+ transport_type=transport_type,
+ )
+ _ = io.transport
+ transport_type.assert_called_once_with(mock_client.return_value,
"test-log")
+
+ @mock.patch("shutil.rmtree")
+ @mock.patch(
+
"airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverRemoteLogIO.transport",
+ new_callable=PropertyMock,
+ )
+ def test_upload_flushes_transport_and_deletes_local(self,
mock_transport_prop, mock_rmtree):
+ io = StackdriverRemoteLogIO(
+ base_log_folder=self.local_log_location,
+ gcp_log_name="airflow",
+ delete_local_copy=True,
+ )
+ mock_transport = mock.MagicMock()
+ mock_transport_prop.return_value = mock_transport
+
+ base = Path(self.local_log_location)
+ base.mkdir(parents=True, exist_ok=True)
+ log_dir = base / "subdir"
+ log_dir.mkdir(parents=True, exist_ok=True)
+ log_file = log_dir / "test.log"
+ log_file.write_text("log content")
+
+ ti = mock.MagicMock()
+ io.upload(str(log_file), ti)
+
+ mock_transport.flush.assert_called_once()
+ mock_rmtree.assert_called_once_with(log_dir.resolve(),
ignore_errors=True)
+
+ @mock.patch(
+
"airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverRemoteLogIO.transport",
+ new_callable=PropertyMock,
+ )
+ def test_upload_no_delete(self, mock_transport_prop):
+ io = StackdriverRemoteLogIO(
+ base_log_folder=self.local_log_location,
+ gcp_log_name="airflow",
+ delete_local_copy=False,
+ )
+ mock_transport = mock.MagicMock()
+ mock_transport_prop.return_value = mock_transport
+
+ ti = mock.MagicMock()
+ io.upload("some/path.log", ti)
+
+ mock_transport.flush.assert_called_once()
+
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+ def test_prepare_log_filter(self, mock_get_creds_and_project_id):
+ mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+ ti_labels = {
+ "task_id": "test_task",
+ "dag_id": "test_dag",
+ "try_number": "1",
+ }
+ log_filter = self.io.prepare_log_filter(ti_labels)
+
+ assert 'resource.type="global"' in log_filter
+ assert 'logName="projects/project_id/logs/airflow"' in log_filter
+ assert 'labels.task_id="test_task"' in log_filter
+ assert 'labels.dag_id="test_dag"' in log_filter
+
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+ def test_prepare_log_filter_with_custom_resource(self,
mock_get_creds_and_project_id):
+ mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+ io = StackdriverRemoteLogIO(
+ base_log_folder=self.local_log_location,
+ gcp_log_name="airflow",
+ resource=Resource(
+ type="cloud_composer_environment",
+ labels={
+ "environment.name": "test-instance",
+ "location": "europe-west-3",
+ },
+ ),
+ )
+ log_filter = io.prepare_log_filter({"task_id": "test"})
+
+ assert 'resource.type="cloud_composer_environment"' in log_filter
+ assert 'resource.labels."environment.name"="test-instance"' in
log_filter
+ assert 'resource.labels.location="europe-west-3"' in log_filter
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="airflow.sdk.log only
exists in Airflow 3+")
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
+ def test_processors_sends_to_transport(self, mock_client,
mock_get_creds_and_project_id):
+ mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+ mock_transport_type = mock.MagicMock()
+ with mock.patch("airflow.sdk.log.relative_path_from_logger",
return_value="dag/task/1.log"):
+ io = StackdriverRemoteLogIO(
+ base_log_folder=self.local_log_location,
+ gcp_log_name="airflow",
+ labels={"env": "test"},
+ transport_type=mock_transport_type,
+ )
+ processors = io.processors
+ assert len(processors) == 1
+
+ proc = processors[0]
+ mock_logger = mock.MagicMock()
+
+ event = {
+ "event": "hello world",
+ "logger_name": "airflow.task",
+ "timestamp": "2026-01-15T10:30:00+00:00",
+ }
+ result = proc(mock_logger, "info", event)
+
+ assert result is event
+ mock_transport = mock_transport_type.return_value
+ mock_transport.send.assert_called_once()
+ record = mock_transport.send.call_args[0][0]
+ assert record.levelno == logging.INFO
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="airflow.sdk.log only
exists in Airflow 3+")
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
+ def test_processors_skips_non_task_logger(self, mock_client,
mock_get_creds_and_project_id):
+ mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+ mock_transport_type = mock.MagicMock()
+ with mock.patch("airflow.sdk.log.relative_path_from_logger",
return_value=None):
+ io = StackdriverRemoteLogIO(
+ base_log_folder=self.local_log_location,
+ gcp_log_name="airflow",
+ transport_type=mock_transport_type,
+ )
+ proc = io.processors[0]
+
+ event = {"event": "should not be sent"}
+ result = proc(mock.MagicMock(), "info", event)
+
+ assert result is event
+ mock_transport_type.return_value.send.assert_not_called()
+
+
@pytest.mark.usefixtures("clean_stackdriver_handlers")
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
@@ -77,7 +308,6 @@ def test_should_pass_message_to_client(mock_client,
mock_get_creds_and_project_i
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
def test_should_use_configured_log_name(mock_client,
mock_get_creds_and_project_id):
import importlib
- import logging
from airflow import settings
from airflow.config_templates import airflow_local_settings
@@ -85,9 +315,6 @@ def test_should_use_configured_log_name(mock_client,
mock_get_creds_and_project_
mock_get_creds_and_project_id.return_value = ("creds", "project_id")
try:
- # this is needed for Airflow 2.8 and below where default settings are
triggering warning on
- # extra "name" in the configuration of stackdriver handler. As of
Airflow 2.9 this warning is not
- # emitted.
context_manager = nullcontext()
with context_manager:
with conf_vars(
@@ -99,12 +326,17 @@ def test_should_use_configured_log_name(mock_client,
mock_get_creds_and_project_
importlib.reload(airflow_local_settings)
settings.configure_logging()
+ task_log = getattr(airflow_local_settings, "REMOTE_TASK_LOG",
None)
+ if task_log is not None:
+ # Airflow 3+ uses REMOTE_TASK_LOG instead of handler-based
config
+ assert isinstance(task_log, StackdriverRemoteLogIO)
+ assert task_log.gcp_log_name == "path"
+ return
+
+ # Older Airflow: stackdriver is wired as a logging handler
logger = logging.getLogger("airflow.task")
handler = logger.handlers[0]
assert isinstance(handler, StackdriverTaskHandler)
- with mock.patch.object(handler, "transport_type") as
transport_type_mock:
- logger.error("foo")
-
transport_type_mock.assert_called_once_with(mock_client.return_value, "path")
finally:
importlib.reload(airflow_local_settings)
settings.configure_logging()
@@ -398,7 +630,7 @@ class TestStackdriverLoggingHandlerTask:
mock_get_creds_and_project_id.return_value = ("creds", "project_id")
stackdriver_task_handler =
StackdriverTaskHandler(gcp_key_path="KEY_PATH")
- client = stackdriver_task_handler._client
+ client = stackdriver_task_handler.io._client
mock_get_creds_and_project_id.assert_called_once_with(
disable_logging=True,