This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 29c6504f24f Add direct-to-triggerer support for 
DataprocSubmitJobOperator #50563 (#62331)
29c6504f24f is described below

commit 29c6504f24f8502d1dec24ef2d38230e0608cddf
Author: Haseeb Malik <[email protected]>
AuthorDate: Wed Mar 11 09:24:19 2026 -0400

    Add direct-to-triggerer support for DataprocSubmitJobOperator #50563 
(#62331)
---
 .../providers/google/cloud/operators/dataproc.py   |  32 +++++
 .../providers/google/cloud/triggers/dataproc.py    | 149 ++++++++++++++++++++
 .../example_dataproc_start_from_trigger.py         | 132 ++++++++++++++++++
 .../unit/google/cloud/operators/test_dataproc.py   |  55 ++++++++
 .../unit/google/cloud/triggers/test_dataproc.py    | 154 +++++++++++++++++++++
 5 files changed, 522 insertions(+)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py 
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index 6dc6e0ab775..38ac0ff8fc4 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -63,6 +63,7 @@ from airflow.providers.google.cloud.triggers.dataproc import (
 )
 from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
 from airflow.providers.google.common.hooks.base_google import 
PROVIDE_PROJECT_ID
+from airflow.triggers.base import StartTriggerArgs
 
 if TYPE_CHECKING:
     from google.api_core import operation
@@ -1880,6 +1881,8 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
         The value is considered only when running in deferrable mode. Must be 
greater than 0.
     :param cancel_on_kill: Flag which indicates whether cancel the hook's job 
or not, when on_kill is called
     :param wait_timeout: How many seconds wait for job to be ready. Used only 
if ``asynchronous`` is False
+    :param start_from_trigger: If True and deferrable is True, the operator 
will start directly
+        from the triggerer without occupying a worker slot.
     """
 
     template_fields: Sequence[str] = (
@@ -1894,6 +1897,15 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
 
     operator_extra_links = (DataprocJobLink(),)
 
+    start_trigger_args = StartTriggerArgs(
+        
trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
+        trigger_kwargs={},
+        next_method="execute_complete",
+        next_kwargs=None,
+        timeout=None,
+    )
+    start_from_trigger = False
+
     def __init__(
         self,
         *,
@@ -1911,6 +1923,7 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
         polling_interval_seconds: int = 10,
         cancel_on_kill: bool = True,
         wait_timeout: int | None = None,
+        start_from_trigger: bool = False,
         openlineage_inject_parent_job_info: bool = conf.getboolean(
             "openlineage", "spark_inject_parent_job_info", fallback=False
         ),
@@ -1938,9 +1951,28 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
         self.hook: DataprocHook | None = None
         self.job_id: str | None = None
         self.wait_timeout = wait_timeout
+        self.start_from_trigger = start_from_trigger
         self.openlineage_inject_parent_job_info = 
openlineage_inject_parent_job_info
         self.openlineage_inject_transport_info = 
openlineage_inject_transport_info
 
+        if self.deferrable and self.start_from_trigger:
+            self.start_trigger_args = StartTriggerArgs(
+                
trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
+                trigger_kwargs={
+                    "job": self.job,
+                    "project_id": self.project_id,
+                    "region": self.region,
+                    "gcp_conn_id": self.gcp_conn_id,
+                    "impersonation_chain": self.impersonation_chain,
+                    "polling_interval_seconds": self.polling_interval_seconds,
+                    "cancel_on_kill": self.cancel_on_kill,
+                    "request_id": self.request_id,
+                },
+                next_method="execute_complete",
+                next_kwargs=None,
+                timeout=None,
+            )
+
     def execute(self, context: Context):
         self.log.info("Submitting job")
         self.hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py 
b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
index 73dd18c4c29..76aa2f5021f 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
@@ -230,6 +230,155 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
                 raise e
 
 
+class DataprocSubmitJobDirectTrigger(DataprocBaseTrigger):
+    """
+    Trigger that submits a Dataproc job and polls for its completion.
+
+    Used for direct-to-triggerer functionality where job submission and polling
+    are handled entirely by the triggerer without requiring a worker.
+
+    :param job: The job resource dict to submit.
+    :param project_id: Google Cloud Project where the job is running.
+    :param region: The Cloud Dataproc region in which to handle the request.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud 
Platform.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term credentials.
+    :param polling_interval_seconds: Polling period in seconds to check for 
the status.
+    :param cancel_on_kill: Flag indicating whether to cancel the job when 
on_kill is called.
+    :param request_id: Optional unique id used to identify the request.
+    """
+
+    def __init__(
+        self,
+        job: dict,
+        request_id: str | None = None,
+        **kwargs,
+    ):
+        self.job = job
+        self.request_id = request_id
+        self.job_id: str | None = None
+        super().__init__(**kwargs)
+
+    def serialize(self):
+        return (
+            
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
+            {
+                "job": self.job,
+                "request_id": self.request_id,
+                "project_id": self.project_id,
+                "region": self.region,
+                "gcp_conn_id": self.gcp_conn_id,
+                "impersonation_chain": self.impersonation_chain,
+                "polling_interval_seconds": self.polling_interval_seconds,
+                "cancel_on_kill": self.cancel_on_kill,
+            },
+        )
+
+    if not AIRFLOW_V_3_0_PLUS:
+
+        @provide_session
+        def get_task_instance(self, session: Session) -> TaskInstance:
+            """
+            Get the task instance for the current task.
+
+            :param session: Sqlalchemy session
+            """
+            task_instance = session.scalar(
+                select(TaskInstance).where(
+                    TaskInstance.dag_id == self.task_instance.dag_id,
+                    TaskInstance.task_id == self.task_instance.task_id,
+                    TaskInstance.run_id == self.task_instance.run_id,
+                    TaskInstance.map_index == self.task_instance.map_index,
+                )
+            )
+            if task_instance is None:
+                raise RuntimeError(
+                    "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and 
map_index: %s is not found",
+                    self.task_instance.dag_id,
+                    self.task_instance.task_id,
+                    self.task_instance.run_id,
+                    self.task_instance.map_index,
+                )
+            return task_instance
+
+    async def get_task_state(self):
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+            dag_id=self.task_instance.dag_id,
+            task_ids=[self.task_instance.task_id],
+            run_ids=[self.task_instance.run_id],
+            map_index=self.task_instance.map_index,
+        )
+        try:
+            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+        except Exception:
+            raise RuntimeError(
+                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
+                self.task_instance.dag_id,
+                self.task_instance.task_id,
+                self.task_instance.run_id,
+                self.task_instance.map_index,
+            )
+        return task_state
+
+    async def safe_to_cancel(self) -> bool:
+        """
+        Whether it is safe to cancel the external job which is being executed 
by this trigger.
+
+        This is to avoid the case that `asyncio.CancelledError` is called 
because the trigger itself is stopped.
+        Because in those cases, we should NOT cancel the external job.
+        """
+        if AIRFLOW_V_3_0_PLUS:
+            task_state = await self.get_task_state()
+        else:
+            task_instance = self.get_task_instance()  # type: ignore[call-arg]
+            task_state = task_instance.state
+        return task_state != TaskInstanceState.DEFERRED
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        try:
+            hook = self.get_async_hook()
+            self.log.info("Submitting Dataproc job.")
+            job_object = await hook.submit_job(
+                project_id=self.project_id,
+                region=self.region,
+                job=self.job,
+                request_id=self.request_id,
+            )
+            self.job_id = job_object.reference.job_id
+            self.log.info("Dataproc job %s submitted successfully.", 
self.job_id)
+
+            while True:
+                job = await hook.get_job(project_id=self.project_id, 
region=self.region, job_id=self.job_id)
+                state = job.status.state
+                self.log.info("Dataproc job: %s is in state: %s", self.job_id, 
state)
+                if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, 
JobStatus.State.ERROR):
+                    break
+                await asyncio.sleep(self.polling_interval_seconds)
+
+            yield TriggerEvent(
+                {"job_id": self.job_id, "job_state": 
JobStatus.State(state).name, "job": Job.to_dict(job)}
+            )
+        except asyncio.CancelledError:
+            self.log.info("Task got cancelled.")
+            try:
+                if self.job_id and self.cancel_on_kill and await 
self.safe_to_cancel():
+                    self.log.info("Cancelling the job: %s", self.job_id)
+                    self.get_sync_hook().cancel_job(
+                        job_id=self.job_id, project_id=self.project_id, 
region=self.region
+                    )
+                    self.log.info("Job: %s is cancelled", self.job_id)
+                    yield TriggerEvent(
+                        {
+                            "job_id": self.job_id,
+                            "job_state": ClusterStatus.State.DELETING.name,  # 
type: ignore[attr-defined]
+                        }
+                    )
+            except Exception as e:
+                self.log.error("Failed to cancel the job: %s with error : %s", 
self.job_id, str(e))
+                raise e
+
+
 class DataprocClusterTrigger(DataprocBaseTrigger):
     """
     DataprocClusterTrigger run on the trigger worker to perform create Build 
operation.
diff --git 
a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py
 
b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py
new file mode 100644
index 00000000000..23f1d93cd9e
--- /dev/null
+++ 
b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py
@@ -0,0 +1,132 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocSubmitJobOperator with start_from_trigger.
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from google.api_core.retry import Retry
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.dataproc import (
+    DataprocCreateClusterOperator,
+    DataprocDeleteClusterOperator,
+    DataprocSubmitJobOperator,
+)
+
+try:
+    from airflow.sdk import TriggerRule
+except ImportError:
+    # Compatibility for Airflow < 3.1
+    from airflow.utils.trigger_rule import TriggerRule  # type: 
ignore[no-redef,attr-defined]
+
+from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+DAG_ID = "dataproc_start_from_trigger"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or 
DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
+
+CLUSTER_NAME_BASE = f"cluster-{DAG_ID}".replace("_", "-")
+CLUSTER_NAME_FULL = CLUSTER_NAME_BASE + f"-{ENV_ID}".replace("_", "-")
+CLUSTER_NAME = CLUSTER_NAME_BASE if len(CLUSTER_NAME_FULL) >= 33 else 
CLUSTER_NAME_FULL
+
+REGION = "europe-west1"
+
+# Cluster definition
+CLUSTER_CONFIG = {
+    "master_config": {
+        "num_instances": 1,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
32},
+    },
+    "worker_config": {
+        "num_instances": 2,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
32},
+    },
+}
+
+# Jobs definitions
+SPARK_JOB = {
+    "reference": {"project_id": PROJECT_ID},
+    "placement": {"cluster_name": CLUSTER_NAME},
+    "spark_job": {
+        "jar_file_uris": 
["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
+        "main_class": "org.apache.spark.examples.SparkPi",
+    },
+}
+
+
+with DAG(
+    DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    catchup=False,
+    tags=["example", "dataproc", "start_from_trigger"],
+) as dag:
+    create_cluster = DataprocCreateClusterOperator(
+        task_id="create_cluster",
+        project_id=PROJECT_ID,
+        cluster_config=CLUSTER_CONFIG,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+        retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
+        num_retries_if_resource_is_not_ready=3,
+    )
+
+    spark_task = DataprocSubmitJobOperator(
+        task_id="spark_task",
+        job=SPARK_JOB,
+        region=REGION,
+        project_id=PROJECT_ID,
+        deferrable=True,
+        start_from_trigger=True,
+    )
+
+    delete_cluster = DataprocDeleteClusterOperator(
+        task_id="delete_cluster",
+        project_id=PROJECT_ID,
+        cluster_name=CLUSTER_NAME,
+        region=REGION,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    (
+        # TEST SETUP
+        create_cluster
+        # TEST BODY
+        >> spark_task
+        # TEST TEARDOWN
+        >> delete_cluster
+    )
+
+    from tests_common.test_utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "teardown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+
+from tests_common.test_utils.system_tests import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
contributing-docs/testing/system_tests.rst)
+test_run = get_test_run(dag)
diff --git 
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py 
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index d284c875548..edb2bbe6c1c 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -2248,6 +2248,61 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
                 impersonation_chain=IMPERSONATION_CHAIN,
             )
 
