This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 86e613029b Implement CloudComposerDAGRunSensor (#40088)
86e613029b is described below

commit 86e613029b871b0a8327d64c040da56f537c0727
Author: Maksim <maks...@google.com>
AuthorDate: Fri Jun 7 06:51:46 2024 -0700

    Implement CloudComposerDAGRunSensor (#40088)
---
 .../google/cloud/sensors/cloud_composer.py         | 173 ++++++++++++++++++++-
 .../google/cloud/triggers/cloud_composer.py        | 115 ++++++++++++++
 .../operators/cloud/cloud_composer.rst             |  20 +++
 .../google/cloud/sensors/test_cloud_composer.py    |  63 +++++++-
 .../google/cloud/triggers/test_cloud_composer.py   |  61 +++++++-
 .../cloud/composer/example_cloud_composer.py       |  25 +++
 6 files changed, 447 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/google/cloud/sensors/cloud_composer.py 
b/airflow/providers/google/cloud/sensors/cloud_composer.py
index 22d16e8f33..0301466eac 100644
--- a/airflow/providers/google/cloud/sensors/cloud_composer.py
+++ b/airflow/providers/google/cloud/sensors/cloud_composer.py
@@ -19,13 +19,24 @@
 
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Any, Sequence
+import json
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, Any, Iterable, Sequence
 
+from dateutil import parser
 from deprecated import deprecated
+from google.cloud.orchestration.airflow.service_v1.types import 
ExecuteAirflowCommandResponse
 
+from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowSkipException
-from airflow.providers.google.cloud.triggers.cloud_composer import 
CloudComposerExecutionTrigger
+from airflow.providers.google.cloud.hooks.cloud_composer import 
CloudComposerHook
+from airflow.providers.google.cloud.triggers.cloud_composer import (
+    CloudComposerDAGRunTrigger,
+    CloudComposerExecutionTrigger,
+)
+from airflow.providers.google.common.consts import 
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 from airflow.sensors.base import BaseSensorOperator
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -117,3 +128,161 @@ class CloudComposerEnvironmentSensor(BaseSensorOperator):
         if self.soft_fail:
             raise AirflowSkipException(message)
         raise AirflowException(message)
+
+
+class CloudComposerDAGRunSensor(BaseSensorOperator):
+    """
+    Check if a DAG run has completed.
+
+    :param project_id: Required. The ID of the Google Cloud project that the 
service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+    :param environment_id: The name of the Composer environment.
+    :param composer_dag_id: The ID of executable DAG.
+    :param allowed_states: Iterable of allowed states, default is 
``['success']``.
+    :param execution_range: execution DAGs time range. Sensor checks DAGs 
states only for DAGs which were
+        started in this time range. For yesterday, use [positive!] 
datetime.timedelta(days=1).
+        For future, use [negative!] datetime.timedelta(days=-1). For specific 
time, use list of
+        datetimes [datetime(2024,3,22,11,0,0), datetime(2024,3,22,12,0,0)].
+        Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for 
states from specific time in the
+        past till current time execution.
+        Default value datetime.timedelta(days=1).
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    :param poll_interval: Optional: Control the rate of the poll for the 
result of deferrable run.
+    :param deferrable: Run sensor in deferrable mode.
+    """
+
+    template_fields = (
+        "project_id",
+        "region",
+        "environment_id",
+        "composer_dag_id",
+        "impersonation_chain",
+    )
+
+    def __init__(
+        self,
+        *,
+        project_id: str,
+        region: str,
+        environment_id: str,
+        composer_dag_id: str,
+        allowed_states: Iterable[str] | None = None,
+        execution_range: timedelta | list[datetime] | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        poll_interval: int = 10,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.environment_id = environment_id
+        self.composer_dag_id = composer_dag_id
+        self.allowed_states = list(allowed_states) if allowed_states else 
[TaskInstanceState.SUCCESS.value]
+        self.execution_range = execution_range
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.deferrable = deferrable
+        self.poll_interval = poll_interval
+
+    def _get_execution_dates(self, context) -> tuple[datetime, datetime]:
+        if isinstance(self.execution_range, timedelta):
+            if self.execution_range < timedelta(0):
+                return context["logical_date"], context["logical_date"] - 
self.execution_range
+            else:
+                return context["logical_date"] - self.execution_range, 
context["logical_date"]
+        elif isinstance(self.execution_range, list) and 
len(self.execution_range) > 0:
+            return self.execution_range[0], self.execution_range[1] if len(
+                self.execution_range
+            ) > 1 else context["logical_date"]
+        else:
+            return context["logical_date"] - timedelta(1), 
context["logical_date"]
+
+    def poke(self, context: Context) -> bool:
+        start_date, end_date = self._get_execution_dates(context)
+
+        if datetime.now(end_date.tzinfo) < end_date:
+            return False
+
+        dag_runs = self._pull_dag_runs()
+
+        self.log.info("Sensor waits for allowed states: %s", 
self.allowed_states)
+        allowed_states_status = self._check_dag_runs_states(
+            dag_runs=dag_runs,
+            start_date=start_date,
+            end_date=end_date,
+        )
+
+        return allowed_states_status
+
+    def _pull_dag_runs(self) -> list[dict]:
+        """Pull the list of dag runs."""
+        hook = CloudComposerHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+        dag_runs_cmd = hook.execute_airflow_command(
+            project_id=self.project_id,
+            region=self.region,
+            environment_id=self.environment_id,
+            command="dags",
+            subcommand="list-runs",
+            parameters=["-d", self.composer_dag_id, "-o", "json"],
+        )
+        cmd_result = hook.wait_command_execution_result(
+            project_id=self.project_id,
+            region=self.region,
+            environment_id=self.environment_id,
+            
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
+        )
+        dag_runs = json.loads(cmd_result["output"][0]["content"])
+        return dag_runs
+
+    def _check_dag_runs_states(
+        self,
+        dag_runs: list[dict],
+        start_date: datetime,
+        end_date: datetime,
+    ) -> bool:
+        for dag_run in dag_runs:
+            if (
+                start_date.timestamp()
+                < parser.parse(dag_run["execution_date"]).timestamp()
+                < end_date.timestamp()
+            ) and dag_run["state"] not in self.allowed_states:
+                return False
+        return True
+
+    def execute(self, context: Context) -> None:
+        if self.deferrable:
+            start_date, end_date = self._get_execution_dates(context)
+            self.defer(
+                trigger=CloudComposerDAGRunTrigger(
+                    project_id=self.project_id,
+                    region=self.region,
+                    environment_id=self.environment_id,
+                    composer_dag_id=self.composer_dag_id,
+                    start_date=start_date,
+                    end_date=end_date,
+                    allowed_states=self.allowed_states,
+                    gcp_conn_id=self.gcp_conn_id,
+                    impersonation_chain=self.impersonation_chain,
+                    poll_interval=self.poll_interval,
+                ),
+                method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
+            )
+        super().execute(context)
+
+    def execute_complete(self, context: Context, event: dict):
+        if event and event["status"] == "error":
+            raise AirflowException(event["message"])
+        self.log.info("DAG %s has executed successfully.", 
self.composer_dag_id)
diff --git a/airflow/providers/google/cloud/triggers/cloud_composer.py 
b/airflow/providers/google/cloud/triggers/cloud_composer.py
index ac5a00c60f..2334d038e6 100644
--- a/airflow/providers/google/cloud/triggers/cloud_composer.py
+++ b/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -19,8 +19,13 @@
 from __future__ import annotations
 
 import asyncio
+import json
+from datetime import datetime
 from typing import Any, Sequence
 
+from dateutil import parser
+from google.cloud.orchestration.airflow.service_v1.types import 
ExecuteAirflowCommandResponse
+
 from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.hooks.cloud_composer import 
CloudComposerAsyncHook
 from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -146,3 +151,113 @@ class CloudComposerAirflowCLICommandTrigger(BaseTrigger):
             }
         )
         return
+
+
+class CloudComposerDAGRunTrigger(BaseTrigger):
+    """The trigger wait for the DAG run completion."""
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        environment_id: str,
+        composer_dag_id: str,
+        start_date: datetime,
+        end_date: datetime,
+        allowed_states: list[str],
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        poll_interval: int = 10,
+    ):
+        super().__init__()
+        self.project_id = project_id
+        self.region = region
+        self.environment_id = environment_id
+        self.composer_dag_id = composer_dag_id
+        self.start_date = start_date
+        self.end_date = end_date
+        self.allowed_states = allowed_states
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.poll_interval = poll_interval
+
+        self.gcp_hook = CloudComposerAsyncHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger",
+            {
+                "project_id": self.project_id,
+                "region": self.region,
+                "environment_id": self.environment_id,
+                "composer_dag_id": self.composer_dag_id,
+                "start_date": self.start_date,
+                "end_date": self.end_date,
+                "allowed_states": self.allowed_states,
+                "gcp_conn_id": self.gcp_conn_id,
+                "impersonation_chain": self.impersonation_chain,
+                "poll_interval": self.poll_interval,
+            },
+        )
+
+    async def _pull_dag_runs(self) -> list[dict]:
+        """Pull the list of dag runs."""
+        dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
+            project_id=self.project_id,
+            region=self.region,
+            environment_id=self.environment_id,
+            command="dags",
+            subcommand="list-runs",
+            parameters=["-d", self.composer_dag_id, "-o", "json"],
+        )
+        cmd_result = await self.gcp_hook.wait_command_execution_result(
+            project_id=self.project_id,
+            region=self.region,
+            environment_id=self.environment_id,
+            
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
+        )
+        dag_runs = json.loads(cmd_result["output"][0]["content"])
+        return dag_runs
+
+    def _check_dag_runs_states(
+        self,
+        dag_runs: list[dict],
+        start_date: datetime,
+        end_date: datetime,
+    ) -> bool:
+        for dag_run in dag_runs:
+            if (
+                start_date.timestamp()
+                < parser.parse(dag_run["execution_date"]).timestamp()
+                < end_date.timestamp()
+            ) and dag_run["state"] not in self.allowed_states:
+                return False
+        return True
+
+    async def run(self):
+        try:
+            while True:
+                if datetime.now(self.end_date.tzinfo).timestamp() > 
self.end_date.timestamp():
+                    dag_runs = await self._pull_dag_runs()
+
+                    self.log.info("Sensor waits for allowed states: %s", 
self.allowed_states)
+                    if self._check_dag_runs_states(
+                        dag_runs=dag_runs,
+                        start_date=self.start_date,
+                        end_date=self.end_date,
+                    ):
+                        yield TriggerEvent({"status": "success"})
+                        return
+                self.log.info("Sleeping for %s seconds.", self.poll_interval)
+                await asyncio.sleep(self.poll_interval)
+        except AirflowException as ex:
+            yield TriggerEvent(
+                {
+                    "status": "error",
+                    "message": str(ex),
+                }
+            )
+            return
diff --git 
a/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst 
b/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
index cdb9cb2931..f8f00fbe6c 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
@@ -177,3 +177,23 @@ or you can define the same operator in the deferrable mode:
     :dedent: 4
     :start-after: [START 
howto_operator_run_airflow_cli_command_deferrable_mode]
     :end-before: [END howto_operator_run_airflow_cli_command_deferrable_mode]
