This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 10ad8d9e38 Add operator to diagnose cluster (#36899)
10ad8d9e38 is described below

commit 10ad8d9e38351427acfa30c58a7702f0f4d66f05
Author: Flavia Nshemerirwe <flac...@users.noreply.github.com>
AuthorDate: Thu Jan 25 17:08:57 2024 +0000

    Add operator to diagnose cluster (#36899)
    
    1. Make diagnose_cluster hook return Operation object just like the rest of 
the hooks.
    2. Rename DataprocWorkflowTrigger to DataprocOperationTrigger handle all 
types of
    operations for get_operation
---
 airflow/providers/google/cloud/hooks/dataproc.py   |  58 ++++++--
 .../providers/google/cloud/operators/dataproc.py   | 150 ++++++++++++++++++++-
 .../providers/google/cloud/triggers/dataproc.py    |  44 ++++--
 airflow/providers/google/cloud/utils/dataproc.py   |  25 ++++
 .../operators/cloud/dataproc.rst                   |  25 ++++
 docs/apache-airflow/img/airflow_erd.svg            |   2 +-
 docs/spelling_wordlist.txt                         |   1 +
 .../providers/google/cloud/hooks/test_dataproc.py  |  16 ++-
 .../google/cloud/operators/test_dataproc.py        |  80 ++++++++++-
 .../google/cloud/triggers/test_dataproc.py         |  63 +++++++--
 .../providers/google/cloud/utils/test_dataproc.py  |  32 +++++
 .../dataproc/example_dataproc_cluster_diagnose.py  | 118 ++++++++++++++++
 12 files changed, 565 insertions(+), 49 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataproc.py 
b/airflow/providers/google/cloud/hooks/dataproc.py
index 10459118c0..dae5535e40 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 
 import time
 import uuid
+from collections.abc import MutableSequence
 from typing import TYPE_CHECKING, Any, Sequence
 
 from google.api_core.client_options import ClientOptions
@@ -54,6 +55,7 @@ if TYPE_CHECKING:
     from google.api_core.retry_async import AsyncRetry
     from google.protobuf.duration_pb2 import Duration
     from google.protobuf.field_mask_pb2 import FieldMask
+    from google.type.interval_pb2 import Interval
 
 
 class DataProcJobBuilder:
@@ -386,17 +388,25 @@ class DataprocHook(GoogleBaseHook):
         region: str,
         cluster_name: str,
         project_id: str,
+        tarball_gcs_dir: str | None = None,
+        diagnosis_interval: dict | Interval | None = None,
+        jobs: MutableSequence[str] | None = None,
+        yarn_application_ids: MutableSequence[str] | None = None,
         retry: Retry | _MethodDefault = DEFAULT,
         timeout: float | None = None,
         metadata: Sequence[tuple[str, str]] = (),
-    ) -> str:
+    ) -> Operation:
         """Get cluster diagnostic information.
 
-        After the operation completes, the GCS URI to diagnose is returned.
+        After the operation completes, the response contains the Cloud Storage 
URI of the diagnostic output report containing a summary of collected 
diagnostics.
 
         :param project_id: Google Cloud project ID that the cluster belongs to.
         :param region: Cloud Dataproc region in which to handle the request.
         :param cluster_name: Name of the cluster.
+        :param tarball_gcs_dir:  The output Cloud Storage directory for the 
diagnostic tarball. If not specified, a task-specific directory in the 
cluster's staging bucket will be used.
+        :param diagnosis_interval: Time interval in which diagnosis should be 
carried out on the cluster.
+        :param jobs: Specifies a list of jobs on which diagnosis is to be 
performed. Format: `projects/{project}/regions/{region}/jobs/{job}`
+        :param yarn_application_ids: Specifies a list of yarn applications on 
which diagnosis is to be performed.
         :param retry: A retry object used to retry requests. If *None*, 
requests
             will not be retried.
         :param timeout: The amount of time, in seconds, to wait for the request
@@ -405,15 +415,21 @@ class DataprocHook(GoogleBaseHook):
         :param metadata: Additional metadata that is provided to the method.
         """
         client = self.get_cluster_client(region=region)
