This is an automated email from the ASF dual-hosted git repository.
shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4f47c3aa86b Adapt GCP CloudSQL trigger to run in private cloud (#66917)
4f47c3aa86b is described below
commit 4f47c3aa86b37587dbc97e862ee561bef6cef4ce
Author: olegkachur-e <[email protected]>
AuthorDate: Fri May 15 17:17:16 2026 +0000
Adapt GCP CloudSQL trigger to run in private cloud (#66917)
---
.../providers/google/cloud/hooks/cloud_sql.py | 22 ++++++--
.../providers/google/cloud/operators/cloud_sql.py | 1 +
.../providers/google/cloud/triggers/cloud_sql.py | 19 ++++++-
.../unit/google/cloud/hooks/test_cloud_sql.py | 50 +++++++++++++++++
.../unit/google/cloud/triggers/test_cloud_sql.py | 64 ++++++++++++++++++----
5 files changed, 136 insertions(+), 20 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
index eda0e341615..51868610223 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -420,6 +420,15 @@ class CloudSQLHook(GoogleBaseHook):
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)
return response
+ @GoogleBaseHook.fallback_to_default_project_id
+ def get_operation(self, project_id: str, operation_name: str) -> dict:
+ return (
+ self.get_conn()
+ .operations()
+ .get(project=project_id, operation=operation_name)
+ .execute(num_retries=self.num_retries)
+ )
+
@GoogleBaseHook.fallback_to_default_project_id
def _wait_for_operation_to_complete(
self, project_id: str, operation_name: str, time_to_sleep: int =
TIME_TO_SLEEP_IN_SECONDS
@@ -432,13 +441,8 @@ class CloudSQLHook(GoogleBaseHook):
:param time_to_sleep: Time to sleep between active checks of the
operation results.
:return: None
"""
- service = self.get_conn()
while True:
- operation_response = (
- service.operations()
- .get(project=project_id, operation=operation_name)
- .execute(num_retries=self.num_retries)
- )
+ operation_response = self.get_operation(project_id=project_id,
operation_name=operation_name)
if operation_response.get("status") ==
CloudSqlOperationStatus.DONE:
error = operation_response.get("error")
if error:
@@ -474,6 +478,12 @@ class CloudSQLAsyncHook(GoogleBaseAsyncHook):
}
return await session_aio.get(url=url, headers=headers)
+ async def get_sync_hook(self, api_version: str = "v1beta4"):
+ if not self._sync_hook:
+ self._hook_kwargs["api_version"] = api_version
+ return await super().get_sync_hook()
+ return self._sync_hook
+
async def get_operation_name(self, project_id: str, operation_name: str,
session):
url =
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project_id}/operations/{operation_name}"
return await self._get_conn(url=str(url), session=session)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py
b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py
index b7f4de4fe9a..fb9de5b82bf 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -1023,6 +1023,7 @@ class
CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
poke_interval=self.poke_interval,
+ api_version=self.api_version,
),
method_name="execute_complete",
)
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_sql.py
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_sql.py
index e7381c0463f..dd8a482d92d 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_sql.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_sql.py
@@ -22,6 +22,8 @@ from __future__ import annotations
import asyncio
from collections.abc import Sequence
+from asgiref.sync import sync_to_async
+
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLAsyncHook,
CloudSqlOperationStatus
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -41,6 +43,7 @@ class CloudSQLExportTrigger(BaseTrigger):
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
poke_interval: int = 20,
+ api_version: str = "v1beta4",
):
super().__init__()
self.gcp_conn_id = gcp_conn_id
@@ -48,6 +51,7 @@ class CloudSQLExportTrigger(BaseTrigger):
self.operation_name = operation_name
self.project_id = project_id
self.poke_interval = poke_interval
+ self.api_version = api_version
self.hook = CloudSQLAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
@@ -62,15 +66,24 @@ class CloudSQLExportTrigger(BaseTrigger):
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"poke_interval": self.poke_interval,
+ "api_version": self.api_version,
},
)
async def run(self):
try:
+ sync_hook = await
self.hook.get_sync_hook(api_version=self.api_version)
+ operation_kwargs = {
+ "project_id": self.project_id,
+ "operation_name": self.operation_name,
+ }
+
while True:
- operation = await self.hook.get_operation(
- project_id=self.project_id,
operation_name=self.operation_name
- )
+ if sync_hook.is_default_universe():
+ operation = await
self.hook.get_operation(**operation_kwargs)
+ else:
+ operation = await
sync_to_async(sync_hook.get_operation)(**operation_kwargs)
+
if operation["status"] == CloudSqlOperationStatus.DONE:
if "error" in operation:
yield TriggerEvent(
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
index ccaa3c010f2..b2a82c2a781 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
@@ -527,6 +527,42 @@ class TestGcpSqlHookDefaultProjectId:
operation_name="operation_id", project_id="example-project"
)
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn")
+ def test_get_operation(self, mock_get_conn):
+ operations_method = mock_get_conn.return_value.operations
+ get_method = operations_method.return_value.get
+ execute_method = get_method.return_value.execute
+
+ mock_response = {"name": "operation_id", "status": "DONE"}
+ execute_method.return_value = mock_response
+
+ result = self.cloudsql_hook.get_operation(project_id="gcp-project",
operation_name="operation_id")
+
+ assert result == mock_response
+ operations_method.assert_called_once()
+ get_method.assert_called_once_with(project="gcp-project",
operation="operation_id")
+
execute_method.assert_called_once_with(num_retries=self.cloudsql_hook.num_retries)
+
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.build")
+ def test_get_conn_obj_caching(self, mock_build):
+ self.cloudsql_hook._authorize = mock.MagicMock()
+ self.cloudsql_hook.get_client_options = mock.MagicMock()
+ mock_service_obj = mock.MagicMock()
+ mock_build.return_value = mock_service_obj
+
+ conn1 = self.cloudsql_hook.get_conn()
+ conn2 = self.cloudsql_hook.get_conn()
+
+ assert conn1 is conn2
+ assert mock_build.call_count == 1
+ mock_build.assert_called_once_with(
+ "sqladmin",
+ self.cloudsql_hook.api_version,
+ http=self.cloudsql_hook._authorize.return_value,
+ cache_discovery=False,
+ client_options=self.cloudsql_hook.get_client_options.return_value,
+ )
+
class TestGcpSqlHookNoDefaultProjectID:
def setup_method(self):
@@ -1968,3 +2004,17 @@ class TestCloudSQLAsyncHook:
)
with pytest.raises(HttpError):
await hook_async.get_operation(operation_name=OPERATION_NAME,
project_id=PROJECT_ID)
+
+ @pytest.mark.asyncio
+ async def test_get_sync_hook_override_sets_custom_api_version(self):
+ hook = CloudSQLAsyncHook(gcp_conn_id="test_conn")
+ with mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.get_sync_hook",
+ new_callable=mock.AsyncMock,
+ ) as mock_super_get_sync:
+ mock_super_get_sync.return_value = mock.MagicMock()
+
+ await hook.get_sync_hook(api_version="api_v42")
+
+ assert hook._hook_kwargs["api_version"] == "api_v42"
+ mock_super_get_sync.assert_called_once()
diff --git
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_sql.py
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_sql.py
index c7cbb2046ce..8787a5cd068 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_sql.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_sql.py
@@ -18,10 +18,11 @@ from __future__ import annotations
import asyncio
import logging
-from unittest import mock as async_mock
+from unittest import mock
import pytest
+from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLHook
from airflow.providers.google.cloud.triggers.cloud_sql import
CloudSQLExportTrigger
from airflow.triggers.base import TriggerEvent
@@ -35,6 +36,7 @@ OPERATION_NAME = "test_operation_name"
OPERATION_URL = (
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{PROJECT_ID}/operations/{OPERATION_NAME}"
)
+API_VERSION = "v1test"
@pytest.fixture
@@ -45,11 +47,22 @@ def trigger():
impersonation_chain=None,
gcp_conn_id=TEST_GCP_CONN_ID,
poke_interval=TEST_POLL_INTERVAL,
+ api_version=API_VERSION,
)
[email protected]
+def sync_hook_mock():
+ mock_obj = mock.MagicMock(spec=CloudSQLHook)
+ with mock.patch(
+ HOOK_STR.format("CloudSQLAsyncHook.get_sync_hook"),
new_callable=mock.AsyncMock
+ ) as patched_get_sync_hook:
+ patched_get_sync_hook.return_value = mock_obj
+ yield mock_obj
+
+
class TestCloudSQLExportTrigger:
- def
test_async_export_trigger_serialization_should_execute_successfully(self,
trigger):
+ def
test_async_export_trigger_serialization_should_execute_successfully(self,
trigger, sync_hook_mock):
"""
Asserts that the CloudSQLExportTrigger correctly serializes its
arguments
and classpath.
@@ -62,12 +75,13 @@ class TestCloudSQLExportTrigger:
"impersonation_chain": None,
"gcp_conn_id": TEST_GCP_CONN_ID,
"poke_interval": TEST_POLL_INTERVAL,
+ "api_version": API_VERSION,
}
@pytest.mark.asyncio
- @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+ @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
async def test_async_export_trigger_on_success_should_execute_successfully(
- self, mock_get_operation, trigger
+ self, mock_get_operation, trigger, sync_hook_mock
):
"""
Tests the CloudSQLExportTrigger only fires once the job execution
reaches a successful state.
@@ -89,14 +103,15 @@ class TestCloudSQLExportTrigger:
)
@pytest.mark.asyncio
- @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+ @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
async def test_async_export_trigger_running_should_execute_successfully(
- self, mock_get_operation, trigger, caplog
+ self, mock_get_operation, trigger, sync_hook_mock, caplog
):
"""
Test that CloudSQLExportTrigger does not fire while a job is still
running.
"""
-
+ # Ensure execution for default universe
+ sync_hook_mock.is_default_universe.return_value = True
mock_get_operation.return_value = {
"status": "RUNNING",
"name": OPERATION_NAME,
@@ -114,8 +129,10 @@ class TestCloudSQLExportTrigger:
asyncio.get_event_loop().stop()
@pytest.mark.asyncio
- @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
- async def
test_async_export_trigger_error_should_execute_successfully(self,
mock_get_operation, trigger):
+ @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+ async def test_async_export_trigger_error_should_execute_successfully(
+ self, mock_get_operation, trigger, sync_hook_mock
+ ):
"""
Test that CloudSQLExportTrigger fires the correct event in case of an
error.
"""
@@ -136,9 +153,9 @@ class TestCloudSQLExportTrigger:
assert TriggerEvent(expected_event) == actual
@pytest.mark.asyncio
- @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+ @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
async def test_async_export_trigger_exception_should_execute_successfully(
- self, mock_get_operation, trigger
+ self, mock_get_operation, trigger, sync_hook_mock
):
"""
Test that CloudSQLExportTrigger fires the correct event in case of an
error.
@@ -148,3 +165,28 @@ class TestCloudSQLExportTrigger:
generator = trigger.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "failed", "message": "Test exception"})
== actual
+
+ @pytest.mark.asyncio
+ @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+ async def
test_async_export_trigger_executes_successfully_in_custom_universe(
+ self, mock_async_get_op, trigger, sync_hook_mock
+ ):
+ """
+ Test non-default universe trigger correct execution.
+ """
+ sync_hook_mock.is_default_universe.return_value = False
+ sync_hook_mock.get_operation.return_value = {
+ "status": "RUNNING",
+ "name": OPERATION_NAME,
+ }
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.1)
+
+ sync_hook_mock.is_default_universe.assert_called_once()
+ sync_hook_mock.get_operation.assert_called_once_with(
+ project_id=trigger.project_id,
operation_name=trigger.operation_name
+ )
+ # Verify the default universe branch not being called
+ mock_async_get_op.assert_not_called()
+ task.cancel()