+    def test_start_from_trigger_default_false(self):
+        op = DataprocSubmitJobOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            job={},
+            gcp_conn_id=GCP_CONN_ID,
+        )
+        assert op.start_from_trigger is False
+
+    def test_start_from_trigger_sets_start_trigger_args(self):
+        job = {"placement": {"cluster_name": "test-cluster"}}
+        op = DataprocSubmitJobOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            job=job,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            deferrable=True,
+            start_from_trigger=True,
+            polling_interval_seconds=15,
+            cancel_on_kill=False,
+            request_id=REQUEST_ID,
+        )
+        assert op.start_from_trigger is True
+        assert (
+            op.start_trigger_args.trigger_cls
+            == 
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger"
+        )
+        assert op.start_trigger_args.trigger_kwargs == {
+            "job": job,
+            "project_id": GCP_PROJECT,
+            "region": GCP_REGION,
+            "gcp_conn_id": GCP_CONN_ID,
+            "impersonation_chain": IMPERSONATION_CHAIN,
+            "polling_interval_seconds": 15,
+            "cancel_on_kill": False,
+            "request_id": REQUEST_ID,
+        }
+        assert op.start_trigger_args.next_method == "execute_complete"
+
+    def test_start_from_trigger_without_deferrable_does_not_set_args(self):
+        op = DataprocSubmitJobOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            job={},
+            gcp_conn_id=GCP_CONN_ID,
+            deferrable=False,
+            start_from_trigger=True,
+        )
+        assert op.start_from_trigger is True
+        assert op.start_trigger_args.trigger_kwargs == {}
+
 
 @pytest.mark.db_test
 @pytest.mark.need_serialized_dag
diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py 
b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
index aa66d2237ed..ed3e5c08f36 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
@@ -31,6 +31,7 @@ from airflow.providers.google.cloud.triggers.dataproc import (
     DataprocBatchTrigger,
     DataprocClusterTrigger,
     DataprocOperationTrigger,
+    DataprocSubmitJobDirectTrigger,
     DataprocSubmitTrigger,
 )
 from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
@@ -138,6 +139,26 @@ def submit_trigger():
     )
 
 
+TEST_JOB = {
+    "placement": {"cluster_name": "test-cluster"},
+    "pyspark_job": {"main_python_file_uri": "gs://test"},
+}
+TEST_REQUEST_ID = "test-request-id"
+
+
[email protected]
+def submit_job_direct_trigger():
+    return DataprocSubmitJobDirectTrigger(
+        job=TEST_JOB,
+        request_id=TEST_REQUEST_ID,
+        project_id=TEST_PROJECT_ID,
+        region=TEST_REGION,
+        gcp_conn_id=TEST_GCP_CONN_ID,
+        polling_interval_seconds=TEST_POLL_INTERVAL,
+        cancel_on_kill=True,
+    )
+
+
 @pytest.fixture
 def async_get_batch():
     def func(**kwargs):
@@ -661,3 +682,136 @@ class TestDataprocSubmitTrigger:
 
         # Clean up the generator
         await async_gen.aclose()