+
+Check if a DAG run has completed
+--------------------------------
+
+You can use sensor that checks if a DAG run has completed in your 
environments, use:
+:class:`~airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerDAGRunSensor`
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/composer/example_cloud_composer.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_sensor_dag_run]
+    :end-before: [END howto_sensor_dag_run]
+
+or you can define the same sensor in the deferrable mode:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/composer/example_cloud_composer.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_sensor_dag_run_deferrable_mode]
+    :end-before: [END howto_sensor_dag_run_deferrable_mode]
diff --git a/tests/providers/google/cloud/sensors/test_cloud_composer.py 
b/tests/providers/google/cloud/sensors/test_cloud_composer.py
index 5241ff551e..c22eb90fde 100644
--- a/tests/providers/google/cloud/sensors/test_cloud_composer.py
+++ b/tests/providers/google/cloud/sensors/test_cloud_composer.py
@@ -17,17 +17,42 @@
 
 from __future__ import annotations
 
+import json
+from datetime import datetime
 from unittest import mock
 
 import pytest
 
 from airflow.exceptions import AirflowException, AirflowSkipException, 
TaskDeferred
-from airflow.providers.google.cloud.sensors.cloud_composer import 
CloudComposerEnvironmentSensor
-from airflow.providers.google.cloud.triggers.cloud_composer import 
CloudComposerExecutionTrigger
+from airflow.providers.google.cloud.sensors.cloud_composer import (
+    CloudComposerDAGRunSensor,
+    CloudComposerEnvironmentSensor,
+)
+from airflow.providers.google.cloud.triggers.cloud_composer import (
+    CloudComposerExecutionTrigger,
+)
 
 TEST_PROJECT_ID = "test_project_id"
 TEST_OPERATION_NAME = "test_operation_name"
 TEST_REGION = "region"
