This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 29c6504f24f Add direct-to-triggerer support for
DataprocSubmitJobOperator #50563 (#62331)
29c6504f24f is described below
commit 29c6504f24f8502d1dec24ef2d38230e0608cddf
Author: Haseeb Malik <[email protected]>
AuthorDate: Wed Mar 11 09:24:19 2026 -0400
Add direct-to-triggerer support for DataprocSubmitJobOperator #50563
(#62331)
---
.../providers/google/cloud/operators/dataproc.py | 32 +++++
.../providers/google/cloud/triggers/dataproc.py | 149 ++++++++++++++++++++
.../example_dataproc_start_from_trigger.py | 132 ++++++++++++++++++
.../unit/google/cloud/operators/test_dataproc.py | 55 ++++++++
.../unit/google/cloud/triggers/test_dataproc.py | 154 +++++++++++++++++++++
5 files changed, 522 insertions(+)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index 6dc6e0ab775..38ac0ff8fc4 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -63,6 +63,7 @@ from airflow.providers.google.cloud.triggers.dataproc import (
)
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
+from airflow.triggers.base import StartTriggerArgs
if TYPE_CHECKING:
from google.api_core import operation
@@ -1880,6 +1881,8 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
The value is considered only when running in deferrable mode. Must be
greater than 0.
:param cancel_on_kill: Flag which indicates whether cancel the hook's job
or not, when on_kill is called
:param wait_timeout: How many seconds wait for job to be ready. Used only
if ``asynchronous`` is False
+ :param start_from_trigger: If True and deferrable is True, the operator
will start directly
+ from the triggerer without occupying a worker slot.
"""
template_fields: Sequence[str] = (
@@ -1894,6 +1897,15 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
operator_extra_links = (DataprocJobLink(),)
+ start_trigger_args = StartTriggerArgs(
+
trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
+ trigger_kwargs={},
+ next_method="execute_complete",
+ next_kwargs=None,
+ timeout=None,
+ )
+ start_from_trigger = False
+
def __init__(
self,
*,
@@ -1911,6 +1923,7 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
polling_interval_seconds: int = 10,
cancel_on_kill: bool = True,
wait_timeout: int | None = None,
+ start_from_trigger: bool = False,
openlineage_inject_parent_job_info: bool = conf.getboolean(
"openlineage", "spark_inject_parent_job_info", fallback=False
),
@@ -1938,9 +1951,28 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
self.hook: DataprocHook | None = None
self.job_id: str | None = None
self.wait_timeout = wait_timeout
+ self.start_from_trigger = start_from_trigger
self.openlineage_inject_parent_job_info =
openlineage_inject_parent_job_info
self.openlineage_inject_transport_info =
openlineage_inject_transport_info
+ if self.deferrable and self.start_from_trigger:
+ self.start_trigger_args = StartTriggerArgs(
+
trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
+ trigger_kwargs={
+ "job": self.job,
+ "project_id": self.project_id,
+ "region": self.region,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "polling_interval_seconds": self.polling_interval_seconds,
+ "cancel_on_kill": self.cancel_on_kill,
+ "request_id": self.request_id,
+ },
+ next_method="execute_complete",
+ next_kwargs=None,
+ timeout=None,
+ )
+
def execute(self, context: Context):
self.log.info("Submitting job")
self.hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
index 73dd18c4c29..76aa2f5021f 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
@@ -230,6 +230,155 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
raise e
+class DataprocSubmitJobDirectTrigger(DataprocBaseTrigger):
+ """
+ Trigger that submits a Dataproc job and polls for its completion.
+
+ Used for direct-to-triggerer functionality where job submission and polling
+ are handled entirely by the triggerer without requiring a worker.
+
+ :param job: The job resource dict to submit.
+ :param project_id: Google Cloud Project where the job is running.
+ :param region: The Cloud Dataproc region in which to handle the request.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud
Platform.
+ :param impersonation_chain: Optional service account to impersonate using
short-term credentials.
+ :param polling_interval_seconds: Polling period in seconds to check for
the status.
+ :param cancel_on_kill: Flag indicating whether to cancel the job when
on_kill is called.
+ :param request_id: Optional unique id used to identify the request.
+ """
+
+ def __init__(
+ self,
+ job: dict,
+ request_id: str | None = None,
+ **kwargs,
+ ):
+ self.job = job
+ self.request_id = request_id
+ self.job_id: str | None = None
+ super().__init__(**kwargs)
+
+ def serialize(self):
+ return (
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
+ {
+ "job": self.job,
+ "request_id": self.request_id,
+ "project_id": self.project_id,
+ "region": self.region,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "polling_interval_seconds": self.polling_interval_seconds,
+ "cancel_on_kill": self.cancel_on_kill,
+ },
+ )
+
+ if not AIRFLOW_V_3_0_PLUS:
+
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """
+ Get the task instance for the current task.
+
+ :param session: Sqlalchemy session
+ """
+ task_instance = session.scalar(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ )
+ if task_instance is None:
+ raise RuntimeError(
+ "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and
map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_instance
+
+ async def get_task_state(self):
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise RuntimeError(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_state
+
+ async def safe_to_cancel(self) -> bool:
+ """
+ Whether it is safe to cancel the external job which is being executed
by this trigger.
+
+ This is to avoid the case that `asyncio.CancelledError` is called
because the trigger itself is stopped.
+ Because in those cases, we should NOT cancel the external job.
+ """
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = await self.get_task_state()
+ else:
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ task_state = task_instance.state
+ return task_state != TaskInstanceState.DEFERRED
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ try:
+ hook = self.get_async_hook()
+ self.log.info("Submitting Dataproc job.")
+ job_object = await hook.submit_job(
+ project_id=self.project_id,
+ region=self.region,
+ job=self.job,
+ request_id=self.request_id,
+ )
+ self.job_id = job_object.reference.job_id
+ self.log.info("Dataproc job %s submitted successfully.",
self.job_id)
+
+ while True:
+ job = await hook.get_job(project_id=self.project_id,
region=self.region, job_id=self.job_id)
+ state = job.status.state
+ self.log.info("Dataproc job: %s is in state: %s", self.job_id,
state)
+ if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED,
JobStatus.State.ERROR):
+ break
+ await asyncio.sleep(self.polling_interval_seconds)
+
+ yield TriggerEvent(
+ {"job_id": self.job_id, "job_state":
JobStatus.State(state).name, "job": Job.to_dict(job)}
+ )
+ except asyncio.CancelledError:
+ self.log.info("Task got cancelled.")
+ try:
+ if self.job_id and self.cancel_on_kill and await
self.safe_to_cancel():
+ self.log.info("Cancelling the job: %s", self.job_id)
+ self.get_sync_hook().cancel_job(
+ job_id=self.job_id, project_id=self.project_id,
region=self.region
+ )
+ self.log.info("Job: %s is cancelled", self.job_id)
+ yield TriggerEvent(
+ {
+ "job_id": self.job_id,
+ "job_state": ClusterStatus.State.DELETING.name, #
type: ignore[attr-defined]
+ }
+ )
+ except Exception as e:
+ self.log.error("Failed to cancel the job: %s with error : %s",
self.job_id, str(e))
+ raise e
+
+
class DataprocClusterTrigger(DataprocBaseTrigger):
"""
DataprocClusterTrigger run on the trigger worker to perform create Build
operation.
diff --git
a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py
b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py
new file mode 100644
index 00000000000..23f1d93cd9e
--- /dev/null
+++
b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py
@@ -0,0 +1,132 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocSubmitJobOperator with start_from_trigger.
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from google.api_core.retry import Retry
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.dataproc import (
+ DataprocCreateClusterOperator,
+ DataprocDeleteClusterOperator,
+ DataprocSubmitJobOperator,
+)
+
+try:
+ from airflow.sdk import TriggerRule
+except ImportError:
+ # Compatibility for Airflow < 3.1
+ from airflow.utils.trigger_rule import TriggerRule # type:
ignore[no-redef,attr-defined]
+
+from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+DAG_ID = "dataproc_start_from_trigger"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or
DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
+
+CLUSTER_NAME_BASE = f"cluster-{DAG_ID}".replace("_", "-")
+CLUSTER_NAME_FULL = CLUSTER_NAME_BASE + f"-{ENV_ID}".replace("_", "-")
+CLUSTER_NAME = CLUSTER_NAME_BASE if len(CLUSTER_NAME_FULL) >= 33 else
CLUSTER_NAME_FULL
+
+REGION = "europe-west1"
+
+# Cluster definition
+CLUSTER_CONFIG = {
+ "master_config": {
+ "num_instances": 1,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
32},
+ },
+ "worker_config": {
+ "num_instances": 2,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
32},
+ },
+}
+
+# Jobs definitions
+SPARK_JOB = {
+ "reference": {"project_id": PROJECT_ID},
+ "placement": {"cluster_name": CLUSTER_NAME},
+ "spark_job": {
+ "jar_file_uris":
["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
+ "main_class": "org.apache.spark.examples.SparkPi",
+ },
+}
+
+
+with DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "dataproc", "start_from_trigger"],
+) as dag:
+ create_cluster = DataprocCreateClusterOperator(
+ task_id="create_cluster",
+ project_id=PROJECT_ID,
+ cluster_config=CLUSTER_CONFIG,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
+ num_retries_if_resource_is_not_ready=3,
+ )
+
+ spark_task = DataprocSubmitJobOperator(
+ task_id="spark_task",
+ job=SPARK_JOB,
+ region=REGION,
+ project_id=PROJECT_ID,
+ deferrable=True,
+ start_from_trigger=True,
+ )
+
+ delete_cluster = DataprocDeleteClusterOperator(
+ task_id="delete_cluster",
+ project_id=PROJECT_ID,
+ cluster_name=CLUSTER_NAME,
+ region=REGION,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ create_cluster
+ # TEST BODY
+ >> spark_task
+ # TEST TEARDOWN
+ >> delete_cluster
+ )
+
+ from tests_common.test_utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests_common.test_utils.system_tests import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
contributing-docs/testing/system_tests.rst)
+test_run = get_test_run(dag)
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index d284c875548..edb2bbe6c1c 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -2248,6 +2248,61 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
impersonation_chain=IMPERSONATION_CHAIN,
)
+ def test_start_from_trigger_default_false(self):
+ op = DataprocSubmitJobOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ job={},
+ gcp_conn_id=GCP_CONN_ID,
+ )
+ assert op.start_from_trigger is False
+
+ def test_start_from_trigger_sets_start_trigger_args(self):
+ job = {"placement": {"cluster_name": "test-cluster"}}
+ op = DataprocSubmitJobOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ job=job,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ deferrable=True,
+ start_from_trigger=True,
+ polling_interval_seconds=15,
+ cancel_on_kill=False,
+ request_id=REQUEST_ID,
+ )
+ assert op.start_from_trigger is True
+ assert (
+ op.start_trigger_args.trigger_cls
+ ==
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger"
+ )
+ assert op.start_trigger_args.trigger_kwargs == {
+ "job": job,
+ "project_id": GCP_PROJECT,
+ "region": GCP_REGION,
+ "gcp_conn_id": GCP_CONN_ID,
+ "impersonation_chain": IMPERSONATION_CHAIN,
+ "polling_interval_seconds": 15,
+ "cancel_on_kill": False,
+ "request_id": REQUEST_ID,
+ }
+ assert op.start_trigger_args.next_method == "execute_complete"
+
+ def test_start_from_trigger_without_deferrable_does_not_set_args(self):
+ op = DataprocSubmitJobOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ job={},
+ gcp_conn_id=GCP_CONN_ID,
+ deferrable=False,
+ start_from_trigger=True,
+ )
+ assert op.start_from_trigger is True
+ assert op.start_trigger_args.trigger_kwargs == {}
+
@pytest.mark.db_test
@pytest.mark.need_serialized_dag
diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
index aa66d2237ed..ed3e5c08f36 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
@@ -31,6 +31,7 @@ from airflow.providers.google.cloud.triggers.dataproc import (
DataprocBatchTrigger,
DataprocClusterTrigger,
DataprocOperationTrigger,
+ DataprocSubmitJobDirectTrigger,
DataprocSubmitTrigger,
)
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
@@ -138,6 +139,26 @@ def submit_trigger():
)
+TEST_JOB = {
+ "placement": {"cluster_name": "test-cluster"},
+ "pyspark_job": {"main_python_file_uri": "gs://test"},
+}
+TEST_REQUEST_ID = "test-request-id"
+
+
[email protected]
+def submit_job_direct_trigger():
+ return DataprocSubmitJobDirectTrigger(
+ job=TEST_JOB,
+ request_id=TEST_REQUEST_ID,
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ polling_interval_seconds=TEST_POLL_INTERVAL,
+ cancel_on_kill=True,
+ )
+
+
@pytest.fixture
def async_get_batch():
def func(**kwargs):
@@ -661,3 +682,136 @@ class TestDataprocSubmitTrigger:
# Clean up the generator
await async_gen.aclose()
+
+
+class TestDataprocSubmitJobDirectTrigger:
+ def test_serialization(self, submit_job_direct_trigger):
+ classpath, kwargs = submit_job_direct_trigger.serialize()
+ assert classpath ==
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger"
+ assert kwargs == {
+ "job": TEST_JOB,
+ "request_id": TEST_REQUEST_ID,
+ "project_id": TEST_PROJECT_ID,
+ "region": TEST_REGION,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ "polling_interval_seconds": TEST_POLL_INTERVAL,
+ "cancel_on_kill": True,
+ "impersonation_chain": None,
+ }
+
+ @pytest.mark.asyncio
+ @mock.patch(
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_async_hook"
+ )
+ async def test_run_submits_and_polls_success(self, mock_get_async_hook,
submit_job_direct_trigger):
+ mock_hook = mock_get_async_hook.return_value
+
+ mock_submitted_job = mock.MagicMock()
+ mock_submitted_job.reference.job_id = TEST_JOB_ID
+ submit_future = asyncio.Future()
+ submit_future.set_result(mock_submitted_job)
+ mock_hook.submit_job.return_value = submit_future
+
+ mock_done_job = Job(status=JobStatus(state=JobStatus.State.DONE))
+ get_future = asyncio.Future()
+ get_future.set_result(mock_done_job)
+ mock_hook.get_job.return_value = get_future
+
+ async_gen = submit_job_direct_trigger.run()
+ event = await async_gen.asend(None)
+
+ expected_event = TriggerEvent(
+ {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE.name,
"job": Job.to_dict(mock_done_job)}
+ )
+ assert event.payload == expected_event.payload
+
+ mock_hook.submit_job.assert_called_once_with(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ job=TEST_JOB,
+ request_id=TEST_REQUEST_ID,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_async_hook"
+ )
+ async def test_run_submits_and_polls_error(self, mock_get_async_hook,
submit_job_direct_trigger):
+ mock_hook = mock_get_async_hook.return_value
+
+ mock_submitted_job = mock.MagicMock()
+ mock_submitted_job.reference.job_id = TEST_JOB_ID
+ submit_future = asyncio.Future()
+ submit_future.set_result(mock_submitted_job)
+ mock_hook.submit_job.return_value = submit_future
+
+ mock_error_job = Job(status=JobStatus(state=JobStatus.State.ERROR))
+ get_future = asyncio.Future()
+ get_future.set_result(mock_error_job)
+ mock_hook.get_job.return_value = get_future
+
+ async_gen = submit_job_direct_trigger.run()
+ event = await async_gen.asend(None)
+
+ expected_event = TriggerEvent(
+ {
+ "job_id": TEST_JOB_ID,
+ "job_state": JobStatus.State.ERROR.name,
+ "job": Job.to_dict(mock_error_job),
+ }
+ )
+ assert event.payload == expected_event.payload
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("is_safe_to_cancel", [True, False])
+ @mock.patch(
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_async_hook"
+ )
+ @mock.patch(
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_sync_hook"
+ )
+ @mock.patch(
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.safe_to_cancel"
+ )
+ async def test_run_cancelled_after_submit(
+ self,
+ mock_safe_to_cancel,
+ mock_get_sync_hook,
+ mock_get_async_hook,
+ submit_job_direct_trigger,
+ is_safe_to_cancel,
+ ):
+ mock_safe_to_cancel.return_value = is_safe_to_cancel
+ mock_hook = mock_get_async_hook.return_value
+
+ mock_submitted_job = mock.MagicMock()
+ mock_submitted_job.reference.job_id = TEST_JOB_ID
+ submit_future = asyncio.Future()
+ submit_future.set_result(mock_submitted_job)
+ mock_hook.submit_job.return_value = submit_future
+
+ mock_hook.get_job.side_effect = asyncio.CancelledError
+
+ mock_sync_hook = mock_get_sync_hook.return_value
+ mock_sync_hook.cancel_job = mock.MagicMock()
+
+ async_gen = submit_job_direct_trigger.run()
+
+ try:
+ await async_gen.asend(None)
+ await async_gen.asend(None)
+ except (asyncio.CancelledError, StopAsyncIteration):
+ pass
+ except Exception as e:
+ pytest.fail(f"Unexpected exception raised: {e}")
+
+ if submit_job_direct_trigger.cancel_on_kill and is_safe_to_cancel:
+ mock_sync_hook.cancel_job.assert_called_once_with(
+ job_id=TEST_JOB_ID,
+ project_id=submit_job_direct_trigger.project_id,
+ region=submit_job_direct_trigger.region,
+ )
+ else:
+ mock_sync_hook.cancel_job.assert_not_called()
+
+ await async_gen.aclose()