This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 866a1f0f6d2 Fix CloudComposerAsyncHook to work correctly with Airflow
3. (#54976)
866a1f0f6d2 is described below
commit 866a1f0f6d2dc585dfdc0188e93990cc97e6a327
Author: Nitochkin <[email protected]>
AuthorDate: Sat Sep 6 10:44:22 2025 +0200
Fix CloudComposerAsyncHook to work correctly with Airflow 3. (#54976)
Co-authored-by: Anton Nitochkin <[email protected]>
Co-authored-by: VladaZakharova <[email protected]>
---
.../providers/google/cloud/hooks/cloud_composer.py | 26 ++++++++--------
.../google/cloud/triggers/cloud_composer.py | 36 +++++++++++++---------
.../unit/google/cloud/hooks/test_cloud_composer.py | 2 +-
3 files changed, 36 insertions(+), 28 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
index 9c963e47b0f..daf6a06a927 100644
---
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
@@ -36,7 +36,7 @@ from google.cloud.orchestration.airflow.service_v1 import (
from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import
GoogleBaseAsyncHook, GoogleBaseHook
if TYPE_CHECKING:
from google.api_core.operation import Operation
@@ -473,15 +473,18 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper):
return response.json()
-class CloudComposerAsyncHook(GoogleBaseHook):
+class CloudComposerAsyncHook(GoogleBaseAsyncHook):
"""Hook for Google Cloud Composer async APIs."""
+ sync_hook_class = CloudComposerHook
+
client_options = ClientOptions(api_endpoint="composer.googleapis.com:443")
- def get_environment_client(self) -> EnvironmentsAsyncClient:
+ async def get_environment_client(self) -> EnvironmentsAsyncClient:
"""Retrieve client library object that allow access Environments
service."""
+ sync_hook = await self.get_sync_hook()
return EnvironmentsAsyncClient(
- credentials=self.get_credentials(),
+ credentials=sync_hook.get_credentials(),
client_info=CLIENT_INFO,
client_options=self.client_options,
)
@@ -493,9 +496,8 @@ class CloudComposerAsyncHook(GoogleBaseHook):
return f"projects/{project_id}/locations/{region}"
async def get_operation(self, operation_name):
- return await
self.get_environment_client().transport.operations_client.get_operation(
- name=operation_name
- )
+ client = await self.get_environment_client()
+ return await
client.transport.operations_client.get_operation(name=operation_name)
@GoogleBaseHook.fallback_to_default_project_id
async def create_environment(
@@ -518,7 +520,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request
as metadata.
"""
- client = self.get_environment_client()
+ client = await self.get_environment_client()
return await client.create_environment(
request={"parent": self.get_parent(project_id, region),
"environment": environment},
retry=retry,
@@ -546,7 +548,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request
as metadata.
"""
- client = self.get_environment_client()
+ client = await self.get_environment_client()
name = self.get_environment_name(project_id, region, environment_id)
return await client.delete_environment(
request={"name": name}, retry=retry, timeout=timeout,
metadata=metadata
@@ -582,7 +584,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request
as metadata.
"""
- client = self.get_environment_client()
+ client = await self.get_environment_client()
name = self.get_environment_name(project_id, region, environment_id)
return await client.update_environment(
@@ -620,7 +622,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request
as metadata.
"""
- client = self.get_environment_client()
+ client = await self.get_environment_client()
return await client.execute_airflow_command(
request={
@@ -662,7 +664,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request
as metadata.
"""
- client = self.get_environment_client()
+ client = await self.get_environment_client()
return await client.poll_airflow_command(
request={
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
index a2840dcffe2..f6654a39351 100644
---
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -52,11 +52,6 @@ class CloudComposerExecutionTrigger(BaseTrigger):
self.impersonation_chain = impersonation_chain
self.pooling_period_seconds = pooling_period_seconds
- self.gcp_hook = CloudComposerAsyncHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
-
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerExecutionTrigger",
@@ -70,7 +65,14 @@ class CloudComposerExecutionTrigger(BaseTrigger):
},
)
+ def _get_async_hook(self) -> CloudComposerAsyncHook:
+ return CloudComposerAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
async def run(self):
+ self.gcp_hook = self._get_async_hook()
while True:
operation = await
self.gcp_hook.get_operation(operation_name=self.operation_name)
if operation.done:
@@ -108,11 +110,6 @@ class CloudComposerAirflowCLICommandTrigger(BaseTrigger):
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval
- self.gcp_hook = CloudComposerAsyncHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
-
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerAirflowCLICommandTrigger",
@@ -127,7 +124,14 @@ class CloudComposerAirflowCLICommandTrigger(BaseTrigger):
},
)
+ def _get_async_hook(self) -> CloudComposerAsyncHook:
+ return CloudComposerAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
async def run(self):
+ self.gcp_hook = self._get_async_hook()
try:
result = await self.gcp_hook.wait_command_execution_result(
project_id=self.project_id,
@@ -199,11 +203,6 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
self.poll_interval = poll_interval
self.composer_airflow_version = composer_airflow_version
- self.gcp_hook = CloudComposerAsyncHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
-
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger",
@@ -264,6 +263,12 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
return False
return True
+ def _get_async_hook(self) -> CloudComposerAsyncHook:
+ return CloudComposerAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
for dag_run in dag_runs:
if dag_run["run_id"] == self.composer_dag_run_id and
dag_run["state"] in self.allowed_states:
@@ -271,6 +276,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
return False
async def run(self):
+ self.gcp_hook: CloudComposerAsyncHook = self._get_async_hook()
try:
while True:
if datetime.now(self.end_date.tzinfo).timestamp() >
self.end_date.timestamp():
diff --git
a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
index 4a371794793..cd62056e36e 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
@@ -286,7 +286,7 @@ class TestCloudComposerHook:
class TestCloudComposerAsyncHook:
def setup_method(self, method):
- with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"),
new=mock_init):
+ with mock.patch(BASE_STRING.format("GoogleBaseAsyncHook.__init__"),
new=mock_init):
self.hook = CloudComposerAsyncHook(gcp_conn_id="test")
@pytest.mark.asyncio