+TEST_ENVIRONMENT_ID = "test_env_id"
+TEST_JSON_RESULT = lambda state: json.dumps(
+    [
+        {
+            "dag_id": "test_dag_id",
+            "run_id": "scheduled__2024-05-22T11:10:00+00:00",
+            "state": state,
+            "execution_date": "2024-05-22T11:10:00+00:00",
+            "start_date": "2024-05-22T11:20:01.531988+00:00",
+            "end_date": "2024-05-22T11:20:11.997479+00:00",
+        }
+    ]
+)
+TEST_EXEC_RESULT = lambda state: {
+    "output": [{"line_number": 1, "content": TEST_JSON_RESULT(state)}],
+    "output_end": True,
+    "exit_info": {"exit_code": 0, "error": ""},
+}
 
 
 class TestCloudComposerEnvironmentSensor:
@@ -76,3 +101,37 @@ class TestCloudComposerEnvironmentSensor:
             task.execute_complete(
                 context={}, event={"operation_done": True, "operation_name": 
TEST_OPERATION_NAME}
             )
+
+
+class TestCloudComposerDAGRunSensor:
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_wait_ready(self, mock_hook, to_dict_mode):
+        mock_hook.return_value.wait_command_execution_result.return_value = 
TEST_EXEC_RESULT("success")
+
+        task = CloudComposerDAGRunSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_dag_id="test_dag_id",
+            allowed_states=["success"],
+        )
+
+        assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 
0)})
+
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_wait_not_ready(self, mock_hook, to_dict_mode):
+        mock_hook.return_value.wait_command_execution_result.return_value = 
TEST_EXEC_RESULT("running")
+
+        task = CloudComposerDAGRunSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_dag_id="test_dag_id",
+            allowed_states=["success"],
+        )
+
+        assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
diff --git a/tests/providers/google/cloud/triggers/test_cloud_composer.py 
b/tests/providers/google/cloud/triggers/test_cloud_composer.py
index 99daaf83bd..00d109ed97 100644
--- a/tests/providers/google/cloud/triggers/test_cloud_composer.py
+++ b/tests/providers/google/cloud/triggers/test_cloud_composer.py
@@ -17,12 +17,16 @@
 
 from __future__ import annotations
 
