This is an automated email from the ASF dual-hosted git repository.

shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f47c3aa86b Adapt GCP CloudSQL trigger to run in private cloud (#66917)
4f47c3aa86b is described below

commit 4f47c3aa86b37587dbc97e862ee561bef6cef4ce
Author: olegkachur-e <[email protected]>
AuthorDate: Fri May 15 17:17:16 2026 +0000

    Adapt GCP CloudSQL trigger to run in private cloud (#66917)
---
 .../providers/google/cloud/hooks/cloud_sql.py      | 22 ++++++--
 .../providers/google/cloud/operators/cloud_sql.py  |  1 +
 .../providers/google/cloud/triggers/cloud_sql.py   | 19 ++++++-
 .../unit/google/cloud/hooks/test_cloud_sql.py      | 50 +++++++++++++++++
 .../unit/google/cloud/triggers/test_cloud_sql.py   | 64 ++++++++++++++++++----
 5 files changed, 136 insertions(+), 20 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
index eda0e341615..51868610223 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -420,6 +420,15 @@ class CloudSQLHook(GoogleBaseHook):
         self._wait_for_operation_to_complete(project_id=project_id, 
operation_name=operation_name)
         return response
 
+    @GoogleBaseHook.fallback_to_default_project_id
+    def get_operation(self, project_id: str, operation_name: str) -> dict:
+        return (
+            self.get_conn()
+            .operations()
+            .get(project=project_id, operation=operation_name)
+            .execute(num_retries=self.num_retries)
+        )
+
     @GoogleBaseHook.fallback_to_default_project_id
     def _wait_for_operation_to_complete(
         self, project_id: str, operation_name: str, time_to_sleep: int = 
TIME_TO_SLEEP_IN_SECONDS
@@ -432,13 +441,8 @@ class CloudSQLHook(GoogleBaseHook):
         :param time_to_sleep: Time to sleep between active checks of the 
operation results.
         :return: None
         """
-        service = self.get_conn()
         while True:
-            operation_response = (
-                service.operations()
-                .get(project=project_id, operation=operation_name)
-                .execute(num_retries=self.num_retries)
-            )
+            operation_response = self.get_operation(project_id=project_id, 
operation_name=operation_name)
             if operation_response.get("status") == 
CloudSqlOperationStatus.DONE:
                 error = operation_response.get("error")
                 if error:
@@ -474,6 +478,12 @@ class CloudSQLAsyncHook(GoogleBaseAsyncHook):
             }
         return await session_aio.get(url=url, headers=headers)
 
+    async def get_sync_hook(self, api_version: str = "v1beta4"):
+        if not self._sync_hook:
+            self._hook_kwargs["api_version"] = api_version
+            return await super().get_sync_hook()
+        return self._sync_hook
+
     async def get_operation_name(self, project_id: str, operation_name: str, 
session):
         url = 
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project_id}/operations/{operation_name}";
         return await self._get_conn(url=str(url), session=session)
diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py 
b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py
index b7f4de4fe9a..fb9de5b82bf 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -1023,6 +1023,7 @@ class 
CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
                 gcp_conn_id=self.gcp_conn_id,
                 impersonation_chain=self.impersonation_chain,
                 poke_interval=self.poke_interval,
+                api_version=self.api_version,
             ),
             method_name="execute_complete",
         )
diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_sql.py 
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_sql.py
index e7381c0463f..dd8a482d92d 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_sql.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_sql.py
@@ -22,6 +22,8 @@ from __future__ import annotations
 import asyncio
 from collections.abc import Sequence
 
+from asgiref.sync import sync_to_async
+
 from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLAsyncHook, 
CloudSqlOperationStatus
 from airflow.providers.google.common.hooks.base_google import 
PROVIDE_PROJECT_ID
 from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -41,6 +43,7 @@ class CloudSQLExportTrigger(BaseTrigger):
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
         poke_interval: int = 20,
+        api_version: str = "v1beta4",
     ):
         super().__init__()
         self.gcp_conn_id = gcp_conn_id
@@ -48,6 +51,7 @@ class CloudSQLExportTrigger(BaseTrigger):
         self.operation_name = operation_name
         self.project_id = project_id
         self.poke_interval = poke_interval
+        self.api_version = api_version
         self.hook = CloudSQLAsyncHook(
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
@@ -62,15 +66,24 @@ class CloudSQLExportTrigger(BaseTrigger):
                 "gcp_conn_id": self.gcp_conn_id,
                 "impersonation_chain": self.impersonation_chain,
                 "poke_interval": self.poke_interval,
+                "api_version": self.api_version,
             },
         )
 
     async def run(self):
         try:
+            sync_hook = await 
self.hook.get_sync_hook(api_version=self.api_version)
+            operation_kwargs = {
+                "project_id": self.project_id,
+                "operation_name": self.operation_name,
+            }
+
             while True:
-                operation = await self.hook.get_operation(
-                    project_id=self.project_id, 
operation_name=self.operation_name
-                )
+                if sync_hook.is_default_universe():
+                    operation = await 
self.hook.get_operation(**operation_kwargs)
+                else:
+                    operation = await 
sync_to_async(sync_hook.get_operation)(**operation_kwargs)
+
                 if operation["status"] == CloudSqlOperationStatus.DONE:
                     if "error" in operation:
                         yield TriggerEvent(
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py 
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
index ccaa3c010f2..b2a82c2a781 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
@@ -527,6 +527,42 @@ class TestGcpSqlHookDefaultProjectId:
             operation_name="operation_id", project_id="example-project"
         )
 
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn")
+    def test_get_operation(self, mock_get_conn):
+        operations_method = mock_get_conn.return_value.operations
+        get_method = operations_method.return_value.get
+        execute_method = get_method.return_value.execute
+
+        mock_response = {"name": "operation_id", "status": "DONE"}
+        execute_method.return_value = mock_response
+
+        result = self.cloudsql_hook.get_operation(project_id="gcp-project", 
operation_name="operation_id")
+
+        assert result == mock_response
+        operations_method.assert_called_once()
+        get_method.assert_called_once_with(project="gcp-project", 
operation="operation_id")
+        
execute_method.assert_called_once_with(num_retries=self.cloudsql_hook.num_retries)
+
+    @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.build")
+    def test_get_conn_obj_caching(self, mock_build):
+        self.cloudsql_hook._authorize = mock.MagicMock()
+        self.cloudsql_hook.get_client_options = mock.MagicMock()
+        mock_service_obj = mock.MagicMock()
+        mock_build.return_value = mock_service_obj
+
+        conn1 = self.cloudsql_hook.get_conn()
+        conn2 = self.cloudsql_hook.get_conn()
+
+        assert conn1 is conn2
+        assert mock_build.call_count == 1
+        mock_build.assert_called_once_with(
+            "sqladmin",
+            self.cloudsql_hook.api_version,
+            http=self.cloudsql_hook._authorize.return_value,
+            cache_discovery=False,
+            client_options=self.cloudsql_hook.get_client_options.return_value,
+        )
+
 
 class TestGcpSqlHookNoDefaultProjectID:
     def setup_method(self):
@@ -1968,3 +2004,17 @@ class TestCloudSQLAsyncHook:
         )
         with pytest.raises(HttpError):
             await hook_async.get_operation(operation_name=OPERATION_NAME, 
project_id=PROJECT_ID)
+
+    @pytest.mark.asyncio
+    async def test_get_sync_hook_override_sets_custom_api_version(self):
+        hook = CloudSQLAsyncHook(gcp_conn_id="test_conn")
+        with mock.patch(
+            
"airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.get_sync_hook",
+            new_callable=mock.AsyncMock,
+        ) as mock_super_get_sync:
+            mock_super_get_sync.return_value = mock.MagicMock()
+
+            await hook.get_sync_hook(api_version="api_v42")
+
+            assert hook._hook_kwargs["api_version"] == "api_v42"
+            mock_super_get_sync.assert_called_once()
diff --git 
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_sql.py 
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_sql.py
index c7cbb2046ce..8787a5cd068 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_sql.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_sql.py
@@ -18,10 +18,11 @@ from __future__ import annotations
 
 import asyncio
 import logging
-from unittest import mock as async_mock
+from unittest import mock
 
 import pytest
 
+from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLHook
 from airflow.providers.google.cloud.triggers.cloud_sql import 
CloudSQLExportTrigger
 from airflow.triggers.base import TriggerEvent
 
@@ -35,6 +36,7 @@ OPERATION_NAME = "test_operation_name"
 OPERATION_URL = (
     
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{PROJECT_ID}/operations/{OPERATION_NAME}";
 )
+API_VERSION = "v1test"
 
 
 @pytest.fixture
@@ -45,11 +47,22 @@ def trigger():
         impersonation_chain=None,
         gcp_conn_id=TEST_GCP_CONN_ID,
         poke_interval=TEST_POLL_INTERVAL,
+        api_version=API_VERSION,
     )
 
 
[email protected]
+def sync_hook_mock():
+    mock_obj = mock.MagicMock(spec=CloudSQLHook)
+    with mock.patch(
+        HOOK_STR.format("CloudSQLAsyncHook.get_sync_hook"), 
new_callable=mock.AsyncMock
+    ) as patched_get_sync_hook:
+        patched_get_sync_hook.return_value = mock_obj
+        yield mock_obj
+
+
 class TestCloudSQLExportTrigger:
-    def 
test_async_export_trigger_serialization_should_execute_successfully(self, 
trigger):
+    def 
test_async_export_trigger_serialization_should_execute_successfully(self, 
trigger, sync_hook_mock):
         """
         Asserts that the CloudSQLExportTrigger correctly serializes its 
arguments
         and classpath.
@@ -62,12 +75,13 @@ class TestCloudSQLExportTrigger:
             "impersonation_chain": None,
             "gcp_conn_id": TEST_GCP_CONN_ID,
             "poke_interval": TEST_POLL_INTERVAL,
+            "api_version": API_VERSION,
         }
 
     @pytest.mark.asyncio
-    @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+    @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
     async def test_async_export_trigger_on_success_should_execute_successfully(
-        self, mock_get_operation, trigger
+        self, mock_get_operation, trigger, sync_hook_mock
     ):
         """
         Tests the CloudSQLExportTrigger only fires once the job execution 
reaches a successful state.
@@ -89,14 +103,15 @@ class TestCloudSQLExportTrigger:
         )
 
     @pytest.mark.asyncio
-    @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+    @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
     async def test_async_export_trigger_running_should_execute_successfully(
-        self, mock_get_operation, trigger, caplog
+        self, mock_get_operation, trigger, sync_hook_mock, caplog
     ):
         """
         Test that CloudSQLExportTrigger does not fire while a job is still 
running.
         """
-
+        # Ensure execution for default universe
+        sync_hook_mock.is_default_universe.return_value = True
         mock_get_operation.return_value = {
             "status": "RUNNING",
             "name": OPERATION_NAME,
@@ -114,8 +129,10 @@ class TestCloudSQLExportTrigger:
         asyncio.get_event_loop().stop()
 
     @pytest.mark.asyncio
-    @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
-    async def 
test_async_export_trigger_error_should_execute_successfully(self, 
mock_get_operation, trigger):
+    @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+    async def test_async_export_trigger_error_should_execute_successfully(
+        self, mock_get_operation, trigger, sync_hook_mock
+    ):
         """
         Test that CloudSQLExportTrigger fires the correct event in case of an 
error.
         """
@@ -136,9 +153,9 @@ class TestCloudSQLExportTrigger:
         assert TriggerEvent(expected_event) == actual
 
     @pytest.mark.asyncio
-    @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+    @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
     async def test_async_export_trigger_exception_should_execute_successfully(
-        self, mock_get_operation, trigger
+        self, mock_get_operation, trigger, sync_hook_mock
     ):
         """
         Test that CloudSQLExportTrigger fires the correct event in case of an 
error.
@@ -148,3 +165,28 @@ class TestCloudSQLExportTrigger:
         generator = trigger.run()
         actual = await generator.asend(None)
         assert TriggerEvent({"status": "failed", "message": "Test exception"}) 
== actual
+
+    @pytest.mark.asyncio
+    @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+    async def 
test_async_export_trigger_executes_successfully_in_custom_universe(
+        self, mock_async_get_op, trigger, sync_hook_mock
+    ):
+        """
+        Test non-default universe trigger correct execution.
+        """
+        sync_hook_mock.is_default_universe.return_value = False
+        sync_hook_mock.get_operation.return_value = {
+            "status": "RUNNING",
+            "name": OPERATION_NAME,
+        }
+
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.1)
+
+        sync_hook_mock.is_default_universe.assert_called_once()
+        sync_hook_mock.get_operation.assert_called_once_with(
+            project_id=trigger.project_id, 
operation_name=trigger.operation_name
+        )
+        # Verify the default universe branch not being called
+        mock_async_get_op.assert_not_called()
+        task.cancel()

Reply via email to