This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d07aec28c90 Switch Cloud Run operators to use regional endpoints
(#61857)
d07aec28c90 is described below
commit d07aec28c90c58ae5fdcfd4eb024e15848573bfd
Author: VladaZakharova <[email protected]>
AuthorDate: Tue Feb 17 01:52:09 2026 +0100
Switch Cloud Run operators to use regional endpoints (#61857)
* Add regional endpoint support
* Add regional endpoint support
---
docs/spelling_wordlist.txt | 3 +
.../google/docs/operators/cloud/cloud_run.rst | 21 ++
.../providers/google/cloud/hooks/cloud_run.py | 229 +++++++++++++++++----
.../providers/google/cloud/operators/cloud_run.py | 82 +++++++-
.../providers/google/cloud/triggers/cloud_run.py | 18 +-
.../google/cloud/cloud_run/example_cloud_run.py | 15 +-
.../cloud/cloud_run/example_cloud_run_service.py | 12 +-
.../unit/google/cloud/hooks/test_cloud_run.py | 153 ++++++++++----
.../unit/google/cloud/operators/test_cloud_run.py | 55 ++++-
.../unit/google/cloud/triggers/test_cloud_run.py | 9 +-
10 files changed, 488 insertions(+), 109 deletions(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 1a576d6df54..56bfd885844 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -995,6 +995,8 @@ jobId
JobList
jobName
JobRunning
+JobsAsyncClient
+JobsClient
JobSpec
JobStatus
jobtracker
@@ -1651,6 +1653,7 @@ servicebus
ServiceBusReceivedMessage
ServicePrincipalCredentials
ServiceResource
+ServicesClient
SES
sessionmaker
setattr
diff --git a/providers/google/docs/operators/cloud/cloud_run.rst
b/providers/google/docs/operators/cloud/cloud_run.rst
index 94291f25718..10dfcb5339e 100644
--- a/providers/google/docs/operators/cloud/cloud_run.rst
+++ b/providers/google/docs/operators/cloud/cloud_run.rst
@@ -33,6 +33,9 @@ Create a job
Before you create a job in Cloud Run, you need to define it.
For more information about the Job object fields, visit `Google Cloud Run Job
description
<https://cloud.google.com/run/docs/reference/rpc/google.cloud.run.v2#google.cloud.run.v2.Job>`__
+If you want to specify the regional endpoint that will be used to create a
Cloud Run client, you can set the flag use_regional_endpoint as True,
+and the API endpoint will be configured with the location you have specified
in the operator.
+
A simple job configuration can be created with a Job object:
.. exampleinclude::
/../../google/tests/system/google/cloud/cloud_run/example_cloud_run.py
@@ -67,6 +70,9 @@ Create a service
Before you create a service in Cloud Run, you need to define it.
For more information about the Service object fields, visit `Google Cloud Run
Service description
<https://cloud.google.com/run/docs/reference/rpc/google.cloud.run.v2#google.cloud.run.v2.Service>`__
+If you want to specify the regional endpoint that will be used to create a
Cloud Run client, you can set the flag use_regional_endpoint as True,
+and the API endpoint will be configured with the location you have specified
in the operator.
+
A simple service configuration can look as follows:
.. exampleinclude::
/../../google/tests/system/google/cloud/cloud_run/example_cloud_run_service.py
@@ -91,6 +97,9 @@ Note that this operator only creates the service without
executing it. The Servi
Delete a service
---------------------
+If you want to specify the regional endpoint that will be used to create a
Cloud Run client, you can set the flag use_regional_endpoint as True,
+and the API endpoint will be configured with the location you have specified
in the operator.
+
With this configuration we can delete the service:
:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunDeleteServiceOperator`
@@ -106,6 +115,9 @@ Note this operator waits for the service to be deleted, and
the deleted Service'
Execute a job
---------------------
+If you want to specify the regional endpoint that will be used to create a
Cloud Run client, you can set the flag use_regional_endpoint as True,
+and the API endpoint will be configured with the location you have specified
in the operator.
+
To execute a job, you can use:
:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator`
@@ -167,6 +179,9 @@ You can also specify overrides that allow you to give a new
entrypoint command t
Update a job
------------------
+If you want to specify the regional endpoint that will be used to create a
Cloud Run client, you can set the flag use_regional_endpoint as True,
+and the API endpoint will be configured with the location you have specified
in the operator.
+
To update a job, you can use:
:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunUpdateJobOperator`
@@ -184,6 +199,9 @@ The Job's dictionary representation is pushed to XCom.
List jobs
----------------------
+If you want to specify the regional endpoint that will be used to create a
Cloud Run client, you can set the flag use_regional_endpoint as True,
+and the API endpoint will be configured with the location you have specified
in the operator.
+
To list the jobs, you can use:
:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunListJobsOperator`
@@ -200,6 +218,9 @@ The operator takes two optional parameters: "limit" to
limit the number of tasks
Delete a job
-----------------
+If you want to specify the regional endpoint that will be used to create a
Cloud Run client, you can set the flag use_regional_endpoint as True,
+and the API endpoint will be configured with the location you have specified
in the operator.
+
To delete a job you can use:
:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunDeleteJobOperator`
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py
index 24fe0f200c8..c609ff737c1 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py
@@ -22,6 +22,7 @@ import itertools
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, Literal
+from google.api_core.client_options import ClientOptions
from google.cloud.run_v2 import (
CreateJobRequest,
CreateServiceRequest,
@@ -55,6 +56,12 @@ if TYPE_CHECKING:
from google.cloud.run_v2.services.jobs import pagers
+class NoLocationSpecifiedException(Exception):
+ """Custom exception to catch error when location is not specified."""
+
+ pass
+
+
class CloudRunHook(GoogleBaseHook):
"""
Hook for the Google Cloud Run service.
@@ -84,9 +91,13 @@ class CloudRunHook(GoogleBaseHook):
self._client: JobsClient | None = None
self.transport = transport
- def get_conn(self):
+ def get_conn(self, location: str | None = None, use_regional_endpoint:
bool | None = False) -> JobsClient:
"""
- Retrieve connection to Cloud Run.
+ Retrieve the connection to Google Cloud Run.
+
+ :param location: The location of the project.
+ :param use_regional_endpoint: If set to True, regional endpoint will
be used while creating Client.
+ If not provided, the default one is global endpoint.
:return: Cloud Run Jobs client object.
"""
@@ -97,20 +108,41 @@ class CloudRunHook(GoogleBaseHook):
}
if self.transport:
client_kwargs["transport"] = self.transport
- self._client = JobsClient(**client_kwargs)
+ if use_regional_endpoint:
+ if not location:
+ raise NoLocationSpecifiedException(
+ "No location was specified while using
use_regional_endpoint parameter"
+ )
+ client_kwargs["client_options"] = ClientOptions(
+ api_endpoint=f"{location}-run.googleapis.com:443"
+ )
+ self._client = JobsClient(**client_kwargs) # type:
ignore[arg-type]
return self._client
@GoogleBaseHook.fallback_to_default_project_id
- def delete_job(self, job_name: str, region: str, project_id: str =
PROVIDE_PROJECT_ID) -> Job:
+ def delete_job(
+ self,
+ job_name: str,
+ region: str,
+ use_regional_endpoint: bool | None = False,
+ project_id: str = PROVIDE_PROJECT_ID,
+ ) -> Job:
delete_request = DeleteJobRequest()
delete_request.name =
f"projects/{project_id}/locations/{region}/jobs/{job_name}"
- operation = self.get_conn().delete_job(delete_request)
+ operation = self.get_conn(location=region,
use_regional_endpoint=use_regional_endpoint).delete_job(
+ delete_request
+ )
return operation.result()
@GoogleBaseHook.fallback_to_default_project_id
def create_job(
- self, job_name: str, job: Job | dict, region: str, project_id: str =
PROVIDE_PROJECT_ID
+ self,
+ job_name: str,
+ job: Job | dict,
+ region: str,
+ use_regional_endpoint: bool | None = False,
+ project_id: str = PROVIDE_PROJECT_ID,
) -> Job:
if isinstance(job, dict):
job = Job(job)
@@ -120,12 +152,19 @@ class CloudRunHook(GoogleBaseHook):
create_request.job_id = job_name
create_request.parent = f"projects/{project_id}/locations/{region}"
- operation = self.get_conn().create_job(create_request)
+ operation = self.get_conn(location=region,
use_regional_endpoint=use_regional_endpoint).create_job(
+ create_request
+ )
return operation.result()
@GoogleBaseHook.fallback_to_default_project_id
def update_job(
- self, job_name: str, job: Job | dict, region: str, project_id: str =
PROVIDE_PROJECT_ID
+ self,
+ job_name: str,
+ job: Job | dict,
+ region: str,
+ use_regional_endpoint: bool | None = False,
+ project_id: str = PROVIDE_PROJECT_ID,
) -> Job:
if isinstance(job, dict):
job = Job(job)
@@ -133,7 +172,9 @@ class CloudRunHook(GoogleBaseHook):
update_request = UpdateJobRequest()
job.name = f"projects/{project_id}/locations/{region}/jobs/{job_name}"
update_request.job = job
- operation = self.get_conn().update_job(update_request)
+ operation = self.get_conn(location=region,
use_regional_endpoint=use_regional_endpoint).update_job(
+ update_request
+ )
return operation.result()
@GoogleBaseHook.fallback_to_default_project_id
@@ -141,24 +182,36 @@ class CloudRunHook(GoogleBaseHook):
self,
job_name: str,
region: str,
+ use_regional_endpoint: bool | None = False,
project_id: str = PROVIDE_PROJECT_ID,
overrides: dict[str, Any] | None = None,
) -> operation.Operation:
run_job_request = RunJobRequest(
name=f"projects/{project_id}/locations/{region}/jobs/{job_name}",
overrides=overrides
)
- operation = self.get_conn().run_job(request=run_job_request)
+ operation = self.get_conn(location=region,
use_regional_endpoint=use_regional_endpoint).run_job(
+ request=run_job_request
+ )
return operation
@GoogleBaseHook.fallback_to_default_project_id
- def get_job(self, job_name: str, region: str, project_id: str =
PROVIDE_PROJECT_ID):
+ def get_job(
+ self,
+ job_name: str,
+ region: str,
+ use_regional_endpoint: bool | None = False,
+ project_id: str = PROVIDE_PROJECT_ID,
+ ):
get_job_request =
GetJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
- return self.get_conn().get_job(get_job_request)
+ return self.get_conn(location=region,
use_regional_endpoint=use_regional_endpoint).get_job(
+ get_job_request
+ )
@GoogleBaseHook.fallback_to_default_project_id
def list_jobs(
self,
region: str,
+ use_regional_endpoint: bool | None = False,
project_id: str = PROVIDE_PROJECT_ID,
show_deleted: bool = False,
limit: int | None = None,
@@ -170,7 +223,9 @@ class CloudRunHook(GoogleBaseHook):
parent=f"projects/{project_id}/locations/{region}",
show_deleted=show_deleted
)
- jobs: pagers.ListJobsPager =
self.get_conn().list_jobs(request=list_jobs_request)
+ jobs: pagers.ListJobsPager = self.get_conn(
+ location=region, use_regional_endpoint=use_regional_endpoint
+ ).list_jobs(request=list_jobs_request)
return list(itertools.islice(jobs, limit))
@@ -207,29 +262,52 @@ class CloudRunAsyncHook(GoogleBaseAsyncHook):
self.transport = transport
super().__init__(gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain, **kwargs)
- async def get_conn(self):
+ async def get_conn(
+ self, location: str | None = None, use_regional_endpoint: bool | None
= False
+ ) -> JobsAsyncClient | JobsClient:
+ """
+ Retrieve the connection to Google Cloud Run.
+
+ :param location: The location of the project.
+ :param use_regional_endpoint: If set to True, regional endpoint will
be used while creating Client.
+ If not provided, the default one is global endpoint.
+
+ :return: Cloud Run Jobs client object.
+ """
if self._client is None:
sync_hook = await self.get_sync_hook()
credentials = sync_hook.get_credentials()
+ common_kwargs = {
+ "credentials": credentials,
+ "client_info": CLIENT_INFO,
+ }
+ if use_regional_endpoint:
+ if not location:
+ raise NoLocationSpecifiedException(
+ "No location was specified while using
use_regional_endpoint parameter"
+ )
+ common_kwargs["client_options"] = ClientOptions(
+ api_endpoint=f"{location}-run.googleapis.com:443"
+ )
if self.transport == "rest":
# REST transport is synchronous-only. Use the sync JobsClient
here;
# get_operation() wraps calls with asyncio.to_thread() for
async compat.
self._client = JobsClient(
- credentials=credentials,
- client_info=CLIENT_INFO,
transport="rest",
+ **common_kwargs,
)
else:
# Default: use JobsAsyncClient which picks grpc_asyncio
transport.
self._client = JobsAsyncClient(
- credentials=credentials,
- client_info=CLIENT_INFO,
+ **common_kwargs,
)
return self._client
- async def get_operation(self, operation_name: str) ->
operations_pb2.Operation:
- conn = await self.get_conn()
+ async def get_operation(
+ self, operation_name: str, location: str | None = None,
use_regional_endpoint: bool | None = False
+ ) -> operations_pb2.Operation:
+ conn = await self.get_conn(location=location,
use_regional_endpoint=use_regional_endpoint)
request = operations_pb2.GetOperationRequest(name=operation_name)
if self.transport == "rest":
# REST client is synchronous — run in a thread to avoid blocking
the event loop.
@@ -261,22 +339,57 @@ class CloudRunServiceHook(GoogleBaseHook):
self._client: ServicesClient | None = None
super().__init__(gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain, **kwargs)
- def get_conn(self):
- if self._client is None:
- self._client = ServicesClient(credentials=self.get_credentials(),
client_info=CLIENT_INFO)
+ def get_conn(
+ self, location: str | None = None, use_regional_endpoint: bool | None
= False
+ ) -> ServicesClient:
+ """
+ Retrieve the connection to Google Cloud Run.
+
+ :param location: The location of the project.
+ :param use_regional_endpoint: If set to True, regional endpoint will
be used while creating Client.
+ If not provided, the default one is global endpoint.
+ :return: Google Cloud Run client object.
+ """
+ if self._client is None:
+ client_kwargs = {
+ "credentials": self.get_credentials(),
+ "client_info": CLIENT_INFO,
+ }
+ if use_regional_endpoint:
+ if not location:
+ raise NoLocationSpecifiedException(
+ "No location was specified while using
use_regional_endpoint parameter"
+ )
+ client_kwargs["client_options"] = ClientOptions(
+ api_endpoint=f"{location}-run.googleapis.com:443"
+ )
+ self._client = ServicesClient(**client_kwargs) # type:
ignore[arg-type]
return self._client
@GoogleBaseHook.fallback_to_default_project_id
- def get_service(self, service_name: str, region: str, project_id: str =
PROVIDE_PROJECT_ID):
+ def get_service(
+ self,
+ service_name: str,
+ region: str,
+ use_regional_endpoint: bool | None = False,
+ project_id: str = PROVIDE_PROJECT_ID,
+ ):
get_service_request = GetServiceRequest(
name=f"projects/{project_id}/locations/{region}/services/{service_name}"
)
- return self.get_conn().get_service(get_service_request)
+ return self.get_conn(location=region,
use_regional_endpoint=use_regional_endpoint).get_service(
+ get_service_request
+ )
@GoogleBaseHook.fallback_to_default_project_id
def create_service(
- self, service_name: str, service: Service | dict, region: str,
project_id: str = PROVIDE_PROJECT_ID
+ self,
+ service_name: str,
+ service: Service | dict,
+ region: str,
+ use_regional_endpoint: bool | None = False,
+ project_id: str = PROVIDE_PROJECT_ID,
) -> Service:
if isinstance(service, dict):
service = Service(service)
@@ -287,16 +400,26 @@ class CloudRunServiceHook(GoogleBaseHook):
service_id=service_name,
)
- operation = self.get_conn().create_service(create_request)
+ operation = self.get_conn(
+ location=region, use_regional_endpoint=use_regional_endpoint
+ ).create_service(create_request)
return operation.result()
@GoogleBaseHook.fallback_to_default_project_id
- def delete_service(self, service_name: str, region: str, project_id: str =
PROVIDE_PROJECT_ID) -> Service:
+ def delete_service(
+ self,
+ service_name: str,
+ region: str,
+ use_regional_endpoint: bool | None = False,
+ project_id: str = PROVIDE_PROJECT_ID,
+ ) -> Service:
delete_request = DeleteServiceRequest(
name=f"projects/{project_id}/locations/{region}/services/{service_name}"
)
- operation = self.get_conn().delete_service(delete_request)
+ operation = self.get_conn(
+ location=region, use_regional_endpoint=use_regional_endpoint
+ ).delete_service(delete_request)
return operation.result()
@@ -323,18 +446,44 @@ class CloudRunServiceAsyncHook(GoogleBaseAsyncHook):
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
):
- self._client: ServicesClient | None = None
+ self._client: ServicesClient | ServicesAsyncClient | None = None
super().__init__(gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain, **kwargs)
- def get_conn(self):
- if self._client is None:
- self._client =
ServicesAsyncClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
+ async def get_conn(
+ self, location: str | None = None, use_regional_endpoint: bool | None
= False
+ ) -> ServicesClient | ServicesAsyncClient:
+ """
+ Retrieve the connection to Google Cloud Run.
+ :param location: The location of the project.
+ :param use_regional_endpoint: If set to True, regional endpoint will
be used while creating Client.
+ If not provided, the default one is global endpoint.
+ """
+ if self._client is None:
+ sync_hook = await self.get_sync_hook()
+ client_kwargs = {
+ "credentials": sync_hook.get_credentials(),
+ "client_info": CLIENT_INFO,
+ }
+ if use_regional_endpoint:
+ if not location:
+ raise NoLocationSpecifiedException(
+ "No location was specified while using
use_regional_endpoint parameter"
+ )
+ client_kwargs["client_options"] = ClientOptions(
+ api_endpoint=f"{location}-run.googleapis.com:443"
+ )
+ self._client = ServicesAsyncClient(**client_kwargs)
return self._client
@GoogleBaseHook.fallback_to_default_project_id
async def create_service(
- self, service_name: str, service: Service | dict, region: str,
project_id: str = PROVIDE_PROJECT_ID
+ self,
+ service_name: str,
+ service: Service | dict,
+ region: str,
+ use_regional_endpoint: bool | None = False,
+ project_id: str = PROVIDE_PROJECT_ID,
) -> AsyncOperation:
if isinstance(service, dict):
service = Service(service)
@@ -344,15 +493,21 @@ class CloudRunServiceAsyncHook(GoogleBaseAsyncHook):
service=service,
service_id=service_name,
)
+ client = await self.get_conn(location=region,
use_regional_endpoint=use_regional_endpoint)
- return await self.get_conn().create_service(create_request)
+ return await client.create_service(create_request) # type:
ignore[misc]
@GoogleBaseHook.fallback_to_default_project_id
async def delete_service(
- self, service_name: str, region: str, project_id: str =
PROVIDE_PROJECT_ID
+ self,
+ service_name: str,
+ region: str,
+ use_regional_endpoint: bool | None = False,
+ project_id: str = PROVIDE_PROJECT_ID,
) -> AsyncOperation:
delete_request = DeleteServiceRequest(
name=f"projects/{project_id}/locations/{region}/services/{service_name}"
)
+ client = await self.get_conn(location=region,
use_regional_endpoint=use_regional_endpoint)
- return await self.get_conn().delete_service(delete_request)
+ return await client.delete_service(delete_request) # type:
ignore[misc]
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py
b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py
index 5c12dd4d7da..3692457038e 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py
@@ -45,6 +45,8 @@ class CloudRunCreateJobOperator(GoogleCloudBaseOperator):
:param region: Required. The ID of the Google Cloud region that the
service belongs to.
:param job_name: Required. The name of the job to create.
:param job: Required. The job descriptor containing the configuration of
the job to submit.
+ :param use_regional_endpoint: If set to True, regional endpoint will be
used while creating Client.
+ If not provided, the default one is global endpoint.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
@@ -66,6 +68,7 @@ class CloudRunCreateJobOperator(GoogleCloudBaseOperator):
job: dict | Job,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
+ use_regional_endpoint: bool = False,
**kwargs,
):
super().__init__(**kwargs)
@@ -75,14 +78,20 @@ class CloudRunCreateJobOperator(GoogleCloudBaseOperator):
self.job = job
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
+ self.use_regional_endpoint = use_regional_endpoint
def execute(self, context: Context):
hook: CloudRunHook = CloudRunHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
)
job = hook.create_job(
- job_name=self.job_name, job=self.job, region=self.region,
project_id=self.project_id
+ job_name=self.job_name,
+ job=self.job,
+ region=self.region,
+ project_id=self.project_id,
+ use_regional_endpoint=self.use_regional_endpoint,
)
+ self.log.info("Job created")
return Job.to_dict(job)
@@ -96,6 +105,8 @@ class CloudRunUpdateJobOperator(GoogleCloudBaseOperator):
:param job_name: Required. The name of the job to update.
:param job: Required. The job descriptor containing the new configuration
of the job to update.
The name field will be replaced by job_name
+ :param use_regional_endpoint: If set to True, regional endpoint will be
used while creating Client.
+ If not provided, the default one is global endpoint.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
@@ -116,6 +127,7 @@ class CloudRunUpdateJobOperator(GoogleCloudBaseOperator):
job_name: str,
job: dict | Job,
gcp_conn_id: str = "google_cloud_default",
+ use_regional_endpoint: bool = False,
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
):
@@ -125,6 +137,7 @@ class CloudRunUpdateJobOperator(GoogleCloudBaseOperator):
self.job_name = job_name
self.job = job
self.gcp_conn_id = gcp_conn_id
+ self.use_regional_endpoint = use_regional_endpoint
self.impersonation_chain = impersonation_chain
def execute(self, context: Context):
@@ -132,7 +145,11 @@ class CloudRunUpdateJobOperator(GoogleCloudBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
)
job = hook.update_job(
- job_name=self.job_name, job=self.job, region=self.region,
project_id=self.project_id
+ job_name=self.job_name,
+ job=self.job,
+ region=self.region,
+ project_id=self.project_id,
+ use_regional_endpoint=self.use_regional_endpoint,
)
return Job.to_dict(job)
@@ -146,6 +163,8 @@ class CloudRunDeleteJobOperator(GoogleCloudBaseOperator):
:param region: Required. The ID of the Google Cloud region that the
service belongs to.
:param job_name: Required. The name of the job to delete.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param use_regional_endpoint: If set to True, regional endpoint will be
used while creating Client.
+ If not provided, the default one is global endpoint.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
of the last account in the list, which will be impersonated in the
request.
@@ -164,6 +183,7 @@ class CloudRunDeleteJobOperator(GoogleCloudBaseOperator):
region: str,
job_name: str,
gcp_conn_id: str = "google_cloud_default",
+ use_regional_endpoint: bool = False,
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
):
@@ -172,13 +192,19 @@ class CloudRunDeleteJobOperator(GoogleCloudBaseOperator):
self.region = region
self.job_name = job_name
self.gcp_conn_id = gcp_conn_id
+ self.use_regional_endpoint = use_regional_endpoint
self.impersonation_chain = impersonation_chain
def execute(self, context: Context):
hook: CloudRunHook = CloudRunHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
)
- job = hook.delete_job(job_name=self.job_name, region=self.region,
project_id=self.project_id)
+ job = hook.delete_job(
+ job_name=self.job_name,
+ region=self.region,
+ project_id=self.project_id,
+ use_regional_endpoint=self.use_regional_endpoint,
+ )
return Job.to_dict(job)
@@ -193,6 +219,8 @@ class CloudRunListJobsOperator(GoogleCloudBaseOperator):
resources along with active ones.
:param limit: The number of jobs to list. If left empty,
all the jobs will be returned.
+ :param use_regional_endpoint: If set to True, regional endpoint will be
used while creating Client.
+ If not provided, the default one is global endpoint.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
@@ -217,6 +245,7 @@ class CloudRunListJobsOperator(GoogleCloudBaseOperator):
region: str,
show_deleted: bool = False,
limit: int | None = None,
+ use_regional_endpoint: bool = False,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
@@ -228,6 +257,7 @@ class CloudRunListJobsOperator(GoogleCloudBaseOperator):
self.impersonation_chain = impersonation_chain
self.show_deleted = show_deleted
self.limit = limit
+ self.use_regional_endpoint = use_regional_endpoint
if limit is not None and limit < 0:
raise AirflowException("The limit for the list jobs request should
be greater or equal to zero")
@@ -236,7 +266,11 @@ class CloudRunListJobsOperator(GoogleCloudBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
)
jobs = hook.list_jobs(
- region=self.region, project_id=self.project_id,
show_deleted=self.show_deleted, limit=self.limit
+ region=self.region,
+ project_id=self.project_id,
+ show_deleted=self.show_deleted,
+ limit=self.limit,
+ use_regional_endpoint=self.use_regional_endpoint,
)
return [Job.to_dict(job) for job in jobs]
@@ -254,6 +288,8 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
:param polling_period_seconds: Optional. Control the rate of the poll for
the result of deferrable run.
By default, the trigger will poll every 10 seconds.
:param timeout_seconds: Optional. The timeout for this request, in seconds.
+ :param use_regional_endpoint: If set to True, regional endpoint will be
used while creating Client.
+ If not provided, the default one is global endpoint.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
of the last account in the list, which will be impersonated in the
request.
@@ -289,6 +325,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
overrides: dict[str, Any] | None = None,
polling_period_seconds: float = 10,
timeout_seconds: float | None = None,
+ use_regional_endpoint: bool = False,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
@@ -305,6 +342,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
self.polling_period_seconds = polling_period_seconds
self.timeout_seconds = timeout_seconds
self.deferrable = deferrable
+ self.use_regional_endpoint = use_regional_endpoint
self.transport = transport
self.operation: operation.Operation | None = None
@@ -315,7 +353,11 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
transport=self.transport,
)
self.operation = hook.execute_job(
- region=self.region, project_id=self.project_id,
job_name=self.job_name, overrides=self.overrides
+ region=self.region,
+ project_id=self.project_id,
+ job_name=self.job_name,
+ overrides=self.overrides,
+ use_regional_endpoint=self.use_regional_endpoint,
)
if self.operation is None:
@@ -330,7 +372,12 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
if not self.deferrable:
result: Execution = self._wait_for_operation(self.operation)
self._fail_if_execution_failed(result)
- job = hook.get_job(job_name=result.job, region=self.region,
project_id=self.project_id)
+ job = hook.get_job(
+ job_name=result.job,
+ region=self.region,
+ project_id=self.project_id,
+ use_regional_endpoint=self.use_regional_endpoint,
+ )
return Job.to_dict(job)
self.defer(
trigger=CloudRunJobFinishedTrigger(
@@ -338,6 +385,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
job_name=self.job_name,
project_id=self.project_id,
location=self.region,
+ use_regional_endpoint=self.use_regional_endpoint,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_period_seconds=self.polling_period_seconds,
@@ -365,7 +413,12 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
transport=self.transport,
)
- job = hook.get_job(job_name=event["job_name"], region=self.region,
project_id=self.project_id)
+ job = hook.get_job(
+ job_name=event["job_name"],
+ region=self.region,
+ project_id=self.project_id,
+ use_regional_endpoint=self.use_regional_endpoint,
+ )
return Job.to_dict(job)
def _fail_if_execution_failed(self, execution: Execution):
@@ -395,6 +448,8 @@ class
CloudRunCreateServiceOperator(GoogleCloudBaseOperator):
:param region: Required. The ID of the Google Cloud region that the
service belongs to.
:param service_name: Required. The name of the service to create.
:param service: The service descriptor containing the configuration of the
service to submit.
+ :param use_regional_endpoint: If set to True, regional endpoint will be
used while creating Client.
+ If not provided, the default one is global endpoint.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
@@ -414,6 +469,7 @@ class
CloudRunCreateServiceOperator(GoogleCloudBaseOperator):
region: str,
service_name: str,
service: dict | Service,
+ use_regional_endpoint: bool = False,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
@@ -423,6 +479,7 @@ class
CloudRunCreateServiceOperator(GoogleCloudBaseOperator):
self.region = region
self.service = service
self.service_name = service_name
+ self.use_regional_endpoint = use_regional_endpoint
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self._validate_inputs()
@@ -446,6 +503,7 @@ class
CloudRunCreateServiceOperator(GoogleCloudBaseOperator):
service_name=self.service_name,
region=self.region,
project_id=self.project_id,
+ use_regional_endpoint=self.use_regional_endpoint,
)
except AlreadyExists:
self.log.info(
@@ -454,7 +512,10 @@ class
CloudRunCreateServiceOperator(GoogleCloudBaseOperator):
self.region,
)
service = hook.get_service(
- service_name=self.service_name, region=self.region,
project_id=self.project_id
+ service_name=self.service_name,
+ region=self.region,
+ project_id=self.project_id,
+ use_regional_endpoint=self.use_regional_endpoint,
)
return Service.to_dict(service)
except google.cloud.exceptions.GoogleCloudError as e:
@@ -471,6 +532,8 @@ class
CloudRunDeleteServiceOperator(GoogleCloudBaseOperator):
:param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
:param region: Required. The ID of the Google Cloud region that the
service belongs to.
:param service_name: Required. The name of the service to create.
+ :param use_regional_endpoint: If set to True, regional endpoint will be
used while creating Client.
+ If not provided, the default one is global endpoint.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
@@ -489,6 +552,7 @@ class
CloudRunDeleteServiceOperator(GoogleCloudBaseOperator):
project_id: str,
region: str,
service_name: str,
+ use_regional_endpoint: bool = False,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
@@ -497,6 +561,7 @@ class
CloudRunDeleteServiceOperator(GoogleCloudBaseOperator):
self.project_id = project_id
self.region = region
self.service_name = service_name
+ self.use_regional_endpoint = use_regional_endpoint
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self._validate_inputs()
@@ -519,6 +584,7 @@ class
CloudRunDeleteServiceOperator(GoogleCloudBaseOperator):
service_name=self.service_name,
region=self.region,
project_id=self.project_id,
+ use_regional_endpoint=self.use_regional_endpoint,
)
except google.cloud.exceptions.NotFound as e:
self.log.error("An error occurred. Not Found.")
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py
index 746ca9e8838..87f1d5f0d89 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py
@@ -19,15 +19,12 @@ from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, Sequence
from enum import Enum
-from typing import TYPE_CHECKING, Any, Literal
+from typing import Any, Literal
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.google.cloud.hooks.cloud_run import CloudRunAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
-if TYPE_CHECKING:
- from google.longrunning import operations_pb2
-
DEFAULT_BATCH_LOCATION = "us-central1"
@@ -48,6 +45,8 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
:param project_id: Required. the Google Cloud project ID in which the job
was started.
:param location: Optional. the location where job is executed.
If set to None then the value of DEFAULT_BATCH_LOCATION will be used.
+ :param use_regional_endpoint: If set to True, regional endpoint will be
used while creating Client.
+ If not provided, the default one is global endpoint.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional. Service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token of the last account
@@ -70,6 +69,7 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
job_name: str,
project_id: str | None,
location: str = DEFAULT_BATCH_LOCATION,
+ use_regional_endpoint: bool = False,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_period_seconds: float = 10,
@@ -85,6 +85,7 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
self.polling_period_seconds = polling_period_seconds
self.timeout = timeout
self.impersonation_chain = impersonation_chain
+ self.use_regional_endpoint = use_regional_endpoint
self.transport = transport
def serialize(self) -> tuple[str, dict[str, Any]]:
@@ -100,15 +101,20 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
"polling_period_seconds": self.polling_period_seconds,
"timeout": self.timeout,
"impersonation_chain": self.impersonation_chain,
+ "use_regional_endpoint": self.use_regional_endpoint,
"transport": self.transport,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]:
timeout = self.timeout
- hook = self._get_async_hook()
+ self.hook = self._get_async_hook()
while timeout is None or timeout > 0:
- operation: operations_pb2.Operation = await
hook.get_operation(self.operation_name)
+ operation = await self.hook.get_operation(
+ operation_name=self.operation_name,
+ location=self.location,
+ use_regional_endpoint=self.use_regional_endpoint,
+ )
if operation.done:
# An operation can only have one of those two combinations: if
it is failed, then
# the error field will be populated, else, then the response
field will be.
diff --git
a/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run.py
b/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run.py
index 789f128dbca..1d41b24f720 100644
--- a/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run.py
+++ b/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run.py
@@ -46,13 +46,13 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
DAG_ID = "cloud_run"
-region = "us-central1"
+region = "europe-west1"
job_name_prefix = "cloudrun-system-test-job"
-job1_name = f"{job_name_prefix}1-{ENV_ID}"
-job2_name = f"{job_name_prefix}2-{ENV_ID}"
-job3_name = f"{job_name_prefix}3-{ENV_ID}"
-job4_name = f"{job_name_prefix}4-{ENV_ID}"
-job5_name = f"{job_name_prefix}5-{ENV_ID}"
+job1_name = f"{job_name_prefix}1-{ENV_ID}".replace("_", "-")
+job2_name = f"{job_name_prefix}2-{ENV_ID}".replace("_", "-")
+job3_name = f"{job_name_prefix}3-{ENV_ID}".replace("_", "-")
+job4_name = f"{job_name_prefix}4-{ENV_ID}".replace("_", "-")
+job5_name = f"{job_name_prefix}5-{ENV_ID}".replace("_", "-")
create1_task_name = "create-job1"
create2_task_name = "create-job2"
@@ -234,6 +234,7 @@ with DAG(
region=region,
job_name=job1_name,
job=_create_job_instance(),
+ use_regional_endpoint=False,
dag=dag,
)
# [END howto_operator_cloud_run_create_job]
@@ -244,6 +245,7 @@ with DAG(
region=region,
job_name=job2_name,
job=_create_job_dict(),
+ use_regional_endpoint=False,
dag=dag,
)
@@ -253,6 +255,7 @@ with DAG(
region=region,
job_name=job3_name,
job=Job.to_dict(_create_job_instance()),
+ use_regional_endpoint=False,
dag=dag,
)
diff --git
a/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run_service.py
b/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run_service.py
index 8d87bcbc100..b1e098fc640 100644
---
a/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run_service.py
+++
b/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run_service.py
@@ -35,6 +35,8 @@ from airflow.providers.google.cloud.operators.cloud_run
import (
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+REGION = "europe-west1"
+SERVICE_NAME = f"cloudrun-service-{ENV_ID}".replace("_", "-")
# [START howto_operator_cloud_run_service_creation]
@@ -60,9 +62,10 @@ with DAG(
create_cloud_run_service = CloudRunCreateServiceOperator(
task_id="create-cloud-run-service",
project_id=PROJECT_ID,
- region="us-central1",
+ region=REGION,
service=_create_service(),
- service_name="cloudrun-system-test-service",
+ service_name=SERVICE_NAME,
+ use_regional_endpoint=False,
)
# [END howto_operator_cloud_run_create_service]
@@ -70,8 +73,9 @@ with DAG(
delete_cloud_run_service = CloudRunDeleteServiceOperator(
task_id="delete-cloud-run-service",
project_id=PROJECT_ID,
- region="us-central1",
- service_name="cloudrun-system-test-service",
+ region=REGION,
+ service_name=SERVICE_NAME,
+ use_regional_endpoint=False,
dag=dag,
)
# [END howto_operator_cloud_run_delete_service]
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py
index bf7ff461233..b5a691f7771 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py
@@ -18,8 +18,10 @@
from __future__ import annotations
from unittest import mock
+from unittest.mock import AsyncMock
import pytest
+from google.api_core.client_options import ClientOptions
from google.cloud.run_v2 import (
CreateJobRequest,
CreateServiceRequest,
@@ -32,6 +34,7 @@ from google.cloud.run_v2 import (
ListJobsRequest,
RunJobRequest,
Service,
+ ServicesAsyncClient,
UpdateJobRequest,
)
from google.longrunning import operations_pb2
@@ -51,6 +54,8 @@ REGION = "region1"
JOB_NAME = "job1"
SERVICE_NAME = "service1"
OPERATION_NAME = "operationname"
+USE_REGIONAL_ENDPOINT = True
+BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
@pytest.mark.db_test
@@ -64,6 +69,20 @@ class TestCloudRunHook:
cloud_run_hook.get_credentials = self.dummy_get_credentials
return cloud_run_hook
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_get_conn_regional_endpoint(self, mock_jobs_client_cls):
+ hook = CloudRunHook(gcp_conn_id="google_cloud_default")
+ hook.get_credentials = mock.MagicMock(return_value=mock.Mock())
+
+ location = "us-central1"
+ hook.get_conn(location=location,
use_regional_endpoint=USE_REGIONAL_ENDPOINT)
+ assert mock_jobs_client_cls.call_count == 1
+
+ _, kwargs = mock_jobs_client_cls.call_args
+ client_options = kwargs.get("client_options")
+ assert isinstance(client_options, ClientOptions)
+ assert client_options.api_endpoint ==
f"{location}-run.googleapis.com:443"
+
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
new=mock_base_gcp_hook_default_project_id,
@@ -72,7 +91,12 @@ class TestCloudRunHook:
def test_get_job(self, mock_batch_service_client, cloud_run_hook):
get_job_request =
GetJobRequest(name=f"projects/{PROJECT_ID}/locations/{REGION}/jobs/{JOB_NAME}")
- cloud_run_hook.get_job(job_name=JOB_NAME, region=REGION,
project_id=PROJECT_ID)
+ cloud_run_hook.get_job(
+ job_name=JOB_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
+ )
cloud_run_hook._client.get_job.assert_called_once_with(get_job_request)
@mock.patch(
@@ -88,7 +112,11 @@ class TestCloudRunHook:
update_request.job = job
cloud_run_hook.update_job(
- job=Job.to_dict(job), job_name=JOB_NAME, region=REGION,
project_id=PROJECT_ID
+ job=Job.to_dict(job),
+ job_name=JOB_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
)
cloud_run_hook._client.update_job.assert_called_once_with(update_request)
@@ -107,7 +135,11 @@ class TestCloudRunHook:
create_request.parent = f"projects/{PROJECT_ID}/locations/{REGION}"
cloud_run_hook.create_job(
- job=Job.to_dict(job), job_name=JOB_NAME, region=REGION,
project_id=PROJECT_ID
+ job=Job.to_dict(job),
+ job_name=JOB_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
)
cloud_run_hook._client.create_job.assert_called_once_with(create_request)
@@ -128,7 +160,11 @@ class TestCloudRunHook:
)
cloud_run_hook.execute_job(
- job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID,
overrides=overrides
+ job_name=JOB_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ overrides=overrides,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
)
cloud_run_hook._client.run_job.assert_called_once_with(request=run_job_request)
@@ -145,7 +181,9 @@ class TestCloudRunHook:
page = self._mock_pager(number_of_jobs)
mock_batch_service_client.return_value.list_jobs.return_value = page
- jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id)
+ jobs_list = cloud_run_hook.list_jobs(
+ region=region, project_id=project_id,
use_regional_endpoint=USE_REGIONAL_ENDPOINT
+ )
for i in range(number_of_jobs):
assert jobs_list[i].name == f"name{i}"
@@ -170,7 +208,12 @@ class TestCloudRunHook:
page = self._mock_pager(number_of_jobs)
mock_batch_service_client.return_value.list_jobs.return_value = page
- jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id, show_deleted=True)
+ jobs_list = cloud_run_hook.list_jobs(
+ region=region,
+ project_id=project_id,
+ show_deleted=True,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
+ )
for i in range(number_of_jobs):
assert jobs_list[i].name == f"name{i}"
@@ -196,7 +239,9 @@ class TestCloudRunHook:
page = self._mock_pager(number_of_jobs)
mock_batch_service_client.return_value.list_jobs.return_value = page
- jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id, limit=limit)
+ jobs_list = cloud_run_hook.list_jobs(
+ region=region, project_id=project_id, limit=limit,
use_regional_endpoint=USE_REGIONAL_ENDPOINT
+ )
assert len(jobs_list) == limit
for i in range(limit):
@@ -212,7 +257,9 @@ class TestCloudRunHook:
page = self._mock_pager(number_of_jobs)
mock_batch_service_client.return_value.list_jobs.return_value = page
- jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id, limit=limit)
+ jobs_list = cloud_run_hook.list_jobs(
+ region=region, project_id=project_id, limit=limit,
use_regional_endpoint=USE_REGIONAL_ENDPOINT
+ )
assert len(jobs_list) == 0
@@ -230,7 +277,9 @@ class TestCloudRunHook:
page = self._mock_pager(number_of_jobs)
mock_batch_service_client.return_value.list_jobs.return_value = page
- jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id, limit=limit)
+ jobs_list = cloud_run_hook.list_jobs(
+ region=region, project_id=project_id, limit=limit,
use_regional_endpoint=USE_REGIONAL_ENDPOINT
+ )
assert len(jobs_list) == number_of_jobs
for i in range(number_of_jobs):
@@ -251,13 +300,20 @@ class TestCloudRunHook:
mock_batch_service_client.return_value.list_jobs.return_value = page
with pytest.raises(expected_exception=AirflowException):
- cloud_run_hook.list_jobs(region=region, project_id=project_id,
limit=limit)
+ cloud_run_hook.list_jobs(
+ region=region, project_id=project_id, limit=limit,
use_regional_endpoint=USE_REGIONAL_ENDPOINT
+ )
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
def test_delete_job(self, mock_batch_service_client, cloud_run_hook):
delete_request =
DeleteJobRequest(name=f"projects/{PROJECT_ID}/locations/{REGION}/jobs/{JOB_NAME}")
- cloud_run_hook.delete_job(job_name=JOB_NAME, region=REGION,
project_id=PROJECT_ID)
+ cloud_run_hook.delete_job(
+ job_name=JOB_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
+ )
cloud_run_hook._client.delete_job.assert_called_once_with(delete_request)
@mock.patch(
@@ -269,7 +325,7 @@ class TestCloudRunHook:
"""Test that transport parameter is passed to JobsClient."""
hook = CloudRunHook(transport="rest")
hook.get_credentials = self.dummy_get_credentials
- hook.get_conn()
+ hook.get_conn(location=REGION,
use_regional_endpoint=USE_REGIONAL_ENDPOINT)
mock_jobs_client.assert_called_once()
call_kwargs = mock_jobs_client.call_args[1]
@@ -284,7 +340,7 @@ class TestCloudRunHook:
"""Test that transport is not passed to JobsClient when None."""
hook = CloudRunHook(transport=None)
hook.get_credentials = self.dummy_get_credentials
- hook.get_conn()
+ hook.get_conn(location=REGION,
use_regional_endpoint=USE_REGIONAL_ENDPOINT)
mock_jobs_client.assert_called_once()
call_kwargs = mock_jobs_client.call_args[1]
@@ -301,9 +357,12 @@ class TestCloudRunHook:
class TestCloudRunAsyncHook:
@pytest.mark.asyncio
async def test_get_operation(self):
+ region = "us-central1"
hook = CloudRunAsyncHook()
hook.get_conn = mock.AsyncMock()
- await hook.get_operation(operation_name=OPERATION_NAME)
+ await hook.get_operation(
+ operation_name=OPERATION_NAME, location=region,
use_regional_endpoint=USE_REGIONAL_ENDPOINT
+ )
hook.get_conn.return_value.get_operation.assert_called_once_with(
operations_pb2.GetOperationRequest(name=OPERATION_NAME),
timeout=120
)
@@ -318,7 +377,7 @@ class TestCloudRunAsyncHook:
mock_sync_hook.get_credentials.return_value = "credentials"
hook.get_sync_hook = mock.AsyncMock(return_value=mock_sync_hook)
- await hook.get_conn()
+ await hook.get_conn(location=REGION,
use_regional_endpoint=USE_REGIONAL_ENDPOINT)
mock_async_client.assert_called_once()
call_kwargs = mock_async_client.call_args[1]
@@ -333,7 +392,7 @@ class TestCloudRunAsyncHook:
mock_sync_hook.get_credentials.return_value = "credentials"
hook.get_sync_hook = mock.AsyncMock(return_value=mock_sync_hook)
- await hook.get_conn()
+ await hook.get_conn(location=REGION,
use_regional_endpoint=USE_REGIONAL_ENDPOINT)
mock_sync_client.assert_called_once()
call_kwargs = mock_sync_client.call_args[1]
@@ -350,7 +409,9 @@ class TestCloudRunAsyncHook:
mock_conn = mock.MagicMock(spec=JobsClient) # sync client
hook.get_conn = mock.AsyncMock(return_value=mock_conn)
- result = await hook.get_operation(operation_name=OPERATION_NAME)
+ result = await hook.get_operation(
+ operation_name=OPERATION_NAME, location=REGION,
use_regional_endpoint=USE_REGIONAL_ENDPOINT
+ )
mock_to_thread.assert_called_once_with(
mock_conn.get_operation,
@@ -367,8 +428,10 @@ class TestCloudRunServiceHook:
@pytest.fixture
def cloud_run_service_hook(self):
+ region = "us-central1"
cloud_run_service_hook = CloudRunServiceHook()
cloud_run_service_hook.get_credentials = self.dummy_get_credentials
+ cloud_run_service_hook.client_options =
ClientOptions(api_endpoint=f"{region}-run.googleapis.com:443")
return cloud_run_service_hook
@mock.patch(
@@ -381,7 +444,12 @@ class TestCloudRunServiceHook:
name=f"projects/{PROJECT_ID}/locations/{REGION}/services/{SERVICE_NAME}"
)
- cloud_run_service_hook.get_service(service_name=SERVICE_NAME,
region=REGION, project_id=PROJECT_ID)
+ cloud_run_service_hook.get_service(
+ service_name=SERVICE_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
+ )
cloud_run_service_hook._client.get_service.assert_called_once_with(get_service_request)
@mock.patch(
@@ -398,7 +466,11 @@ class TestCloudRunServiceHook:
create_request.parent = f"projects/{PROJECT_ID}/locations/{REGION}"
cloud_run_service_hook.create_service(
- service=service, service_name=SERVICE_NAME, region=REGION,
project_id=PROJECT_ID
+ service=service,
+ service_name=SERVICE_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
)
cloud_run_service_hook._client.create_service.assert_called_once_with(create_request)
@@ -412,7 +484,12 @@ class TestCloudRunServiceHook:
name=f"projects/{PROJECT_ID}/locations/{REGION}/services/{SERVICE_NAME}"
)
- cloud_run_service_hook.delete_service(service_name=SERVICE_NAME,
region=REGION, project_id=PROJECT_ID)
+ cloud_run_service_hook.delete_service(
+ service_name=SERVICE_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
+ )
cloud_run_service_hook._client.delete_service.assert_called_once_with(delete_request)
@@ -423,24 +500,30 @@ class TestCloudRunServiceAsyncHook:
def mock_service(self):
return mock.AsyncMock()
+ @pytest.fixture
+ def cloud_run_service_hook(self):
+ region = "us-central1"
+ cloud_run_service_hook = CloudRunServiceAsyncHook()
+ cloud_run_service_hook.get_credentials = self.dummy_get_credentials
+ cloud_run_service_hook.client_options =
ClientOptions(api_endpoint=f"{region}-run.googleapis.com:443")
+ return cloud_run_service_hook
+
@pytest.mark.asyncio
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__",
new=mock_base_gcp_hook_default_project_id,
)
-
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.ServicesAsyncClient")
- async def test_create_service(self, mock_client):
- mock_client.return_value = mock.MagicMock()
- mock_client.return_value.create_service = self.mock_service()
-
- hook = CloudRunServiceAsyncHook()
- hook.get_credentials = self.dummy_get_credentials
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.CloudRunServiceAsyncHook.get_conn")
+ async def test_create_service(self, mock_client, cloud_run_service_hook):
+ mock_env_client = AsyncMock(ServicesAsyncClient)
+ mock_client.return_value = mock_env_client
- await hook.create_service(
+ await cloud_run_service_hook.create_service(
service_name=SERVICE_NAME,
service=Service(),
region=REGION,
project_id=PROJECT_ID,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
)
expected_request = CreateServiceRequest(
@@ -456,18 +539,16 @@ class TestCloudRunServiceAsyncHook:
"airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__",
new=mock_base_gcp_hook_default_project_id,
)
-
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.ServicesAsyncClient")
- async def test_delete_service(self, mock_client):
- mock_client.return_value = mock.MagicMock()
- mock_client.return_value.delete_service = self.mock_service()
-
- hook = CloudRunServiceAsyncHook()
- hook.get_credentials = self.dummy_get_credentials
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.CloudRunServiceAsyncHook.get_conn")
+ async def test_delete_service(self, mock_client, cloud_run_service_hook):
+ mock_env_client = AsyncMock(ServicesAsyncClient)
+ mock_client.return_value = mock_env_client
- await hook.delete_service(
+ await cloud_run_service_hook.delete_service(
service_name=SERVICE_NAME,
region=REGION,
project_id=PROJECT_ID,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
)
expected_request = DeleteServiceRequest(
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py
b/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py
index 3c389713b1a..b77d4199d6f 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py
@@ -87,7 +87,11 @@ class TestCloudRunCreateJobOperator:
operator.execute(context=mock.MagicMock())
hook_mock.return_value.create_job.assert_called_once_with(
- job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, job=JOB
+ job_name=JOB_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ job=JOB,
+ use_regional_endpoint=False,
)
@@ -137,11 +141,18 @@ class TestCloudRunExecuteJobOperator:
operator.execute(context=mock.MagicMock())
hook_mock.return_value.get_job.assert_called_once_with(
- job_name=mock.ANY, region=REGION, project_id=PROJECT_ID
+ job_name=mock.ANY,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
hook_mock.return_value.execute_job.assert_called_once_with(
- job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID,
overrides=None
+ job_name=JOB_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ overrides=None,
+ use_regional_endpoint=False,
)
@mock.patch(CLOUD_RUN_HOOK_PATH)
@@ -254,7 +265,10 @@ class TestCloudRunExecuteJobOperator:
result = operator.execute_complete(mock.MagicMock(), event)
hook_mock.return_value.get_job.assert_called_once_with(
- job_name=mock.ANY, region=REGION, project_id=PROJECT_ID
+ job_name=mock.ANY,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
assert result["name"] == JOB_NAME
@@ -276,11 +290,18 @@ class TestCloudRunExecuteJobOperator:
operator.execute(context=mock.MagicMock())
hook_mock.return_value.get_job.assert_called_once_with(
- job_name=mock.ANY, region=REGION, project_id=PROJECT_ID
+ job_name=mock.ANY,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
hook_mock.return_value.execute_job.assert_called_once_with(
- job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID,
overrides=overrides
+ job_name=JOB_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ overrides=overrides,
+ use_regional_endpoint=False,
)
@mock.patch(CLOUD_RUN_HOOK_PATH)
@@ -363,7 +384,10 @@ class TestCloudRunDeleteJobOperator:
assert deleted_job["name"] == JOB.name
hook_mock.return_value.delete_job.assert_called_once_with(
- job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID
+ job_name=JOB_NAME,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
@@ -389,7 +413,11 @@ class TestCloudRunUpdateJobOperator:
assert updated_job["name"] == JOB.name
hook_mock.return_value.update_job.assert_called_once_with(
- job_name=JOB_NAME, job=JOB, region=REGION, project_id=PROJECT_ID
+ job_name=JOB_NAME,
+ job=JOB,
+ region=REGION,
+ project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
@@ -412,7 +440,11 @@ class TestCloudRunListJobsOperator:
operator.execute(context=mock.MagicMock())
hook_mock.return_value.list_jobs.assert_called_once_with(
- region=REGION, project_id=PROJECT_ID, limit=limit,
show_deleted=show_deleted
+ region=REGION,
+ project_id=PROJECT_ID,
+ limit=limit,
+ show_deleted=show_deleted,
+ use_regional_endpoint=False,
)
@mock.patch(CLOUD_RUN_HOOK_PATH)
@@ -454,6 +486,7 @@ class TestCloudRunCreateServiceOperator:
service_name=SERVICE_NAME,
region=REGION,
project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
@mock.patch(CLOUD_RUN_SERVICE_HOOK_PATH)
@@ -476,11 +509,13 @@ class TestCloudRunCreateServiceOperator:
service_name=SERVICE_NAME,
region=REGION,
project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
hook_mock.return_value.get_service.assert_called_once_with(
service_name=SERVICE_NAME,
region=REGION,
project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
@mock.patch(CLOUD_RUN_SERVICE_HOOK_PATH)
@@ -506,6 +541,7 @@ class TestCloudRunCreateServiceOperator:
service_name=SERVICE_NAME,
region=REGION,
project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
@@ -538,4 +574,5 @@ class TestCloudRunDeleteServiceOperator:
service_name=SERVICE_NAME,
region=REGION,
project_id=PROJECT_ID,
+ use_regional_endpoint=False,
)
diff --git
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py
index 3902a17885e..a906a7d0332 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py
@@ -36,6 +36,7 @@ GCP_CONNECTION_ID = "gcp_connection_id"
POLL_SLEEP = 0.01
TIMEOUT = 0.02
IMPERSONATION_CHAIN = "impersonation_chain"
+USE_REGIONAL_ENDPOINT = True
@pytest.fixture
@@ -45,6 +46,7 @@ def trigger():
job_name=JOB_NAME,
project_id=PROJECT_ID,
location=LOCATION,
+ use_regional_endpoint=USE_REGIONAL_ENDPOINT,
gcp_conn_id=GCP_CONNECTION_ID,
polling_period_seconds=POLL_SLEEP,
timeout=TIMEOUT,
@@ -65,6 +67,7 @@ class TestCloudBatchJobFinishedTrigger:
"gcp_conn_id": GCP_CONNECTION_ID,
"polling_period_seconds": POLL_SLEEP,
"timeout": TIMEOUT,
+ "use_regional_endpoint": USE_REGIONAL_ENDPOINT,
"impersonation_chain": IMPERSONATION_CHAIN,
"transport": None,
}
@@ -78,7 +81,7 @@ class TestCloudBatchJobFinishedTrigger:
Tests the CloudRunJobFinishedTrigger fires once the job execution
reaches a successful state.
"""
- async def _mock_operation(name):
+ async def _mock_operation(operation_name, location,
use_regional_endpoint):
operation = mock.MagicMock()
operation.done = True
operation.name = "name"
@@ -108,7 +111,7 @@ class TestCloudBatchJobFinishedTrigger:
Tests the CloudRunJobFinishedTrigger raises an exception once the job
execution fails.
"""
- async def _mock_operation(name):
+ async def _mock_operation(operation_name, location,
use_regional_endpoint):
operation = mock.MagicMock()
operation.done = True
operation.name = "name"
@@ -138,7 +141,7 @@ class TestCloudBatchJobFinishedTrigger:
Tests the CloudRunJobFinishedTrigger fires once the job execution
times out with an error message.
"""
- async def _mock_operation(name):
+ async def _mock_operation(operation_name, location,
use_regional_endpoint):
operation = mock.MagicMock()
operation.done = False
operation.error = mock.MagicMock()