This is an automated email from the ASF dual-hosted git repository. vincbeck pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 103df61bde6 Fix the way to get STS endpoint in EKS hook (#45520) 103df61bde6 is described below commit 103df61bde6e98de2466f42e76b4ac3bcc4ab8b5 Author: Vincent <97131062+vincb...@users.noreply.github.com> AuthorDate: Thu Jan 9 13:20:12 2025 -0500 Fix the way to get STS endpoint in EKS hook (#45520) --- providers/src/airflow/providers/amazon/aws/hooks/eks.py | 4 ++-- providers/tests/amazon/aws/hooks/test_eks.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/hooks/eks.py b/providers/src/airflow/providers/amazon/aws/hooks/eks.py index d48c103505a..4e8c0ca7ad6 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/eks.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/eks.py @@ -32,6 +32,7 @@ from botocore.exceptions import ClientError from botocore.signers import RequestSigner from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.sts import StsHook from airflow.utils import yaml from airflow.utils.json import AirflowJsonEncoder @@ -612,8 +613,7 @@ class EksHook(AwsBaseHook): def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str: session = self.get_session() service_id = self.conn.meta.service_model.service_id - sts_client = session.client("sts") - sts_url = f"{sts_client.meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15" + sts_url = f"{StsHook().conn_client_meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15" signer = RequestSigner( service_id=service_id, diff --git a/providers/tests/amazon/aws/hooks/test_eks.py b/providers/tests/amazon/aws/hooks/test_eks.py index 06cc7ddab53..10a93790ac6 100644 --- a/providers/tests/amazon/aws/hooks/test_eks.py +++ b/providers/tests/amazon/aws/hooks/test_eks.py @@ -22,7 +22,6 @@ from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING from unittest import mock -from unittest.mock import Mock from urllib.parse import urlsplit import pytest @@ -1284,14 +1283,13 @@ class TestEksHook: } @mock.patch("airflow.providers.amazon.aws.hooks.eks.RequestSigner") + @mock.patch("airflow.providers.amazon.aws.hooks.eks.StsHook") @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn") @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_session") - def test_fetch_access_token_for_cluster(self, mock_get_session, mock_conn, mock_signer): + def test_fetch_access_token_for_cluster(self, mock_get_session, mock_conn, mock_sts_hook, mock_signer): mock_signer.return_value.generate_presigned_url.return_value = "http://example.com" mock_get_session.return_value.region_name = "us-east-1" - client = Mock() - client.meta.endpoint_url = "https://sts.us-east-1.amazonaws.com" - mock_get_session.return_value.client.return_value = client + mock_sts_hook.return_value.conn_client_meta.endpoint_url = "https://sts.us-east-1.amazonaws.com" hook = EksHook() token = hook.fetch_access_token_for_cluster(eks_cluster_name="test-cluster") mock_signer.assert_called_once_with(