+from datetime import datetime
 from unittest import mock
 
 import pytest
 
 from airflow.models import Connection
-from airflow.providers.google.cloud.triggers.cloud_composer import 
CloudComposerAirflowCLICommandTrigger
+from airflow.providers.google.cloud.triggers.cloud_composer import (
+    CloudComposerAirflowCLICommandTrigger,
+    CloudComposerDAGRunTrigger,
+)
 from airflow.triggers.base import TriggerEvent
 
 TEST_PROJECT_ID = "test-project-id"
@@ -34,6 +38,10 @@ TEST_EXEC_CMD_INFO = {
     "pod_namespace": "test_namespace",
     "error": "test_error",
 }
+TEST_COMPOSER_DAG_ID = "test_dag_id"
+TEST_START_DATE = datetime(2024, 3, 22, 11, 0, 0)
+TEST_END_DATE = datetime(2024, 3, 22, 12, 0, 0)
+TEST_STATES = ["success"]
 TEST_GCP_CONN_ID = "test_gcp_conn_id"
 TEST_POLL_INTERVAL = 10
 TEST_IMPERSONATION_CHAIN = "test_impersonation_chain"
@@ -49,7 +57,7 @@ TEST_EXEC_RESULT = {
     
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
     return_value=Connection(conn_id="test_conn"),
 )
-def trigger(mock_conn):
+def cli_command_trigger(mock_conn):
     return CloudComposerAirflowCLICommandTrigger(
         project_id=TEST_PROJECT_ID,
         region=TEST_LOCATION,
@@ -61,9 +69,29 @@ def trigger(mock_conn):
     )
 
 
+@pytest.fixture
+@mock.patch(
+    
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
+    return_value=Connection(conn_id="test_conn"),
+)
+def dag_run_trigger(mock_conn):
+    return CloudComposerDAGRunTrigger(
+        project_id=TEST_PROJECT_ID,
+        region=TEST_LOCATION,
+        environment_id=TEST_ENVIRONMENT_ID,
+        composer_dag_id=TEST_COMPOSER_DAG_ID,
+        start_date=TEST_START_DATE,
+        end_date=TEST_END_DATE,
+        allowed_states=TEST_STATES,
+        gcp_conn_id=TEST_GCP_CONN_ID,
+        impersonation_chain=TEST_IMPERSONATION_CHAIN,
+        poll_interval=TEST_POLL_INTERVAL,
+    )
+
+
 class TestCloudComposerAirflowCLICommandTrigger:
