This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new d3dc88f084 Avoid to use `functools.lru_cache` in class methods in `google` provider (#38652) d3dc88f084 is described below commit d3dc88f0844bcb377a9e52312e1a99b5ca6e617e Author: Andrey Anshin <andrey.ans...@taragol.is> AuthorDate: Mon Apr 1 18:04:30 2024 +0400 Avoid to use `functools.lru_cache` in class methods in `google` provider (#38652) --- .../providers/google/cloud/hooks/compute_ssh.py | 2 +- .../providers/google/common/hooks/base_google.py | 2 +- .../google/cloud/hooks/test_compute_ssh.py | 56 +++++++++++++--------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/compute_ssh.py b/airflow/providers/google/cloud/hooks/compute_ssh.py index 2bc5dcf514..97df5c5525 100644 --- a/airflow/providers/google/cloud/hooks/compute_ssh.py +++ b/airflow/providers/google/cloud/hooks/compute_ssh.py @@ -334,7 +334,7 @@ class ComputeEngineSSHHook(SSHHook): ) def _authorize_os_login(self, pubkey): - username = self._oslogin_hook._get_credentials_email() + username = self._oslogin_hook._get_credentials_email self.log.info("Importing SSH public key using OSLogin: user=%s", username) expiration = int((time.time() + self.expire_time) * 1000000) ssh_public_key = {"key": pubkey, "expiration_time_usec": expiration} diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 13543243cb..ca08f86e78 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -317,7 +317,7 @@ class GoogleBaseHook(BaseHook): credentials.refresh(auth_req) return credentials.token - @functools.lru_cache(maxsize=None) + @functools.cached_property def _get_credentials_email(self) -> str: """ Return the email address associated with the currently logged in account. diff --git a/tests/providers/google/cloud/hooks/test_compute_ssh.py b/tests/providers/google/cloud/hooks/test_compute_ssh.py index 27cfe4fc1b..dfcd0d719c 100644 --- a/tests/providers/google/cloud/hooks/test_compute_ssh.py +++ b/tests/providers/google/cloud/hooks/test_compute_ssh.py @@ -28,6 +28,7 @@ from paramiko.ssh_exception import SSHException from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook +from airflow.providers.google.cloud.hooks.os_login import OSLoginHook pytestmark = pytest.mark.db_test @@ -48,22 +49,35 @@ class TestComputeEngineHookWithPassedProjectId: with pytest.raises(RuntimeError): ComputeEngineSSHHook(gcp_conn_id="gcpssh", delegate_to="delegate_to") + def test_os_login_hook(self, mocker): + mock_os_login_hook = mocker.patch.object(OSLoginHook, "__init__", return_value=None, spec=OSLoginHook) + + # Default values + assert ComputeEngineSSHHook()._oslogin_hook + mock_os_login_hook.assert_called_with(gcp_conn_id="google_cloud_default") + + # Custom conn_id + assert ComputeEngineSSHHook(gcp_conn_id="gcpssh")._oslogin_hook + mock_os_login_hook.assert_called_with(gcp_conn_id="gcpssh") + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook") - @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient") - def test_get_conn_default_configuration( - self, mock_ssh_client, mock_paramiko, mock_os_login_hook, mock_compute_hook - ): - mock_paramiko.SSHException = Exception + def test_get_conn_default_configuration(self, mock_ssh_client, mock_paramiko, mock_compute_hook, mocker): + mock_paramiko.SSHException = RuntimeError mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME" mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ" mock_compute_hook.return_value.project_id = TEST_PROJECT_ID mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP - mock_os_login_hook.return_value._get_credentials_email.return_value = "test-exam...@example.org" - mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [ + mock_os_login_hook = mocker.patch.object( + ComputeEngineSSHHook, "_oslogin_hook", spec=OSLoginHook, name="FakeOsLoginHook" + ) + type(mock_os_login_hook)._get_credentials_email = mock.PropertyMock( + return_value="test-exam...@example.org" + ) + mock_os_login_hook.import_ssh_public_key.return_value.login_profile.posix_accounts = [ mock.MagicMock(username="test-username") ] @@ -83,16 +97,10 @@ class TestComputeEngineHookWithPassedProjectId: ), ] ) - mock_os_login_hook.assert_has_calls( - [ - mock.call(gcp_conn_id="google_cloud_default"), - mock.call()._get_credentials_email(), - mock.call().import_ssh_public_key( - ssh_public_key={"key": "NAME AYZ root", "expiration_time_usec": mock.ANY}, - project_id="test-project-id", - user=mock_os_login_hook.return_value._get_credentials_email.return_value, - ), - ] + mock_os_login_hook.import_ssh_public_key.assert_called_once_with( + ssh_public_key={"key": "NAME AYZ root", "expiration_time_usec": mock.ANY}, + project_id="test-project-id", + user="test-exam...@example.org", ) mock_ssh_client.assert_has_calls( [ @@ -113,7 +121,6 @@ class TestComputeEngineHookWithPassedProjectId: [(SSHException, r"Error occurred when establishing SSH connection using Paramiko")], ) @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook") - @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance") @@ -122,21 +129,26 @@ class TestComputeEngineHookWithPassedProjectId: mock_connect, mock_ssh_client, mock_paramiko, - mock_os_login_hook, mock_compute_hook, exception_type, error_message, caplog, + mocker, ): - mock_paramiko.SSHException = Exception + mock_paramiko.SSHException = RuntimeError mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME" mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ" mock_compute_hook.return_value.project_id = TEST_PROJECT_ID mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP - mock_os_login_hook.return_value._get_credentials_email.return_value = "test-exam...@example.org" - mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [ + mock_os_login_hook = mocker.patch.object( + ComputeEngineSSHHook, "_oslogin_hook", spec=OSLoginHook, name="FakeOsLoginHook" + ) + type(mock_os_login_hook)._get_credentials_email = mock.PropertyMock( + return_value="test-exam...@example.org" + ) + mock_os_login_hook.import_ssh_public_key.return_value.login_profile.posix_accounts = [ mock.MagicMock(username="test-username") ]