This is an automated email from the ASF dual-hosted git repository. dstandish pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new b4c50dadd3 GCSTaskHandler may use remote log conn id (#29117) b4c50dadd3 is described below commit b4c50dadd36d66e4d222c627a61771653767afd6 Author: Daniel Standish <15932138+dstand...@users.noreply.github.com> AuthorDate: Tue Jan 24 15:25:04 2023 -0800 GCSTaskHandler may use remote log conn id (#29117) --- .../providers/google/cloud/log/gcs_task_handler.py | 30 +++++++++++++++++----- .../google/cloud/log/test_gcs_task_handler.py | 25 +++++++++++++----- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/airflow/providers/google/cloud/log/gcs_task_handler.py b/airflow/providers/google/cloud/log/gcs_task_handler.py index 5fbba80798..a264821093 100644 --- a/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -24,6 +24,9 @@ from typing import Collection from google.cloud import storage # type: ignore[attr-defined] from airflow.compat.functools import cached_property +from airflow.configuration import conf +from airflow.exceptions import AirflowNotFoundException +from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id from airflow.providers.google.common.consts import CLIENT_INFO from airflow.utils.log.file_task_handler import FileTaskHandler @@ -72,7 +75,6 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin): super().__init__(base_log_folder, filename_template) self.remote_base = gcs_log_folder self.log_relative_path = "" - self._hook = None self.closed = False self.upload_on_close = True self.gcp_key_path = gcp_key_path @@ -80,15 +82,29 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin): self.scopes = gcp_scopes self.project_id = project_id + @cached_property + def hook(self) -> GCSHook | None: + """Returns GCSHook if remote_log_conn_id configured.""" + conn_id = conf.get("logging", "remote_log_conn_id", fallback=None) + if conn_id: + try: + return GCSHook(gcp_conn_id=conn_id) + except AirflowNotFoundException: + pass + return None + @cached_property def client(self) -> storage.Client: """Returns GCS Client.""" - credentials, project_id = get_credentials_and_project_id( - key_path=self.gcp_key_path, - keyfile_dict=self.gcp_keyfile_dict, - scopes=self.scopes, - disable_logging=True, - ) + if self.hook: + credentials, project_id = self.hook.get_credentials_and_project_id() + else: + credentials, project_id = get_credentials_and_project_id( + key_path=self.gcp_key_path, + keyfile_dict=self.gcp_keyfile_dict, + scopes=self.scopes, + disable_logging=True, + ) return storage.Client( credentials=credentials, client_info=CLIENT_INFO, diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler.py b/tests/providers/google/cloud/log/test_gcs_task_handler.py index 049a627336..b801d1fc05 100644 --- a/tests/providers/google/cloud/log/test_gcs_task_handler.py +++ b/tests/providers/google/cloud/log/test_gcs_task_handler.py @@ -21,6 +21,7 @@ import tempfile from unittest import mock import pytest +from pytest import param from airflow.providers.google.cloud.log.gcs_task_handler import GCSTaskHandler from airflow.utils.state import TaskInstanceState @@ -59,15 +60,25 @@ class TestGCSTaskHandler: ) yield self.gcs_task_handler - @mock.patch( - "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id", - return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"), - ) + @mock.patch("airflow.providers.google.cloud.log.gcs_task_handler.GCSHook") @mock.patch("google.cloud.storage.Client") - def test_hook(self, mock_client, mock_creds): - return_value = self.gcs_task_handler.client + @mock.patch("airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id") + @pytest.mark.parametrize("conn_id", [param("", id="no-conn"), param("my_gcs_conn", id="with-conn")]) + def test_client_conn_id_behavior(self, mock_get_cred, mock_client, mock_hook, conn_id): + """When remote log conn id configured, hook will be used""" + mock_hook.return_value.get_credentials_and_project_id.return_value = ("test_cred", "test_proj") + mock_get_cred.return_value = ("test_cred", "test_proj") + with conf_vars({("logging", "remote_log_conn_id"): conn_id}): + return_value = self.gcs_task_handler.client + if conn_id: + mock_hook.assert_called_once_with(gcp_conn_id="my_gcs_conn") + mock_get_cred.assert_not_called() + else: + mock_hook.assert_not_called() + mock_get_cred.assert_called() + mock_client.assert_called_once_with( - client_info=mock.ANY, credentials="TEST_CREDENTIALS", project="TEST_PROJECT_ID" + client_info=mock.ANY, credentials="test_cred", project="test_proj" ) assert mock_client.return_value == return_value