-    def test_serialize(self, trigger):
-        actual_data = trigger.serialize()
+    def test_serialize(self, cli_command_trigger):
+        actual_data = cli_command_trigger.serialize()
         expected_data = (
             
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerAirflowCLICommandTrigger",
             {
@@ -82,7 +110,7 @@ class TestCloudComposerAirflowCLICommandTrigger:
     @mock.patch(
         
"airflow.providers.google.cloud.hooks.cloud_composer.CloudComposerAsyncHook.wait_command_execution_result"
     )
-    async def test_run(self, mock_exec_result, trigger):
+    async def test_run(self, mock_exec_result, cli_command_trigger):
         mock_exec_result.return_value = TEST_EXEC_RESULT
 
         expected_event = TriggerEvent(
@@ -91,6 +119,27 @@ class TestCloudComposerAirflowCLICommandTrigger:
                 "result": TEST_EXEC_RESULT,
             }
         )
-        actual_event = await trigger.run().asend(None)
+        actual_event = await cli_command_trigger.run().asend(None)
 
         assert actual_event == expected_event
+
+
+class TestCloudComposerDAGRunTrigger:
+    def test_serialize(self, dag_run_trigger):
+        actual_data = dag_run_trigger.serialize()
+        expected_data = (
+            
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger",
+            {
+                "project_id": TEST_PROJECT_ID,
+                "region": TEST_LOCATION,
+                "environment_id": TEST_ENVIRONMENT_ID,
+                "composer_dag_id": TEST_COMPOSER_DAG_ID,
+                "start_date": TEST_START_DATE,
+                "end_date": TEST_END_DATE,
+                "allowed_states": TEST_STATES,
+                "gcp_conn_id": TEST_GCP_CONN_ID,
+                "impersonation_chain": TEST_IMPERSONATION_CHAIN,
+                "poll_interval": TEST_POLL_INTERVAL,
+            },
+        )
+        assert actual_data == expected_data
diff --git 
a/tests/system/providers/google/cloud/composer/example_cloud_composer.py 
b/tests/system/providers/google/cloud/composer/example_cloud_composer.py
index fe60c56ddf..52404fa375 100644
--- a/tests/system/providers/google/cloud/composer/example_cloud_composer.py
+++ b/tests/system/providers/google/cloud/composer/example_cloud_composer.py
@@ -31,6 +31,7 @@ from airflow.providers.google.cloud.operators.cloud_composer 
import (
     CloudComposerRunAirflowCLICommandOperator,
     CloudComposerUpdateEnvironmentOperator,
 )
+from airflow.providers.google.cloud.sensors.cloud_composer import 
CloudComposerDAGRunSensor
 from airflow.utils.trigger_rule import TriggerRule
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
@@ -158,6 +159,29 @@ with DAG(
     )
     # [END howto_operator_run_airflow_cli_command_deferrable_mode]
 
+    # [START howto_sensor_dag_run]
+    dag_run_sensor = CloudComposerDAGRunSensor(
+        task_id="dag_run_sensor",
+        project_id=PROJECT_ID,
+        region=REGION,
+        environment_id=ENVIRONMENT_ID,
+        composer_dag_id="airflow_monitoring",
+        allowed_states=["success"],
+    )
+    # [END howto_sensor_dag_run]
+
+    # [START howto_sensor_dag_run_deferrable_mode]
+    defer_dag_run_sensor = CloudComposerDAGRunSensor(
+        task_id="defer_dag_run_sensor",
+        project_id=PROJECT_ID,
+        region=REGION,
+        environment_id=ENVIRONMENT_ID_ASYNC,
+        composer_dag_id="airflow_monitoring",
+        allowed_states=["success"],
+        deferrable=True,
+    )
+    # [END howto_sensor_dag_run_deferrable_mode]
+
     # [START howto_operator_delete_composer_environment]
     delete_env = CloudComposerDeleteEnvironmentOperator(
         task_id="delete_env",
@@ -186,6 +210,7 @@ with DAG(
         get_env,
         [update_env, defer_update_env],
         [run_airflow_cli_cmd, defer_run_airflow_cli_cmd],
+        [dag_run_sensor, defer_dag_run_sensor],
         [delete_env, defer_delete_env],
     )
 

Reply via email to