-        operation = client.diagnose_cluster(
-            request={"project_id": project_id, "region": region, 
"cluster_name": cluster_name},
+        result = client.diagnose_cluster(
+            request={
+                "project_id": project_id,
+                "region": region,
+                "cluster_name": cluster_name,
+                "tarball_gcs_dir": tarball_gcs_dir,
+                "diagnosis_interval": diagnosis_interval,
+                "jobs": jobs,
+                "yarn_application_ids": yarn_application_ids,
+            },
             retry=retry,
             timeout=timeout,
             metadata=metadata,
         )
-        operation.result()
-        gcs_uri = str(operation.operation.response.value)
-        return gcs_uri
+        return result
 
     @GoogleBaseHook.fallback_to_default_project_id
     def get_cluster(
@@ -1243,17 +1259,25 @@ class DataprocAsyncHook(GoogleBaseHook):
         region: str,
         cluster_name: str,
         project_id: str,
+        tarball_gcs_dir: str | None = None,
+        diagnosis_interval: dict | Interval | None = None,
+        jobs: MutableSequence[str] | None = None,
+        yarn_application_ids: MutableSequence[str] | None = None,
         retry: AsyncRetry | _MethodDefault = DEFAULT,
         timeout: float | None = None,
         metadata: Sequence[tuple[str, str]] = (),
-    ) -> str:
+    ) -> AsyncOperation:
         """Get cluster diagnostic information.
 
-        After the operation completes, the GCS URI to diagnose is returned.
+        After the operation completes, the response contains the Cloud Storage 
URI of the diagnostic output report containing a summary of collected 
diagnostics.
 
         :param project_id: Google Cloud project ID that the cluster belongs to.
         :param region: Cloud Dataproc region in which to handle the request.
         :param cluster_name: Name of the cluster.
+        :param tarball_gcs_dir:  The output Cloud Storage directory for the 
diagnostic tarball. If not specified, a task-specific directory in the 
cluster's staging bucket will be used.
+        :param diagnosis_interval: Time interval in which diagnosis should be 
carried out on the cluster.
+        :param jobs: Specifies a list of jobs on which diagnosis is to be 
performed. Format: `projects/{project}/regions/{region}/jobs/{job}`
+        :param yarn_application_ids: Specifies a list of yarn applications on 
which diagnosis is to be performed.
         :param retry: A retry object used to retry requests. If *None*, 
requests
             will not be retried.
         :param timeout: The amount of time, in seconds, to wait for the request
@@ -1262,15 +1286,21 @@ class DataprocAsyncHook(GoogleBaseHook):
         :param metadata: Additional metadata that is provided to the method.
         """
         client = self.get_cluster_client(region=region)