+
+
+class TestDataprocSubmitJobDirectTrigger:
+    def test_serialization(self, submit_job_direct_trigger):
+        classpath, kwargs = submit_job_direct_trigger.serialize()
+        assert classpath == 
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger"
+        assert kwargs == {
+            "job": TEST_JOB,
+            "request_id": TEST_REQUEST_ID,
+            "project_id": TEST_PROJECT_ID,
+            "region": TEST_REGION,
+            "gcp_conn_id": TEST_GCP_CONN_ID,
+            "polling_interval_seconds": TEST_POLL_INTERVAL,
+            "cancel_on_kill": True,
+            "impersonation_chain": None,
+        }
+
+    @pytest.mark.asyncio
+    @mock.patch(
+        
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_async_hook"
+    )
+    async def test_run_submits_and_polls_success(self, mock_get_async_hook, 
submit_job_direct_trigger):
+        mock_hook = mock_get_async_hook.return_value
+
+        mock_submitted_job = mock.MagicMock()
+        mock_submitted_job.reference.job_id = TEST_JOB_ID
+        submit_future = asyncio.Future()
+        submit_future.set_result(mock_submitted_job)
+        mock_hook.submit_job.return_value = submit_future
+
+        mock_done_job = Job(status=JobStatus(state=JobStatus.State.DONE))
+        get_future = asyncio.Future()
+        get_future.set_result(mock_done_job)
+        mock_hook.get_job.return_value = get_future
+
+        async_gen = submit_job_direct_trigger.run()
+        event = await async_gen.asend(None)
+
+        expected_event = TriggerEvent(
+            {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE.name, 
"job": Job.to_dict(mock_done_job)}
+        )
+        assert event.payload == expected_event.payload
+
+        mock_hook.submit_job.assert_called_once_with(
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            job=TEST_JOB,
+            request_id=TEST_REQUEST_ID,
+        )
+
+    @pytest.mark.asyncio
+    @mock.patch(
+        
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_async_hook"
+    )
+    async def test_run_submits_and_polls_error(self, mock_get_async_hook, 
submit_job_direct_trigger):
+        mock_hook = mock_get_async_hook.return_value
+
+        mock_submitted_job = mock.MagicMock()
+        mock_submitted_job.reference.job_id = TEST_JOB_ID
+        submit_future = asyncio.Future()
+        submit_future.set_result(mock_submitted_job)
+        mock_hook.submit_job.return_value = submit_future
+
+        mock_error_job = Job(status=JobStatus(state=JobStatus.State.ERROR))
+        get_future = asyncio.Future()
+        get_future.set_result(mock_error_job)
+        mock_hook.get_job.return_value = get_future
+
+        async_gen = submit_job_direct_trigger.run()
+        event = await async_gen.asend(None)
+
+        expected_event = TriggerEvent(
+            {
+                "job_id": TEST_JOB_ID,
+                "job_state": JobStatus.State.ERROR.name,
+                "job": Job.to_dict(mock_error_job),
+            }
+        )
+        assert event.payload == expected_event.payload
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize("is_safe_to_cancel", [True, False])
+    @mock.patch(
+        
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_async_hook"
+    )
+    @mock.patch(
+        
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_sync_hook"
+    )
+    @mock.patch(
+        
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.safe_to_cancel"
+    )
+    async def test_run_cancelled_after_submit(
+        self,
+        mock_safe_to_cancel,
+        mock_get_sync_hook,
+        mock_get_async_hook,
+        submit_job_direct_trigger,
+        is_safe_to_cancel,
+    ):
+        mock_safe_to_cancel.return_value = is_safe_to_cancel
+        mock_hook = mock_get_async_hook.return_value
+
+        mock_submitted_job = mock.MagicMock()
+        mock_submitted_job.reference.job_id = TEST_JOB_ID
+        submit_future = asyncio.Future()
+        submit_future.set_result(mock_submitted_job)
+        mock_hook.submit_job.return_value = submit_future
+
+        mock_hook.get_job.side_effect = asyncio.CancelledError
+
+        mock_sync_hook = mock_get_sync_hook.return_value
+        mock_sync_hook.cancel_job = mock.MagicMock()
+
+        async_gen = submit_job_direct_trigger.run()
+
+        try:
+            await async_gen.asend(None)
+            await async_gen.asend(None)
+        except (asyncio.CancelledError, StopAsyncIteration):
+            pass
+        except Exception as e:
+            pytest.fail(f"Unexpected exception raised: {e}")
+
+        if submit_job_direct_trigger.cancel_on_kill and is_safe_to_cancel:
+            mock_sync_hook.cancel_job.assert_called_once_with(
+                job_id=TEST_JOB_ID,
+                project_id=submit_job_direct_trigger.project_id,
+                region=submit_job_direct_trigger.region,
+            )
+        else:
+            mock_sync_hook.cancel_job.assert_not_called()
+
+        await async_gen.aclose()

Reply via email to