This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c6a014a370 Add `CloudBatchHook` and operators (#32606)
c6a014a370 is described below

commit c6a014a3707d2e4a5a9d2fe0b4277be09266b63b
Author: Freddy Demiane <[email protected]>
AuthorDate: Fri Aug 18 21:00:08 2023 +0200

    Add `CloudBatchHook` and operators (#32606)
---
 .../providers/google/cloud/hooks/cloud_batch.py    | 204 ++++++++++++
 .../google/cloud/operators/cloud_batch.py          | 298 ++++++++++++++++++
 .../providers/google/cloud/triggers/cloud_batch.py | 156 ++++++++++
 airflow/providers/google/provider.yaml             |  16 +
 .../operators/cloud/cloud_batch.rst                | 108 +++++++
 generated/provider_dependencies.json               |   1 +
 .../google/cloud/hooks/test_cloud_batch.py         | 343 +++++++++++++++++++++
 .../google/cloud/operators/test_cloud_batch.py     | 190 ++++++++++++
 .../google/cloud/triggers/test_cloud_batch.py      | 160 ++++++++++
 .../providers/google/cloud/cloud_batch/__init__.py |  16 +
 .../cloud/cloud_batch/example_cloud_batch.py       | 202 ++++++++++++
 11 files changed, 1694 insertions(+)

diff --git a/airflow/providers/google/cloud/hooks/cloud_batch.py 
b/airflow/providers/google/cloud/hooks/cloud_batch.py
new file mode 100644
index 0000000000..f85283047d
--- /dev/null
+++ b/airflow/providers/google/cloud/hooks/cloud_batch.py
@@ -0,0 +1,204 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import itertools
+import json
+from time import sleep
+from typing import Iterable, Sequence
+
+from google.api_core import operation  # type: ignore
+from google.cloud.batch import ListJobsRequest, ListTasksRequest
+from google.cloud.batch_v1 import (
+    BatchServiceAsyncClient,
+    BatchServiceClient,
+    CreateJobRequest,
+    Job,
+    JobStatus,
+    Task,
+)
+from google.cloud.batch_v1.services.batch_service import pagers
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.common.consts import CLIENT_INFO
+from airflow.providers.google.common.hooks.base_google import 
PROVIDE_PROJECT_ID, GoogleBaseHook
+
+
+class CloudBatchHook(GoogleBaseHook):
+    """
+    Hook for the Google Cloud Batch service.
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account.
+    """
+
+    def __init__(
+        self,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+    ) -> None:
+        super().__init__(gcp_conn_id=gcp_conn_id, 
impersonation_chain=impersonation_chain)
+        self._client: BatchServiceClient | None = None
+
+    def get_conn(self):
+        """
+        Retrieves connection to GCE Batch.
+
+        :return: Google Batch Service client object.
+        """
+        if self._client is None:
+            self._client = 
BatchServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
+        return self._client
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def submit_batch_job(
+        self, job_name: str, job: Job, region: str, project_id: str = 
PROVIDE_PROJECT_ID
+    ) -> Job:
+        if isinstance(job, dict):
+            job = Job.from_json(json.dumps(job))
+
+        create_request = CreateJobRequest()
+        create_request.job = job
+        create_request.job_id = job_name
+        create_request.parent = f"projects/{project_id}/locations/{region}"
+
+        return self.get_conn().create_job(create_request)
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def delete_job(
+        self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID
+    ) -> operation.Operation:
+        return 
self.get_conn().delete_job(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def list_jobs(
+        self,
+        region: str,
+        project_id: str = PROVIDE_PROJECT_ID,
+        filter: str | None = None,
+        limit: int | None = None,
+    ) -> Iterable[Job]:
+
+        if limit is not None and limit < 0:
+            raise AirflowException("The limit for the list jobs request should 
be greater or equal to zero")
+
+        list_jobs_request: ListJobsRequest = ListJobsRequest(
+            parent=f"projects/{project_id}/locations/{region}", filter=filter
+        )
+
+        jobs: pagers.ListJobsPager = 
self.get_conn().list_jobs(request=list_jobs_request)
+
+        return list(itertools.islice(jobs, limit))
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def list_tasks(
+        self,
+        region: str,
+        job_name: str,
+        project_id: str = PROVIDE_PROJECT_ID,
+        group_name: str = "group0",
+        filter: str | None = None,
+        limit: int | None = None,
+    ) -> Iterable[Task]:
+
+        if limit is not None and limit < 0:
+            raise AirflowException("The limit for the list tasks request 
should be greater or equal to zero")
+
+        list_tasks_request: ListTasksRequest = ListTasksRequest(
+            
parent=f"projects/{project_id}/locations/{region}/jobs/{job_name}/taskGroups/{group_name}",
+            filter=filter,
+        )
+
+        tasks: pagers.ListTasksPager = 
self.get_conn().list_tasks(request=list_tasks_request)
+
+        return list(itertools.islice(tasks, limit))
+
+    def wait_for_job(
+        self, job_name: str, polling_period_seconds: float = 10, timeout: 
float | None = None
+    ) -> Job:
+        client = self.get_conn()
+        while timeout is None or timeout > 0:
+            try:
+                job = client.get_job(name=f"{job_name}")
+                status: JobStatus.State = job.status.state
+                if (
+                    status == JobStatus.State.SUCCEEDED
+                    or status == JobStatus.State.FAILED
+                    or status == JobStatus.State.DELETION_IN_PROGRESS
+                ):
+                    return job
+                else:
+                    sleep(polling_period_seconds)
+            except Exception as e:
+                self.log.exception("Exception occurred while checking for job 
completion.")
+                raise e
+
+            if timeout is not None:
+                timeout -= polling_period_seconds
+
+        raise AirflowException(f"Job with name [{job_name}] timed out")
+
+    def get_job(self, job_name) -> Job:
+        return self.get_conn().get_job(name=job_name)
+
+
+class CloudBatchAsyncHook(GoogleBaseHook):
+    """
+    Async hook for the Google Cloud Batch service.
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account.
+    """
+
+    def __init__(
+        self,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+    ):
+
+        self._client: BatchServiceAsyncClient | None = None
+        super().__init__(gcp_conn_id=gcp_conn_id, 
impersonation_chain=impersonation_chain)
+
+    def get_conn(self):
+        if self._client is None:
+            self._client = BatchServiceAsyncClient(
+                credentials=self.get_credentials(), client_info=CLIENT_INFO
+            )
+
+        return self._client
+
+    async def get_batch_job(
+        self,
+        job_name: str,
+    ) -> Job:
+        client = self.get_conn()
+        return await client.get_job(name=f"{job_name}")
diff --git a/airflow/providers/google/cloud/operators/cloud_batch.py 
b/airflow/providers/google/cloud/operators/cloud_batch.py
new file mode 100644
index 0000000000..26c0af06c1
--- /dev/null
+++ b/airflow/providers/google/cloud/operators/cloud_batch.py
@@ -0,0 +1,298 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Sequence
+
+from google.api_core import operation  # type: ignore
+from google.cloud.batch_v1 import Job, Task
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.cloud_batch import CloudBatchHook
+from airflow.providers.google.cloud.operators.cloud_base import 
GoogleCloudBaseOperator
+from airflow.providers.google.cloud.triggers.cloud_batch import 
CloudBatchJobFinishedTrigger
+from airflow.utils.context import Context
+
+
+class CloudBatchSubmitJobOperator(GoogleCloudBaseOperator):
+    """
+    Submit a job and wait for its completion.
+
+    :param project_id: Required. The ID of the Google Cloud project that the 
service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+    :param job_name: Required. The name of the job to create.
+    :param job: Required. The job descriptor containing the configuration of 
the job to submit.
+    :param polling_period_seconds: Optional: Control the rate of the poll for 
the result of deferrable run.
+        By default, the trigger will poll every 10 seconds.
+    :param timeout: The timeout for this request.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    :param deferrable: Run operator in the deferrable mode
+
+    """
+
+    template_fields = ("project_id", "region", "gcp_conn_id", 
"impersonation_chain", "job_name")
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        job_name: str,
+        job: dict | Job,
+        polling_period_seconds: float = 10,
+        timeout_seconds: float | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.job_name = job_name
+        self.job = job
+        self.polling_period_seconds = polling_period_seconds
+        self.timeout_seconds = timeout_seconds
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.deferrable = deferrable
+        self.polling_period_seconds = polling_period_seconds
+
+    def execute(self, context: Context):
+        hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, 
self.impersonation_chain)
+        job = hook.submit_batch_job(job_name=self.job_name, job=self.job, 
region=self.region)
+
+        if not self.deferrable:
+            completed_job = hook.wait_for_job(
+                job_name=job.name,
+                polling_period_seconds=self.polling_period_seconds,
+                timeout=self.timeout_seconds,
+            )
+
+            return Job.to_dict(completed_job)
+
+        else:
+            self.defer(
+                trigger=CloudBatchJobFinishedTrigger(
+                    job_name=job.name,
+                    project_id=self.project_id,
+                    gcp_conn_id=self.gcp_conn_id,
+                    impersonation_chain=self.impersonation_chain,
+                    location=self.region,
+                    polling_period_seconds=self.polling_period_seconds,
+                    timeout=self.timeout_seconds,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context: Context, event: dict):
+        job_status = event["status"]
+        if job_status == "success":
+            hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, 
self.impersonation_chain)
+            job = hook.get_job(job_name=event["job_name"])
+            return Job.to_dict(job)
+        else:
+            raise AirflowException(f"Unexpected error in the operation: 
{event['message']}")
+
+
+class CloudBatchDeleteJobOperator(GoogleCloudBaseOperator):
+    """
+    Deletes a job and wait for the operation to be completed.
+
+    :param project_id: Required. The ID of the Google Cloud project that the 
service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+    :param job_name: Required. The name of the job to be deleted.
+    :param timeout: The timeout for this request.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+
+    """
+
+    template_fields = ("project_id", "region", "gcp_conn_id", 
"impersonation_chain", "job_name")
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        job_name: str,
+        timeout: float | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.job_name = job_name
+        self.timeout = timeout
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context: Context):
+        hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, 
self.impersonation_chain)
+
+        operation = hook.delete_job(job_name=self.job_name, 
region=self.region, project_id=self.project_id)
+
+        self._wait_for_operation(operation)
+
+    def _wait_for_operation(self, operation: operation.Operation):
+        try:
+            return operation.result(timeout=self.timeout)
+        except Exception:
+            error = operation.exception(timeout=self.timeout)
+            raise AirflowException(error)
+
+
+class CloudBatchListJobsOperator(GoogleCloudBaseOperator):
+    """
+    List Cloud Batch jobs.
+
+    :param project_id: Required. The ID of the Google Cloud project that the 
service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param filter: The filter based on which to list the jobs. If left empty, 
all the jobs are listed.
+    :param limit: The number of jobs to list. If left empty,
+        all the jobs matching the filter will be returned.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+
+    """
+
+    template_fields = (
+        "project_id",
+        "region",
+        "gcp_conn_id",
+        "impersonation_chain",
+    )
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        gcp_conn_id: str = "google_cloud_default",
+        filter: str | None = None,
+        limit: int | None = None,
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.filter = filter
+        self.limit = limit
+        if limit is not None and limit < 0:
+            raise AirflowException("The limit for the list jobs request should 
be greater or equal to zero")
+
+    def execute(self, context: Context):
+        hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, 
self.impersonation_chain)
+
+        jobs_list = hook.list_jobs(
+            region=self.region, project_id=self.project_id, 
filter=self.filter, limit=self.limit
+        )
+
+        return [Job.to_dict(job) for job in jobs_list]
+
+
+class CloudBatchListTasksOperator(GoogleCloudBaseOperator):
+    """
+    List Cloud Batch tasks for a given job.
+
+    :param project_id: Required. The ID of the Google Cloud project that the 
service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+    :param job_name: Required. The name of the job for which to list tasks.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param filter: The filter based on which to list the jobs. If left empty, 
all the jobs are listed.
+    :param group_name: The name of the group that owns the task. By default 
it's `group0`.
+    :param limit: The number of tasks to list.
+        If left empty, all the tasks matching the filter will be returned.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+
+    """
+
+    template_fields = ("project_id", "region", "job_name", "gcp_conn_id", 
"impersonation_chain", "group_name")
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        job_name: str,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        group_name: str = "group0",
+        filter: str | None = None,
+        limit: int | None = None,
+        **kwargs,
+    ) -> None:
+
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.job_name = job_name
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.group_name = group_name
+        self.filter = filter
+        self.limit = limit
+        if limit is not None and limit < 0:
+            raise AirflowException("The limit for the list jobs request should 
be greater or equal to zero")
+
+    def execute(self, context: Context):
+        hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, 
self.impersonation_chain)
+
+        tasks_list = hook.list_tasks(
+            region=self.region,
+            project_id=self.project_id,
+            job_name=self.job_name,
+            group_name=self.group_name,
+            filter=self.filter,
+            limit=self.limit,
+        )
+
+        return [Task.to_dict(task) for task in tasks_list]
diff --git a/airflow/providers/google/cloud/triggers/cloud_batch.py 
b/airflow/providers/google/cloud/triggers/cloud_batch.py
new file mode 100644
index 0000000000..211e436c95
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/cloud_batch.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from typing import Any, AsyncIterator, Sequence
+
+from google.cloud.batch_v1 import Job, JobStatus
+
+from airflow.providers.google.cloud.hooks.cloud_batch import 
CloudBatchAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+DEFAULT_BATCH_LOCATION = "us-central1"
+
+
+class CloudBatchJobFinishedTrigger(BaseTrigger):
+    """Cloud Batch trigger to check if templated job has been finished.
+
+    :param job_name: Required. Name of the job.
+    :param project_id: Required. the Google Cloud project ID in which the job 
was started.
+    :param location: Optional. the location where job is executed. If set to 
None then
+        the value of DEFAULT_BATCH_LOCATION will be used
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional. Service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    :param poll_sleep: Polling period in seconds to check for the status
+
+    """
+
+    def __init__(
+        self,
+        job_name: str,
+        project_id: str | None,
+        location: str = DEFAULT_BATCH_LOCATION,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        polling_period_seconds: float = 10,
+        timeout: float | None = None,
+    ):
+        super().__init__()
+        self.project_id = project_id
+        self.job_name = job_name
+        self.location = location
+        self.gcp_conn_id = gcp_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.timeout = timeout
+        self.impersonation_chain = impersonation_chain
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes class arguments and classpath."""
+        return (
+            
"airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchJobFinishedTrigger",
+            {
+                "project_id": self.project_id,
+                "job_name": self.job_name,
+                "location": self.location,
+                "gcp_conn_id": self.gcp_conn_id,
+                "polling_period_seconds": self.polling_period_seconds,
+                "timeout": self.timeout,
+                "impersonation_chain": self.impersonation_chain,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """
+        Main loop of the class in where it is fetching the job status and 
yields certain Event.
+
+        If the job has status success then it yields TriggerEvent with success 
status, if job has
+        status failed - with error status and if the job is being deleted - 
with deleted status.
+        In any other case Trigger will wait for specified amount of time
+        stored in self.polling_period_seconds variable.
+        """
+        timeout = self.timeout
+        hook = self._get_async_hook()
+        while timeout is None or timeout > 0:
+
+            try:
+                job: Job = await hook.get_batch_job(job_name=self.job_name)
+
+                status: JobStatus.State = job.status.state
+                if status == JobStatus.State.SUCCEEDED:
+                    yield TriggerEvent(
+                        {
+                            "job_name": self.job_name,
+                            "status": "success",
+                            "message": "Job completed",
+                        }
+                    )
+                    return
+                elif status == JobStatus.State.FAILED:
+                    yield TriggerEvent(
+                        {
+                            "job_name": self.job_name,
+                            "status": "error",
+                            "message": f"Batch job with name {self.job_name} 
has failed its execution",
+                        }
+                    )
+                    return
+                elif status == JobStatus.State.DELETION_IN_PROGRESS:
+                    yield TriggerEvent(
+                        {
+                            "job_name": self.job_name,
+                            "status": "deleted",
+                            "message": f"Batch job with name {self.job_name} 
is being deleted",
+                        }
+                    )
+                    return
+                else:
+                    self.log.info("Current job status is: %s", status)
+                    self.log.info("Sleeping for %s seconds.", 
self.polling_period_seconds)
+                    if timeout is not None:
+                        timeout -= self.polling_period_seconds
+
+                    if timeout is None or timeout > 0:
+                        await asyncio.sleep(self.polling_period_seconds)
+
+            except Exception as e:
+                self.log.exception("Exception occurred while checking for job 
completion.")
+                yield TriggerEvent({"status": "error", "message": str(e)})
+                return
+
+        self.log.exception(f"Job with name [{self.job_name}] timed out")
+        yield TriggerEvent(
+            {
+                "job_name": self.job_name,
+                "status": "timed out",
+                "message": f"Batch job with name {self.job_name} timed out",
+            }
+        )
+        return
+
+    def _get_async_hook(self) -> CloudBatchAsyncHook:
+        return CloudBatchAsyncHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
diff --git a/airflow/providers/google/provider.yaml 
b/airflow/providers/google/provider.yaml
index f3ec0ecee9..29fd2c8073 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -117,6 +117,7 @@ dependencies:
   - google-cloud-videointelligence>=2.11.0
   - google-cloud-vision>=3.4.0
   - google-cloud-workflows>=1.10.0
+  - google-cloud-batch>=0.13.0
   - grpcio-gcp>=0.2.2
   - httpx
   - json-merge-patch>=0.2
@@ -182,6 +183,11 @@ integrations:
     how-to-guide:
       - 
/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
     tags: [google]
+  - integration-name: Google Cloud Batch
+    external-doc-url: https://cloud.google.com/batch
+    how-to-guide:
+      - /docs/apache-airflow-providers-google/operators/cloud/cloud_batch.rst
+    tags: [google]
   - integration-name: Google Cloud Dataform
     external-doc-url: https://cloud.google.com/dataform/
     how-to-guide:
@@ -611,6 +617,10 @@ operators:
   - integration-name: Google Cloud Dataform
     python-modules:
       - airflow.providers.google.cloud.operators.dataform
+  - integration-name: Google Cloud Batch
+    python-modules:
+      - airflow.providers.google.cloud.operators.cloud_batch
+
 
 sensors:
   - integration-name: Google BigQuery
@@ -850,6 +860,9 @@ hooks:
   - integration-name: Google Cloud Dataform
     python-modules:
       - airflow.providers.google.cloud.hooks.dataform
+  - integration-name: Google Cloud Batch
+    python-modules:
+      - airflow.providers.google.cloud.hooks.cloud_batch
 
 triggers:
   - integration-name: Google BigQuery Data Transfer Service
@@ -891,6 +904,9 @@ triggers:
   - integration-name: Google Cloud Pub/Sub
     python-modules:
       - airflow.providers.google.cloud.triggers.pubsub
+  - integration-name: Google Cloud
+    python-modules:
+      - airflow.providers.google.cloud.triggers.cloud_batch
 
 transfers:
   - source-integration-name: Presto
diff --git 
a/docs/apache-airflow-providers-google/operators/cloud/cloud_batch.rst 
b/docs/apache-airflow-providers-google/operators/cloud/cloud_batch.rst
new file mode 100644
index 0000000000..2254ead25b
--- /dev/null
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_batch.rst
@@ -0,0 +1,108 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+Google Cloud Batch Operators
+===============================
+
+Cloud Batch is a fully managed batch service to schedule, queue, and execute 
batch jobs on Google's infrastructure.
+
+For more information about the service visit `Google Cloud Batch documentation 
<https://cloud.google.com/batch>`__.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+.. include:: /operators/_partials/prerequisite_tasks.rst
+
+Submit a job
+---------------------
+
+Before you submit a job in Cloud Batch, you need to define it.
+For more information about the Job object fields, visit `Google Cloud Batch 
Job description 
<https://cloud.google.com/python/docs/reference/batch/latest/google.cloud.batch_v1.types.Job>`__.
+
+A simple job configuration can look as follows:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+    :language: python
+    :dedent: 0
+    :start-after: [START howto_operator_batch_job_creation]
+    :end-before: [END howto_operator_batch_job_creation]
+
+With this configuration we can submit the job:
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchSubmitJobOperator`
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_batch_submit_job]
+    :end-before: [END howto_operator_batch_submit_job]
+
+or you can define the same operator in the deferrable mode:
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchSubmitJobOperator`
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_batch_submit_job_deferrable_mode]
+    :end-before: [END howto_operator_batch_submit_job_deferrable_mode]
+
+Note that this operator waits for the job complete its execution, and the 
Job's dictionary representation is pushed to XCom.
+
+List a job's tasks
+------------------
+
+To list the tasks of a certain job, you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchListTasksOperator`
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_batch_list_tasks]
+    :end-before: [END howto_operator_batch_list_tasks]
+
+The operator takes two optional parameters: "limit" to limit the number of 
tasks returned, and "filter" to only list the tasks matching the `filter 
<https://cloud.google.com/sdk/gcloud/reference/topic/filters>`__.
+
+List jobs
+----------------------
+
+To list the jobs, you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchListJobsOperator`
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_batch_list_jobs]
+    :end-before: [END howto_operator_batch_list_jobs]
+
+The operator takes two optional parameters: "limit" to limit the number of 
tasks returned, and "filter" to only list the tasks matching the `filter 
<https://cloud.google.com/sdk/gcloud/reference/topic/filters>`__.
+
+Delete a job
+-----------------
+
+To delete a job you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchDeleteJobOperator`
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_delete_job]
+    :end-before: [END howto_operator_delete_job]
+
+
+Note that this operator waits for the job to be deleted, and the deleted Job's 
dictionary representation is pushed to XCom.
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index ee7a382267..b96c6baf1c 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -420,6 +420,7 @@
       "google-auth>=1.0.0",
       "google-cloud-aiplatform>=1.22.1",
       "google-cloud-automl>=2.11.0",
+      "google-cloud-batch>=0.13.0",
       "google-cloud-bigquery-datatransfer>=3.11.0",
       "google-cloud-bigtable>=2.17.0",
       "google-cloud-build>=3.13.0",
diff --git a/tests/providers/google/cloud/hooks/test_cloud_batch.py 
b/tests/providers/google/cloud/hooks/test_cloud_batch.py
new file mode 100644
index 0000000000..d4c000067f
--- /dev/null
+++ b/tests/providers/google/cloud/hooks/test_cloud_batch.py
@@ -0,0 +1,343 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from google.cloud.batch import ListJobsRequest
+from google.cloud.batch_v1 import CreateJobRequest, Job, JobStatus
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.cloud_batch import 
CloudBatchAsyncHook, CloudBatchHook
+from tests.providers.google.cloud.utils.base_gcp_mock import 
mock_base_gcp_hook_default_project_id
+
+
+class TestCloudBathHook:
+    def dummy_get_credentials(self):
+        pass
+
+    @pytest.fixture
+    def cloud_batch_hook(self):
+        cloud_batch_hook = CloudBatchHook()
+        cloud_batch_hook.get_credentials = self.dummy_get_credentials
+        return cloud_batch_hook
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_submit(self, mock_batch_service_client, cloud_batch_hook):
+        job = Job()
+        job_name = "jobname"
+        project_id = "test_project_id"
+        region = "us-central1"
+
+        cloud_batch_hook.submit_batch_job(
+            job_name=job_name, job=Job.to_dict(job), region=region, 
project_id=project_id
+        )
+
+        create_request = CreateJobRequest()
+        create_request.job = job
+        create_request.job_id = job_name
+        create_request.parent = f"projects/{project_id}/locations/{region}"
+
+        cloud_batch_hook._client.create_job.assert_called_with(create_request)
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_get_job(self, mock_batch_service_client, cloud_batch_hook):
+        cloud_batch_hook.get_job("job1")
+        cloud_batch_hook._client.get_job.assert_called_once_with(name="job1")
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_delete_job(self, mock_batch_service_client, cloud_batch_hook):
+        job_name = "job1"
+        region = "us-east1"
+        project_id = "test_project_id"
+        cloud_batch_hook.delete_job(job_name=job_name, region=region, 
project_id=project_id)
+        cloud_batch_hook._client.delete_job.assert_called_once_with(
+            name=f"projects/{project_id}/locations/{region}/jobs/{job_name}"
+        )
+
+    @pytest.mark.parametrize(
+        "state", [JobStatus.State.SUCCEEDED, JobStatus.State.FAILED, 
JobStatus.State.DELETION_IN_PROGRESS]
+    )
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_wait_job_succeeded(self, mock_batch_service_client, state, 
cloud_batch_hook):
+        mock_job = self._mock_job_with_status(state)
+        mock_batch_service_client.return_value.get_job.return_value = mock_job
+        actual_job = cloud_batch_hook.wait_for_job("job1")
+        assert actual_job == mock_job
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_wait_job_timeout(self, mock_batch_service_client, 
cloud_batch_hook):
+        mock_job = self._mock_job_with_status(JobStatus.State.RUNNING)
+        mock_batch_service_client.return_value.get_job.return_value = mock_job
+
+        exception_caught = False
+        try:
+            cloud_batch_hook.wait_for_job("job1", polling_period_seconds=0.01, 
timeout=0.02)
+        except AirflowException:
+            exception_caught = True
+
+        assert exception_caught
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_list_jobs(self, mock_batch_service_client, cloud_batch_hook):
+
+        number_of_jobs = 3
+        region = "us-central1"
+        project_id = "test_project_id"
+        filter = "filter_description"
+
+        page = self._mock_pager(number_of_jobs)
+        mock_batch_service_client.return_value.list_jobs.return_value = page
+
+        jobs_list = cloud_batch_hook.list_jobs(region=region, 
project_id=project_id, filter=filter)
+
+        for i in range(number_of_jobs):
+            assert jobs_list[i].name == f"name{i}"
+
+        expected_list_jobs_request: ListJobsRequest = ListJobsRequest(
+            parent=f"projects/{project_id}/locations/{region}", filter=filter
+        )
+        
mock_batch_service_client.return_value.list_jobs.assert_called_once_with(
+            request=expected_list_jobs_request
+        )
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_list_jobs_with_limit(self, mock_batch_service_client, 
cloud_batch_hook):
+
+        number_of_jobs = 3
+        limit = 2
+        region = "us-central1"
+        project_id = "test_project_id"
+        filter = "filter_description"
+
+        page = self._mock_pager(number_of_jobs)
+        mock_batch_service_client.return_value.list_jobs.return_value = page
+
+        jobs_list = cloud_batch_hook.list_jobs(
+            region=region, project_id=project_id, filter=filter, limit=limit
+        )
+
+        assert len(jobs_list) == limit
+        for i in range(limit):
+            assert jobs_list[i].name == f"name{i}"
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_list_jobs_with_limit_zero(self, mock_batch_service_client, 
cloud_batch_hook):
+
+        number_of_jobs = 3
+        limit = 0
+        region = "us-central1"
+        project_id = "test_project_id"
+        filter = "filter_description"
+
+        page = self._mock_pager(number_of_jobs)
+        mock_batch_service_client.return_value.list_jobs.return_value = page
+
+        jobs_list = cloud_batch_hook.list_jobs(
+            region=region, project_id=project_id, filter=filter, limit=limit
+        )
+
+        assert len(jobs_list) == 0
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_list_jobs_with_limit_greater_then_range(self, 
mock_batch_service_client, cloud_batch_hook):
+
+        number_of_jobs = 3
+        limit = 5
+        region = "us-central1"
+        project_id = "test_project_id"
+        filter = "filter_description"
+
+        page = self._mock_pager(number_of_jobs)
+        mock_batch_service_client.return_value.list_jobs.return_value = page
+
+        jobs_list = cloud_batch_hook.list_jobs(
+            region=region, project_id=project_id, filter=filter, limit=limit
+        )
+
+        assert len(jobs_list) == number_of_jobs
+        for i in range(number_of_jobs):
+            assert jobs_list[i].name == f"name{i}"
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_list_jobs_with_limit_less_than_zero(self, 
mock_batch_service_client, cloud_batch_hook):
+
+        number_of_jobs = 3
+        limit = -1
+        region = "us-central1"
+        project_id = "test_project_id"
+        filter = "filter_description"
+
+        page = self._mock_pager(number_of_jobs)
+        mock_batch_service_client.return_value.list_jobs.return_value = page
+
+        with pytest.raises(expected_exception=AirflowException):
+            cloud_batch_hook.list_jobs(region=region, project_id=project_id, 
filter=filter, limit=limit)
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_list_tasks_with_limit(self, mock_batch_service_client, 
cloud_batch_hook):
+
+        number_of_tasks = 3
+        limit = 2
+        region = "us-central1"
+        project_id = "test_project_id"
+        filter = "filter_description"
+        job_name = "test_job"
+
+        page = self._mock_pager(number_of_tasks)
+        mock_batch_service_client.return_value.list_tasks.return_value = page
+
+        tasks_list = cloud_batch_hook.list_tasks(
+            region=region, project_id=project_id, job_name=job_name, 
filter=filter, limit=limit
+        )
+
+        assert len(tasks_list) == limit
+        for i in range(limit):
+            assert tasks_list[i].name == f"name{i}"
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_list_tasks_with_limit_greater_then_range(self, 
mock_batch_service_client, cloud_batch_hook):
+
+        number_of_tasks = 3
+        limit = 5
+        region = "us-central1"
+        project_id = "test_project_id"
+        filter = "filter_description"
+        job_name = "test_job"
+
+        page = self._mock_pager(number_of_tasks)
+        mock_batch_service_client.return_value.list_tasks.return_value = page
+
+        tasks_list = cloud_batch_hook.list_tasks(
+            region=region, project_id=project_id, filter=filter, 
job_name=job_name, limit=limit
+        )
+
+        assert len(tasks_list) == number_of_tasks
+        for i in range(number_of_tasks):
+            assert tasks_list[i].name == f"name{i}"
+
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+    def test_list_tasks_with_limit_less_than_zero(self, 
mock_batch_service_client, cloud_batch_hook):
+
+        number_of_tasks = 3
+        limit = -1
+        region = "us-central1"
+        project_id = "test_project_id"
+        filter = "filter_description"
+        job_name = "test_job"
+
+        page = self._mock_pager(number_of_tasks)
+        mock_batch_service_client.return_value.list_tasks.return_value = page
+
+        with pytest.raises(expected_exception=AirflowException):
+            cloud_batch_hook.list_tasks(
+                region=region, project_id=project_id, job_name=job_name, 
filter=filter, limit=limit
+            )
+
+    def _mock_job_with_status(self, status: JobStatus.State):
+        job = mock.MagicMock()
+        job.status.state = status
+        return job
+
+    def _mock_pager(self, number_of_jobs):
+        mock_pager = []
+        for i in range(number_of_jobs):
+            mock_pager.append(Job(name=f"name{i}"))
+
+        return mock_pager
+
+
+class TestCloudBatchAsyncHook:
+    @pytest.mark.asyncio
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceAsyncClient")
+    async def test_get_job(self, mock_client):
+        expected_job = {"name": "somename"}
+
+        async def _get_job(name):
+            return expected_job
+
+        job_name = "jobname"
+        mock_client.return_value = mock.MagicMock()
+        mock_client.return_value.get_job = _get_job
+
+        hook = CloudBatchAsyncHook()
+        hook.get_credentials = self._dummy_get_credentials
+
+        returned_operation = await hook.get_batch_job(job_name=job_name)
+
+        assert returned_operation == expected_job
+
+    def _dummy_get_credentials(self):
+        pass
diff --git a/tests/providers/google/cloud/operators/test_cloud_batch.py 
b/tests/providers/google/cloud/operators/test_cloud_batch.py
new file mode 100644
index 0000000000..a4377eebdc
--- /dev/null
+++ b/tests/providers/google/cloud/operators/test_cloud_batch.py
@@ -0,0 +1,190 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from google.cloud import batch_v1
+
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.providers.google.cloud.operators.cloud_batch import (
+    CloudBatchDeleteJobOperator,
+    CloudBatchListJobsOperator,
+    CloudBatchListTasksOperator,
+    CloudBatchSubmitJobOperator,
+)
+
+CLOUD_BATCH_HOOK_PATH = 
"airflow.providers.google.cloud.operators.cloud_batch.CloudBatchHook"
+TASK_ID = "test"
+PROJECT_ID = "testproject"
+REGION = "us-central1"
+JOB_NAME = "test"
+JOB = batch_v1.Job()
+JOB.name = JOB_NAME
+
+
+class TestCloudBatchSubmitJobOperator:
+    @mock.patch(CLOUD_BATCH_HOOK_PATH)
+    def test_execute(self, mock):
+        mock.return_value.wait_for_job.return_value = JOB
+        operator = CloudBatchSubmitJobOperator(
+            task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, job=JOB
+        )
+
+        completed_job = operator.execute(context=mock.MagicMock())
+
+        assert completed_job["name"] == JOB_NAME
+
+        
mock.return_value.submit_batch_job.assert_called_with(job_name=JOB_NAME, 
job=JOB, region=REGION)
+        mock.return_value.wait_for_job.assert_called()
+
+    @mock.patch(CLOUD_BATCH_HOOK_PATH)
+    def test_execute_deferrable(self, mock):
+        operator = CloudBatchSubmitJobOperator(
+            task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, job=JOB, deferrable=True
+        )
+
+        with pytest.raises(expected_exception=TaskDeferred):
+            operator.execute(context=mock.MagicMock())
+
+    @mock.patch(CLOUD_BATCH_HOOK_PATH)
+    def test_execute_complete(self, mock):
+        mock.return_value.get_job.return_value = JOB
+        operator = CloudBatchSubmitJobOperator(
+            task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, job=JOB, deferrable=True
+        )
+
+        event = {"status": "success", "job_name": JOB_NAME, "message": "test 
error"}
+        completed_job = operator.execute_complete(context=mock.MagicMock(), 
event=event)
+
+        assert completed_job["name"] == JOB_NAME
+
+        mock.return_value.get_job.assert_called_once_with(job_name=JOB_NAME)
+
+    @mock.patch(CLOUD_BATCH_HOOK_PATH)
+    def test_execute_complete_exception(self, mock):
+        operator = CloudBatchSubmitJobOperator(
+            task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, job=JOB, deferrable=True
+        )
+
+        event = {"status": "error", "job_name": JOB_NAME, "message": "test 
error"}
+        with pytest.raises(
+            expected_exception=AirflowException, match="Unexpected error in 
the operation: test error"
+        ):
+            operator.execute_complete(context=mock.MagicMock(), event=event)
+
+
+class TestCloudBatchDeleteJobOperator:
+    @mock.patch(CLOUD_BATCH_HOOK_PATH)
+    def test_execute(self, hook_mock):
+        delete_operation_mock = self._delete_operation_mock()
+        hook_mock.return_value.delete_job.return_value = delete_operation_mock
+
+        operator = CloudBatchDeleteJobOperator(
+            task_id=TASK_ID,
+            project_id=PROJECT_ID,
+            region=REGION,
+            job_name=JOB_NAME,
+        )
+
+        operator.execute(context=mock.MagicMock())
+
+        hook_mock.return_value.delete_job.assert_called_once_with(
+            job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID
+        )
+        delete_operation_mock.result.assert_called_once()
+
+    def _delete_operation_mock(self):
+        operation = mock.MagicMock()
+        operation.result.return_value = mock.MagicMock()
+        return operation
+
+
+class TestCloudBatchListJobsOperator:
+    @mock.patch(CLOUD_BATCH_HOOK_PATH)
+    def test_execute(self, hook_mock):
+
+        filter = "filter_description"
+        limit = 2
+        operator = CloudBatchListJobsOperator(
+            task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
filter=filter, limit=limit
+        )
+
+        operator.execute(context=mock.MagicMock())
+
+        hook_mock.return_value.list_jobs.assert_called_once_with(
+            region=REGION, project_id=PROJECT_ID, filter=filter, limit=limit
+        )
+
+    @mock.patch(CLOUD_BATCH_HOOK_PATH)
+    def test_execute_with_invalid_limit(self, hook_mock):
+
+        filter = "filter_description"
+        limit = -1
+
+        with pytest.raises(expected_exception=AirflowException):
+            CloudBatchListJobsOperator(
+                task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
filter=filter, limit=limit
+            )
+
+
+class TestCloudBatchListTasksOperator:
+    @mock.patch(CLOUD_BATCH_HOOK_PATH)
+    def test_execute(self, hook_mock):
+
+        filter = "filter_description"
+        limit = 2
+        job_name = "test_job"
+
+        operator = CloudBatchListTasksOperator(
+            task_id=TASK_ID,
+            project_id=PROJECT_ID,
+            region=REGION,
+            job_name=job_name,
+            filter=filter,
+            limit=limit,
+        )
+
+        operator.execute(context=mock.MagicMock())
+
+        hook_mock.return_value.list_tasks.assert_called_once_with(
+            region=REGION,
+            project_id=PROJECT_ID,
+            filter=filter,
+            job_name=job_name,
+            limit=limit,
+            group_name="group0",
+        )
+
+    @mock.patch(CLOUD_BATCH_HOOK_PATH)
+    def test_execute_with_invalid_limit(self, hook_mock):
+
+        filter = "filter_description"
+        limit = -1
+        job_name = "test_job"
+
+        with pytest.raises(expected_exception=AirflowException):
+            CloudBatchListTasksOperator(
+                task_id=TASK_ID,
+                project_id=PROJECT_ID,
+                region=REGION,
+                job_name=job_name,
+                filter=filter,
+                limit=limit,
+            )
diff --git a/tests/providers/google/cloud/triggers/test_cloud_batch.py 
b/tests/providers/google/cloud/triggers/test_cloud_batch.py
new file mode 100644
index 0000000000..8da083f17e
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_cloud_batch.py
@@ -0,0 +1,160 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from google.cloud.batch_v1 import Job, JobStatus
+
+from airflow.providers.google.cloud.triggers.cloud_batch import 
CloudBatchJobFinishedTrigger
+from airflow.triggers.base import TriggerEvent
+
+JOB_NAME = "jobName"
+PROJECT_ID = "projectId"
+LOCATION = "us-central1"
+GCP_CONNECTION_ID = "gcp_connection_id"
+POLL_SLEEP = 0.01
+TIMEOUT = 0.02
+IMPERSONATION_CHAIN = "impersonation_chain"
+
+
[email protected]
+def trigger():
+    return CloudBatchJobFinishedTrigger(
+        job_name=JOB_NAME,
+        project_id=PROJECT_ID,
+        location=LOCATION,
+        gcp_conn_id=GCP_CONNECTION_ID,
+        polling_period_seconds=POLL_SLEEP,
+        timeout=TIMEOUT,
+        impersonation_chain=IMPERSONATION_CHAIN,
+    )
+
+
+class TestCloudBatchJobFinishedTrigger:
+    def test_serialization(self, trigger):
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchJobFinishedTrigger"
+        assert kwargs == {
+            "project_id": PROJECT_ID,
+            "job_name": JOB_NAME,
+            "location": LOCATION,
+            "gcp_conn_id": GCP_CONNECTION_ID,
+            "polling_period_seconds": POLL_SLEEP,
+            "timeout": TIMEOUT,
+            "impersonation_chain": IMPERSONATION_CHAIN,
+        }
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchAsyncHook")
+    async def test_trigger_on_success_yield_successfully(
+        self, mock_hook, trigger: CloudBatchJobFinishedTrigger
+    ):
+        """
+        Tests the CloudBatchJobFinishedTrigger fires once the job execution 
reaches a successful state.
+        """
+        state = JobStatus.State.SUCCEEDED
+        mock_hook.return_value.get_batch_job.return_value = 
self._mock_job_with_state(state)
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert (
+            TriggerEvent(
+                {
+                    "job_name": JOB_NAME,
+                    "status": "success",
+                    "message": "Job completed",
+                }
+            )
+            == actual
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchAsyncHook")
+    async def test_trigger_on_deleted_yield_successfully(
+        self, mock_hook, trigger: CloudBatchJobFinishedTrigger
+    ):
+        """
+        Tests the CloudBatchJobFinishedTrigger fires once the job execution 
reaches a successful state.
+        """
+        state = JobStatus.State.DELETION_IN_PROGRESS
+        mock_hook.return_value.get_batch_job.return_value = 
self._mock_job_with_state(state)
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert (
+            TriggerEvent(
+                {
+                    "job_name": JOB_NAME,
+                    "status": "deleted",
+                    "message": f"Batch job with name {JOB_NAME} is being 
deleted",
+                }
+            )
+            == actual
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchAsyncHook")
+    async def test_trigger_on_deleted_yield_exception(self, mock_hook, 
trigger: CloudBatchJobFinishedTrigger):
+        """
+        Tests the CloudBatchJobFinishedTrigger fires once the job execution
+        reaches an state with an error message.
+        """
+        mock_hook.return_value.get_batch_job.side_effect = Exception("Test 
Exception")
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert (
+            TriggerEvent(
+                {
+                    "status": "error",
+                    "message": "Test Exception",
+                }
+            )
+            == actual
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchAsyncHook")
+    async def test_trigger_timeout(self, mock_hook, trigger: 
CloudBatchJobFinishedTrigger):
+        """
+        Tests the CloudBatchJobFinishedTrigger fires once the job execution 
times out with an error message.
+        """
+
+        async def _mock_job(job_name):
+            job = mock.MagicMock()
+            job.status.state = JobStatus.State.RUNNING
+            return job
+
+        mock_hook.return_value.get_batch_job = _mock_job
+
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert (
+            TriggerEvent(
+                {
+                    "job_name": JOB_NAME,
+                    "status": "timed out",
+                    "message": f"Batch job with name {JOB_NAME} timed out",
+                }
+            )
+            == actual
+        )
+
+    async def _mock_job_with_state(self, state: JobStatus.State):
+        job: Job = mock.MagicMock()
+        job.status.state = state
+        return job
diff --git a/tests/system/providers/google/cloud/cloud_batch/__init__.py 
b/tests/system/providers/google/cloud/cloud_batch/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/system/providers/google/cloud/cloud_batch/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git 
a/tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py 
b/tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
new file mode 100644
index 0000000000..d3f3d752a8
--- /dev/null
+++ b/tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
@@ -0,0 +1,202 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that uses Google Cloud Batch Operators.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from google.cloud import batch_v1
+
+from airflow import models
+from airflow.operators.python import PythonOperator
+from airflow.providers.google.cloud.operators.cloud_batch import (
+    CloudBatchDeleteJobOperator,
+    CloudBatchListJobsOperator,
+    CloudBatchListTasksOperator,
+    CloudBatchSubmitJobOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "example_cloud_batch"
+
+region = "us-central1"
+job_name_prefix = "batch-system-test-job"
+job1_name = f"{job_name_prefix}1"
+job2_name = f"{job_name_prefix}2"
+
+submit1_task_name = "submit-job1"
+submit2_task_name = "submit-job2"
+
+delete1_task_name = "delete-job1"
+delete2_task_name = "delete-job2"
+
+list_jobs_task_name = "list-jobs"
+list_tasks_task_name = "list-tasks"
+
+clean1_task_name = "clean-job1"
+clean2_task_name = "clean-job2"
+
+
+def _assert_jobs(ti):
+    job_names = ti.xcom_pull(task_ids=[list_jobs_task_name], 
key="return_value")
+    job_names_str = job_names[0][0]["name"].split("/")[-1] + " " + 
job_names[0][1]["name"].split("/")[-1]
+    assert job1_name in job_names_str
+    assert job2_name in job_names_str
+
+
+def _assert_tasks(ti):
+    tasks_names = ti.xcom_pull(task_ids=[list_tasks_task_name], 
key="return_value")
+    assert len(tasks_names[0]) == 2
+    assert "tasks/0" in tasks_names[0][0]["name"]
+    assert "tasks/1" in tasks_names[0][1]["name"]
+
+
+# [START howto_operator_batch_job_creation]
+def _create_job():
+    runnable = batch_v1.Runnable()
+    runnable.container = batch_v1.Runnable.Container()
+    runnable.container.image_uri = "gcr.io/google-containers/busybox"
+    runnable.container.entrypoint = "/bin/sh"
+    runnable.container.commands = [
+        "-c",
+        "echo Hello world! This is task ${BATCH_TASK_INDEX}.\
+          This job has a total of ${BATCH_TASK_COUNT} tasks.",
+    ]
+
+    task = batch_v1.TaskSpec()
+    task.runnables = [runnable]
+
+    resources = batch_v1.ComputeResource()
+    resources.cpu_milli = 2000
+    resources.memory_mib = 16
+    task.compute_resource = resources
+    task.max_retry_count = 2
+
+    group = batch_v1.TaskGroup()
+    group.task_count = 2
+    group.task_spec = task
+    policy = batch_v1.AllocationPolicy.InstancePolicy()
+    policy.machine_type = "e2-standard-4"
+    instances = batch_v1.AllocationPolicy.InstancePolicyOrTemplate()
+    instances.policy = policy
+    allocation_policy = batch_v1.AllocationPolicy()
+    allocation_policy.instances = [instances]
+
+    job = batch_v1.Job()
+    job.task_groups = [group]
+    job.allocation_policy = allocation_policy
+    job.labels = {"env": "testing", "type": "container"}
+
+    job.logs_policy = batch_v1.LogsPolicy()
+    job.logs_policy.destination = batch_v1.LogsPolicy.Destination.CLOUD_LOGGING
+
+    return job
+
+
+# [END howto_operator_batch_job_creation]
+
+
+with models.DAG(
+    DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    catchup=False,
+    tags=["example", "batch"],
+) as dag:
+
+    # [START howto_operator_batch_submit_job]
+    submit1 = CloudBatchSubmitJobOperator(
+        task_id=submit1_task_name,
+        project_id=PROJECT_ID,
+        region=region,
+        job_name=job1_name,
+        job=_create_job(),
+        dag=dag,
+        deferrable=False,
+    )
+    # [END howto_operator_batch_submit_job]
+
+    # [START howto_operator_batch_submit_job_deferrable_mode]
+    submit2 = CloudBatchSubmitJobOperator(
+        task_id=submit2_task_name,
+        project_id=PROJECT_ID,
+        region=region,
+        job_name=job2_name,
+        job=batch_v1.Job.to_dict(_create_job()),
+        dag=dag,
+        deferrable=True,
+    )
+    # [END howto_operator_batch_submit_job_deferrable_mode]
+
+    # [START howto_operator_batch_list_tasks]
+    list_tasks = CloudBatchListTasksOperator(
+        task_id=list_tasks_task_name, project_id=PROJECT_ID, region=region, 
job_name=job1_name, dag=dag
+    )
+    # [END howto_operator_batch_list_tasks]
+
+    assert_tasks = PythonOperator(task_id="assert-tasks", 
python_callable=_assert_tasks, dag=dag)
+
+    # [START howto_operator_batch_list_jobs]
+    list_jobs = CloudBatchListJobsOperator(
+        task_id=list_jobs_task_name,
+        project_id=PROJECT_ID,
+        region=region,
+        limit=2,
+        
filter=f"name:projects/{PROJECT_ID}/locations/{region}/jobs/{job_name_prefix}*",
+        dag=dag,
+    )
+    # [END howto_operator_batch_list_jobs]
+
+    get_name = PythonOperator(task_id="assert-jobs", 
python_callable=_assert_jobs, dag=dag)
+
+    # [START howto_operator_delete_job]
+    delete_job1 = CloudBatchDeleteJobOperator(
+        task_id="delete-job1",
+        project_id=PROJECT_ID,
+        region=region,
+        job_name=job1_name,
+        dag=dag,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+    # [END howto_operator_delete_job]
+
+    delete_job2 = CloudBatchDeleteJobOperator(
+        task_id="delete-job2",
+        project_id=PROJECT_ID,
+        region=region,
+        job_name=job2_name,
+        dag=dag,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    ([submit1, submit2] >> list_tasks >> assert_tasks >> list_jobs >> get_name 
>> [delete_job1, delete_job2])
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to