-        operation = await client.diagnose_cluster(
-            request={"project_id": project_id, "region": region, 
"cluster_name": cluster_name},
+        result = await client.diagnose_cluster(
+            request={
+                "project_id": project_id,
+                "region": region,
+                "cluster_name": cluster_name,
+                "tarball_gcs_dir": tarball_gcs_dir,
+                "diagnosis_interval": diagnosis_interval,
+                "jobs": jobs,
+                "yarn_application_ids": yarn_application_ids,
+            },
             retry=retry,
             timeout=timeout,
             metadata=metadata,
         )
-        operation.result()
-        gcs_uri = str(operation.operation.response.value)
-        return gcs_uri
+        return result
 
     @GoogleBaseHook.fallback_to_default_project_id
     async def get_cluster(
diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index b14121139d..7f3fcd5d01 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -25,6 +25,7 @@ import re
 import time
 import uuid
 import warnings
+from collections.abc import MutableSequence
 from dataclasses import dataclass
 from datetime import datetime, timedelta
 from enum import Enum
@@ -56,9 +57,10 @@ from airflow.providers.google.cloud.triggers.dataproc import 
(
     DataprocBatchTrigger,
     DataprocClusterTrigger,
     DataprocDeleteClusterTrigger,
+    DataprocOperationTrigger,
     DataprocSubmitTrigger,
-    DataprocWorkflowTrigger,
 )
+from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
 from airflow.utils import timezone
 
 if TYPE_CHECKING:
@@ -66,6 +68,7 @@ if TYPE_CHECKING:
     from google.api_core.retry_async import AsyncRetry
     from google.protobuf.duration_pb2 import Duration
     from google.protobuf.field_mask_pb2 import FieldMask
+    from google.type.interval_pb2 import Interval
 
     from airflow.utils.context import Context
 
@@ -681,10 +684,13 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
             return
         self.log.info("Cluster is in ERROR state")
         self.log.info("Gathering diagnostic information.")
-        gcs_uri = hook.diagnose_cluster(
+        operation = hook.diagnose_cluster(
             region=self.region, cluster_name=self.cluster_name, 
project_id=self.project_id
         )
+        operation.result()
+        gcs_uri = str(operation.operation.response.value)
         self.log.info("Diagnostic information for cluster %s available at: 
%s", self.cluster_name, gcs_uri)
+
         if self.delete_on_error:
             self._delete_cluster(hook)
             # The delete op is asynchronous and can cause further failure if 
the cluster finishes
@@ -2054,7 +2060,7 @@ class 
DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
             self.log.info("Workflow %s completed successfully", workflow_id)
         else:
             self.defer(
-                trigger=DataprocWorkflowTrigger(
+                trigger=DataprocOperationTrigger(
                     name=operation_name,
                     project_id=self.project_id,
                     region=self.region,
@@ -2196,7 +2202,7 @@ class 
DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
             self.log.info("Workflow %s completed successfully", workflow_id)
         else:
             self.defer(
-                trigger=DataprocWorkflowTrigger(
+                trigger=DataprocOperationTrigger(
                     name=operation_name,
                     project_id=self.project_id or hook.project_id,
                     region=self.region,
@@ -2530,6 +2536,142 @@ class 
DataprocUpdateClusterOperator(GoogleCloudBaseOperator):
         self.log.info("%s completed successfully.", self.task_id)
 
 
+class DataprocDiagnoseClusterOperator(GoogleCloudBaseOperator):
+    """Diagnose a cluster in a project.
+
+    After the operation completes, the response contains the Cloud Storage URI 
of the diagnostic output report containing a summary of collected diagnostics.
+
+    :param region: Required. The Cloud Dataproc region in which to handle the 
request (templated).
+    :param project_id: Optional. The ID of the Google Cloud project that the 
cluster belongs to (templated).
+    :param cluster_name: Required. The cluster name (templated).
+    :param tarball_gcs_dir:  The output Cloud Storage directory for the 
diagnostic tarball. If not specified, a task-specific directory in the 
cluster's staging bucket will be used.
+    :param diagnosis_interval: Time interval in which diagnosis should be 
carried out on the cluster.
+    :param jobs: Specifies a list of jobs on which diagnosis is to be 
performed. Format: `projects/{project}/regions/{region}/jobs/{job}`
+    :param yarn_application_ids: Specifies a list of yarn applications on 
which diagnosis is to be performed.
+    :param metadata: Additional metadata that is provided to the method.
+    :param retry: A retry object used to retry requests. If ``None`` is 
specified, requests will not be
+        retried.
+    :param timeout: The amount of time, in seconds, to wait for the request to 
complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    :param deferrable: Run operator in the deferrable mode.
+    :param polling_interval_seconds: Time (seconds) to wait between calls to 
check the cluster status.
+    """
+
+    template_fields: Sequence[str] = (
+        "project_id",
+        "region",
+        "cluster_name",
+        "impersonation_chain",
+        "tarball_gcs_dir",
+        "diagnosis_interval",
+        "jobs",
+        "yarn_application_ids",
+    )
+
+    def __init__(
+        self,
+        *,
+        region: str,
+        cluster_name: str,
+        project_id: str | None = None,
+        tarball_gcs_dir: str | None = None,
+        diagnosis_interval: dict | Interval | None = None,
+        jobs: MutableSequence[str] | None = None,
+        yarn_application_ids: MutableSequence[str] | None = None,
+        retry: AsyncRetry | _MethodDefault = DEFAULT,
+        timeout: float = 1 * 60 * 60,
+        metadata: Sequence[tuple[str, str]] = (),
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        polling_interval_seconds: int = 10,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        if deferrable and polling_interval_seconds <= 0:
+            raise ValueError("Invalid value for polling_interval_seconds. 
Expected value greater than 0")
+        self.project_id = project_id
+        self.region = region
+        self.cluster_name = cluster_name
+        self.tarball_gcs_dir = tarball_gcs_dir
+        self.diagnosis_interval = diagnosis_interval
+        self.jobs = jobs
+        self.yarn_application_ids = yarn_application_ids
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.deferrable = deferrable
+        self.polling_interval_seconds = polling_interval_seconds
+
+    def execute(self, context: Context):
+        hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
+        self.log.info("Collecting diagnostic tarball for cluster: %s", 
self.cluster_name)
+        operation = hook.diagnose_cluster(
+            region=self.region,
+            cluster_name=self.cluster_name,
+            project_id=self.project_id,
+            tarball_gcs_dir=self.tarball_gcs_dir,
+            diagnosis_interval=self.diagnosis_interval,
+            jobs=self.jobs,
+            yarn_application_ids=self.yarn_application_ids,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+        if not self.deferrable:
+            result = hook.wait_for_operation(
+                timeout=self.timeout, result_retry=self.retry, 
operation=operation
+            )
+            self.log.info(
+                "The diagnostic output for cluster %s is available at: %s",
+                self.cluster_name,
+                result.output_uri,
+            )
+        else:
+            self.defer(
+                trigger=DataprocOperationTrigger(
+                    name=operation.operation.name,
+                    operation_type=DataprocOperationType.DIAGNOSE.value,
+                    project_id=self.project_id,
+                    region=self.region,
+                    gcp_conn_id=self.gcp_conn_id,
+                    impersonation_chain=self.impersonation_chain,
+                    polling_interval_seconds=self.polling_interval_seconds,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        """Callback for when the trigger fires.
+
+        This returns immediately. It relies on trigger to throw an exception,
+        otherwise it assumes execution was successful.
+        """
+        if event:
+            status = event.get("status")
+            if status in ("failed", "error"):
+                self.log.exception("Unexpected error in the operation.")
+                raise AirflowException(event.get("message"))
+
+            self.log.info(
+                "The diagnostic output for cluster %s is available at: %s",
+                self.cluster_name,
+                event.get("output_uri"),
+            )
+
+
 class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
     """Create a batch workload.
 
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py 
b/airflow/providers/google/cloud/triggers/dataproc.py
index e03f7a14ca..7a612cdc29 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -19,6 +19,7 @@
 from __future__ import annotations
 
 import asyncio
+import re
 import time
 from typing import Any, AsyncIterator, Sequence
 
@@ -26,6 +27,7 @@ from google.api_core.exceptions import NotFound
 from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
 
 from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
+from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
@@ -281,22 +283,24 @@ class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
             yield TriggerEvent({"status": "error", "message": "Timeout"})
 
 
-class DataprocWorkflowTrigger(DataprocBaseTrigger):
+class DataprocOperationTrigger(DataprocBaseTrigger):
     """
-    Trigger that periodically polls information from Dataproc API to verify 
status.
+    Trigger that periodically polls information on a long running operation 
from Dataproc API to verify status.
 
     Implementation leverages asynchronous transport.
     """
 
-    def __init__(self, name: str, **kwargs: Any):
+    def __init__(self, name: str, operation_type: str | None = None, **kwargs: 
Any):
         super().__init__(**kwargs)
         self.name = name
+        self.operation_type = operation_type
 
     def serialize(self):
         return (
-            
"airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger",
+            
"airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger",
             {
                 "name": self.name,
+                "operation_type": self.operation_type,
                 "project_id": self.project_id,
                 "region": self.region,
                 "gcp_conn_id": self.gcp_conn_id,
@@ -317,14 +321,30 @@ class DataprocWorkflowTrigger(DataprocBaseTrigger):
                     else:
                         status = "success"
                         message = "Operation is successfully ended."
-                    yield TriggerEvent(
-                        {
-                            "operation_name": operation.name,
-                            "operation_done": operation.done,
-                            "status": status,
-                            "message": message,
-                        }
-                    )
+                    if self.operation_type == 
DataprocOperationType.DIAGNOSE.value:
+                        gcs_regex = 
rb"gs:\/\/[a-z0-9][a-z0-9_-]{1,61}[a-z0-9_\-\/]*"
+                        gcs_uri_value = operation.response.value
+                        match = re.search(gcs_regex, gcs_uri_value)
+                        if match:
+                            output_uri = match.group(0).decode("utf-8", 
"ignore")
+                        else:
+                            output_uri = gcs_uri_value
+                        yield TriggerEvent(
+                            {
+                                "status": status,
+                                "message": message,
+                                "output_uri": output_uri,
+                            }
+                        )
+                    else:
+                        yield TriggerEvent(
+                            {
+                                "operation_name": operation.name,
+                                "operation_done": operation.done,
+                                "status": status,
+                                "message": message,
+                            }
+                        )
                     return
                 else:
                     self.log.info("Sleeping for %s seconds.", 
self.polling_interval_seconds)
diff --git a/airflow/providers/google/cloud/utils/dataproc.py 
b/airflow/providers/google/cloud/utils/dataproc.py
new file mode 100644
index 0000000000..d71792df3f
--- /dev/null
+++ b/airflow/providers/google/cloud/utils/dataproc.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from enum import Enum
+
+
+class DataprocOperationType(Enum):
+    """Contains types of long running operations."""
+
+    DIAGNOSE = "DIAGNOSE"
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst 
b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
index db4ef47d76..67c2831a1a 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
@@ -145,6 +145,31 @@ You can generate and use config as followed:
     :start-after: [START 
how_to_cloud_dataproc_create_cluster_generate_cluster_config]
     :end-before: [END 
how_to_cloud_dataproc_create_cluster_generate_cluster_config]
 
+Diagnose a cluster
+------------------
+Dataproc supports the collection of `cluster diagnostic information 
<https://cloud.google.com/dataproc/docs/support/diagnose-cluster-command#diagnostic_summary_and_archive_contents>`_
+like system, Spark, Hadoop, and Dataproc logs, cluster configuration files 
that can be used to troubleshoot a Dataproc cluster or job.
+It is important to note that this information can only be collected before the 
cluster is deleted.
+For more information about the available fields to pass when diagnosing a 
cluster, visit
+`Dataproc diagnose cluster API. 
<https://cloud.google.com/dataproc/docs/reference/rest/v1/projects.regions.clusters/diagnose>`_
+
+To diagnose a Dataproc cluster use:
+:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocDiagnoseClusterOperator.``
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_diagnose.py
+    :language: python
+    :dedent: 0
+    :start-after: [START how_to_cloud_dataproc_diagnose_cluster]
+    :end-before: [END how_to_cloud_dataproc_diagnose_cluster]
+
+You can also use deferrable mode in order to run the operator asynchronously:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_diagnose.py
+    :language: python
+    :dedent: 0
+    :start-after: [START how_to_cloud_dataproc_diagnose_cluster_deferrable]
+    :end-before: [END how_to_cloud_dataproc_diagnose_cluster_deferrable]
+
 Update a cluster
 ----------------
 You can scale the cluster up or down by providing a cluster config and a 
updateMask.
diff --git a/docs/apache-airflow/img/airflow_erd.svg 
b/docs/apache-airflow/img/airflow_erd.svg
index 497ef76975..2c3a7a33cc 100644
--- a/docs/apache-airflow/img/airflow_erd.svg
+++ b/docs/apache-airflow/img/airflow_erd.svg
@@ -1363,7 +1363,7 @@
 <g id="edge44" class="edge">
 <title>task_instance&#45;&#45;xcom</title>
 <path fill="none" stroke="#7f7f7f" stroke-dasharray="5,2" 
d="M1166.1,-816.29C1196.72,-811.66 1228.55,-806.13 1258.36,-800.24"/>
-<text text-anchor="start" x="1248.36" y="-804.04" font-family="Times,serif" 
font-size="14.00">1</text>
+<text text-anchor="start" x="1227.36" y="-804.04" font-family="Times,serif" 
font-size="14.00">0..N</text>
 <text text-anchor="start" x="1166.1" y="-820.09" font-family="Times,serif" 
font-size="14.00">1</text>
 </g>
 <!-- rendered_task_instance_fields -->
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index e5129c16fd..1dc8d5693b 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -387,6 +387,7 @@ datapoint
 Dataprep
 Dataproc
 dataproc
+DataprocDiagnoseClusterOperator
 DataScan
 dataScans
 Dataset
diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py 
b/tests/providers/google/cloud/hooks/test_dataproc.py
index c6eb6dcab3..1a82fc8a1c 100644
--- a/tests/providers/google/cloud/hooks/test_dataproc.py
+++ b/tests/providers/google/cloud/hooks/test_dataproc.py
@@ -197,19 +197,26 @@ class TestDataprocHook:
 
     @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
     def test_diagnose_cluster(self, mock_client):
-        self.hook.diagnose_cluster(project_id=GCP_PROJECT, 
region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
+        self.hook.diagnose_cluster(
+            project_id=GCP_PROJECT,
+            region=GCP_LOCATION,
+            cluster_name=CLUSTER_NAME,
+        )
         mock_client.assert_called_once_with(region=GCP_LOCATION)
         mock_client.return_value.diagnose_cluster.assert_called_once_with(
             request=dict(
                 project_id=GCP_PROJECT,
                 region=GCP_LOCATION,
                 cluster_name=CLUSTER_NAME,
+                tarball_gcs_dir=None,
+                jobs=None,
+                yarn_application_ids=None,
+                diagnosis_interval=None,
             ),
             metadata=(),
             retry=DEFAULT,
             timeout=None,
         )
-        
mock_client.return_value.diagnose_cluster.return_value.result.assert_called_once_with()
 
     @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
     def test_get_cluster(self, mock_client):
@@ -646,12 +653,15 @@ class TestDataprocAsyncHook:
                 project_id=GCP_PROJECT,
                 region=GCP_LOCATION,
                 cluster_name=CLUSTER_NAME,
+                tarball_gcs_dir=None,
+                jobs=None,
+                yarn_application_ids=None,
+                diagnosis_interval=None,
             ),
             metadata=(),
             retry=DEFAULT,
             timeout=None,
         )
-        
mock_client.return_value.diagnose_cluster.return_value.result.assert_called_once_with()
 
     @pytest.mark.asyncio
     @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_cluster_client"))
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py 
b/tests/providers/google/cloud/operators/test_dataproc.py
index 00f45ca8b3..d0b04a6fa9 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -47,6 +47,7 @@ from airflow.providers.google.cloud.operators.dataproc import 
(
     DataprocCreateWorkflowTemplateOperator,
     DataprocDeleteBatchOperator,
     DataprocDeleteClusterOperator,
+    DataprocDiagnoseClusterOperator,
     DataprocGetBatchOperator,
     DataprocInstantiateInlineWorkflowTemplateOperator,
     DataprocInstantiateWorkflowTemplateOperator,
@@ -68,8 +69,8 @@ from airflow.providers.google.cloud.triggers.dataproc import (
     DataprocBatchTrigger,
     DataprocClusterTrigger,
     DataprocDeleteClusterTrigger,
+    DataprocOperationTrigger,
     DataprocSubmitTrigger,
-    DataprocWorkflowTrigger,
 )
 from airflow.providers.google.common.consts import 
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 from airflow.serialization.serialized_objects import SerializedDAG
@@ -792,6 +793,7 @@ class 
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
         mock_hook.return_value.diagnose_cluster.assert_called_once_with(
             region=GCP_REGION, project_id=GCP_PROJECT, 
cluster_name=CLUSTER_NAME
         )
+        
mock_hook.return_value.diagnose_cluster.return_value.result.assert_called_once_with()
         mock_hook.return_value.delete_cluster.assert_called_once_with(
             region=GCP_REGION, project_id=GCP_PROJECT, 
cluster_name=CLUSTER_NAME
         )
@@ -1742,7 +1744,7 @@ class TestDataprocInstantiateWorkflowTemplateOperator:
         
mock_hook.return_value.instantiate_workflow_template.assert_called_once()
 
         mock_hook.return_value.wait_for_operation.assert_not_called()
-        assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
+        assert isinstance(exc.value.trigger, DataprocOperationTrigger)
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1872,7 +1874,7 @@ class 
TestDataprocWorkflowTemplateInstantiateInlineOperator:
 
         
mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once()
 
-        assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
+        assert isinstance(exc.value.trigger, DataprocOperationTrigger)
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -2685,3 +2687,75 @@ class TestDataprocListBatchesOperator:
 
         assert isinstance(exc.value.trigger, DataprocBatchTrigger)
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+
+
+class TestDataprocDiagnoseClusterOperator:
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute(self, mock_hook):
+        op = DataprocDiagnoseClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            gcp_conn_id=GCP_CONN_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        op.execute(context={})
+        mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, 
impersonation_chain=IMPERSONATION_CHAIN)
+        mock_hook.return_value.diagnose_cluster.assert_called_once_with(
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            tarball_gcs_dir=None,
+            diagnosis_interval=None,
+            jobs=None,
+            yarn_application_ids=None,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+    def test_create_execute_call_defer_method(self, mock_trigger_hook, 
mock_hook):
+        mock_hook.return_value.create_cluster.return_value = None
+        operator = DataprocDiagnoseClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            gcp_conn_id=GCP_CONN_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as exc:
+            operator.execute(mock.MagicMock())
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.diagnose_cluster.assert_called_once_with(
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            tarball_gcs_dir=None,
+            diagnosis_interval=None,
+            jobs=None,
+            yarn_application_ids=None,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        mock_hook.return_value.wait_for_operation.assert_not_called()
+        assert isinstance(exc.value.trigger, DataprocOperationTrigger)
+        assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py 
b/tests/providers/google/cloud/triggers/test_dataproc.py
index 6f5e2782ad..8bf0d8d006 100644
--- a/tests/providers/google/cloud/triggers/test_dataproc.py
+++ b/tests/providers/google/cloud/triggers/test_dataproc.py
@@ -23,13 +23,15 @@ from unittest import mock
 
 import pytest
 from google.cloud.dataproc_v1 import Batch, ClusterStatus
+from google.protobuf.any_pb2 import Any
 from google.rpc.status_pb2 import Status
 
 from airflow.providers.google.cloud.triggers.dataproc import (
     DataprocBatchTrigger,
     DataprocClusterTrigger,
-    DataprocWorkflowTrigger,
+    DataprocOperationTrigger,
 )
+from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
 from airflow.triggers.base import TriggerEvent
 
 TEST_PROJECT_ID = "project-id"
@@ -73,8 +75,8 @@ def batch_trigger():
 
 
 @pytest.fixture
-def workflow_trigger():
-    return DataprocWorkflowTrigger(
+def operation_trigger():
+    return DataprocOperationTrigger(
         name=TEST_OPERATION_NAME,
         project_id=TEST_PROJECT_ID,
         region=TEST_REGION,
@@ -84,6 +86,19 @@ def workflow_trigger():
     )
 
 
+@pytest.fixture
+def diagnose_operation_trigger():
+    return DataprocOperationTrigger(
+        name=TEST_OPERATION_NAME,
+        operation_type=DataprocOperationType.DIAGNOSE.value,
+        project_id=TEST_PROJECT_ID,
+        region=TEST_REGION,
+        gcp_conn_id=TEST_GCP_CONN_ID,
+        impersonation_chain=None,
+        polling_interval_seconds=TEST_POLL_INTERVAL,
+    )
+
+
 @pytest.fixture()
 def async_get_cluster():
     def func(**kwargs):
@@ -286,13 +301,14 @@ class TestDataprocBatchTrigger:
         assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
 
 
-class TestDataprocWorkflowTrigger:
-    def 
test_async_cluster_trigger_serialization_should_execute_successfully(self, 
workflow_trigger):
-        classpath, kwargs = workflow_trigger.serialize()
-        assert classpath == 
"airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger"
+class TestDataprocOperationTrigger:
+    def 
test_async_cluster_trigger_serialization_should_execute_successfully(self, 
operation_trigger):
+        classpath, kwargs = operation_trigger.serialize()
+        assert classpath == 
"airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger"
         assert kwargs == {
             "name": TEST_OPERATION_NAME,
             "project_id": TEST_PROJECT_ID,
+            "operation_type": None,
             "region": TEST_REGION,
             "gcp_conn_id": TEST_GCP_CONN_ID,
             "impersonation_chain": None,
@@ -301,8 +317,8 @@ class TestDataprocWorkflowTrigger:
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
-    async def 
test_async_workflow_triggers_on_success_should_execute_successfully(
-        self, mock_hook, workflow_trigger, async_get_operation
+    async def 
test_async_operation_triggers_on_success_should_execute_successfully(
+        self, mock_hook, operation_trigger, async_get_operation
     ):
         mock_hook.return_value.get_operation.return_value = 
async_get_operation(
             name=TEST_OPERATION_NAME, done=True, response={}, 
error=Status(message="")
@@ -316,12 +332,35 @@ class TestDataprocWorkflowTrigger:
                 "message": "Operation is successfully ended.",
             }
         )
-        actual_event = await workflow_trigger.run().asend(None)
+        actual_event = await operation_trigger.run().asend(None)
+        assert expected_event == actual_event
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
+    async def 
test_async_diagnose_operation_triggers_on_success_should_execute_successfully(
+        self, mock_hook, diagnose_operation_trigger, async_get_operation
+    ):
+        gcs_uri = "gs://test-tarball-gcs-dir-bucket"
+        mock_hook.return_value.get_operation.return_value = 
async_get_operation(
+            name=TEST_OPERATION_NAME,
+            done=True,
+            response=Any(value=gcs_uri.encode("utf-8")),
+            error=Status(message=""),
+        )
+
+        expected_event = TriggerEvent(
+            {
+                "output_uri": gcs_uri,
+                "status": "success",
+                "message": "Operation is successfully ended.",
+            }
+        )
+        actual_event = await diagnose_operation_trigger.run().asend(None)
         assert expected_event == actual_event
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
-    async def test_async_workflow_triggers_on_error(self, mock_hook, 
workflow_trigger, async_get_operation):
+    async def test_async_operation_triggers_on_error(self, mock_hook, 
operation_trigger, async_get_operation):
         mock_hook.return_value.get_operation.return_value = 
async_get_operation(
             name=TEST_OPERATION_NAME, done=True, response={}, 
error=Status(message="test_error")
         )
@@ -334,5 +373,5 @@ class TestDataprocWorkflowTrigger:
                 "message": "test_error",
             }
         )
-        actual_event = await workflow_trigger.run().asend(None)
+        actual_event = await operation_trigger.run().asend(None)
         assert expected_event == actual_event
diff --git a/tests/providers/google/cloud/utils/test_dataproc.py 
b/tests/providers/google/cloud/utils/test_dataproc.py
new file mode 100644
index 0000000000..c6bb5080f5
--- /dev/null
+++ b/tests/providers/google/cloud/utils/test_dataproc.py
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
+
+
+class TestDataprocOperationType:
+    @pytest.mark.parametrize(
+        "str_value, expected_item",
+        [
+            ("DIAGNOSE", DataprocOperationType.DIAGNOSE.value),
+        ],
+    )
+    def test_diagnose_operation(self, str_value, expected_item):
+        assert str_value == expected_item
diff --git 
a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_diagnose.py
 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_diagnose.py
new file mode 100644
index 0000000000..5ed41d77da
--- /dev/null
+++ 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_diagnose.py
@@ -0,0 +1,118 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocDiagnoseClusterOperator.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.dataproc import (
+    DataprocCreateClusterOperator,
+    DataprocDeleteClusterOperator,
+    DataprocDiagnoseClusterOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "dataproc_diagnose_cluster"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+
+CLUSTER_NAME = f"cluster-{ENV_ID}-{DAG_ID}".replace("_", "-")
+REGION = "europe-west1"
+
+
+# Cluster definition
+CLUSTER_CONFIG = {
+    "master_config": {
+        "num_instances": 1,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
32},
+    },
+    "worker_config": {
+        "num_instances": 2,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
32},
+    },
+}
+
+
+with DAG(
+    DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    catchup=False,
+    tags=["example", "dataproc", "diagnose", "cluster"],
+) as dag:
+    create_cluster = DataprocCreateClusterOperator(
+        task_id="create_cluster",
+        project_id=PROJECT_ID,
+        cluster_config=CLUSTER_CONFIG,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+    )
+
+    # [START how_to_cloud_dataproc_diagnose_cluster]
+    diagnose_cluster = DataprocDiagnoseClusterOperator(
+        task_id="diagnose_cluster",
+        region=REGION,
+        project_id=PROJECT_ID,
+        cluster_name=CLUSTER_NAME,
+    )
+    # [END how_to_cloud_dataproc_diagnose_cluster]
+
+    # [START how_to_cloud_dataproc_diagnose_cluster_deferrable]
+    diagnose_cluster_deferrable = DataprocDiagnoseClusterOperator(
+        task_id="diagnose_cluster_deferrable",
+        region=REGION,
+        project_id=PROJECT_ID,
+        cluster_name=CLUSTER_NAME,
+        deferrable=True,
+    )
+    # [END how_to_cloud_dataproc_diagnose_cluster_deferrable]
+
+    delete_cluster = DataprocDeleteClusterOperator(
+        task_id="delete_cluster",
+        project_id=PROJECT_ID,
+        cluster_name=CLUSTER_NAME,
+        region=REGION,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    (
+        # TEST SETUP
+        create_cluster
+        # TEST BODY
+        >> diagnose_cluster
+        # TEST TEARDOWN
+        >> delete_cluster
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "teardown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)


Reply via email to