This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2bdb51550cb Refactor the google cloud DataprocCreateBatchOperator 
tests (#52573)
2bdb51550cb is described below

commit 2bdb51550cb29da5a3dbb03812ace55ebf26b694
Author: olegkachur-e <[email protected]>
AuthorDate: Mon Jun 30 23:14:38 2025 +0200

    Refactor the google cloud DataprocCreateBatchOperator tests (#52573)
    
    - replce un-called method mock
    - add logging checks
    - populate labels checks
    
    Co-authored-by: Oleg Kachur <[email protected]>
---
 .../providers/google/cloud/operators/dataproc.py   |   8 +-
 .../unit/google/cloud/operators/test_dataproc.py   | 204 ++++++++++++---------
 2 files changed, 125 insertions(+), 87 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py 
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index 5ed6923821a..349250150a5 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -2488,13 +2488,13 @@ class 
DataprocCreateBatchOperator(GoogleCloudBaseOperator):
         link = DATAPROC_BATCH_LINK.format(region=self.region, 
project_id=self.project_id, batch_id=batch_id)
         if state == Batch.State.FAILED:
             raise AirflowException(
-                f"Batch job {batch_id} failed with error: 
{state_message}\nDriver Logs: {link}"
+                f"Batch job {batch_id} failed with error: 
{state_message}.\nDriver logs: {link}"
             )
         if state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
-            raise AirflowException(f"Batch job {batch_id} was cancelled. 
Driver logs: {link}")
+            raise AirflowException(f"Batch job {batch_id} was 
cancelled.\nDriver logs: {link}")
         if state == Batch.State.STATE_UNSPECIFIED:
-            raise AirflowException(f"Batch job {batch_id} unspecified. Driver 
logs: {link}")
-        self.log.info("Batch job %s completed. Driver logs: %s", batch_id, 
link)
+            raise AirflowException(f"Batch job {batch_id} unspecified.\nDriver 
logs: {link}")
+        self.log.info("Batch job %s completed.\nDriver logs: %s", batch_id, 
link)
 
     def retry_batch_creation(
         self,
diff --git 
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py 
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index 00923cb590a..c76c7046db1 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -39,6 +39,7 @@ from airflow.exceptions import (
 )
 from airflow.models import DAG, DagBag
 from airflow.providers.google.cloud.links.dataproc import (
+    DATAPROC_BATCH_LINK,
     DATAPROC_CLUSTER_LINK_DEPRECATED,
     DATAPROC_JOB_LINK_DEPRECATED,
 )
@@ -353,6 +354,12 @@ DEFAULT_DATE = datetime(2020, 1, 1)
 TEST_JOB_ID = "test-job"
 TEST_WORKFLOW_ID = "test-workflow"
 
+EXPECTED_LABELS = {
+    "airflow-dag-id": TEST_DAG_ID,
+    "airflow-dag-display-name": TEST_DAG_ID,
+    "airflow-task-id": TASK_ID,
+}
+
 DATAPROC_JOB_LINK_EXPECTED = (
     
f"https://console.cloud.google.com/dataproc/jobs/{TEST_JOB_ID}?region={GCP_REGION}&project={GCP_PROJECT}";
 )
@@ -3187,9 +3194,10 @@ class TestDataprocCreateWorkflowTemplateOperator:
 
 
 class TestDataprocCreateBatchOperator:
+    @mock.patch.object(DataprocCreateBatchOperator, "log", 
new_callable=mock.MagicMock)
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute(self, mock_hook, to_dict_mock):
+    def test_execute(self, mock_hook, to_dict_mock, mock_log):
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3203,7 +3211,10 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.create_batch.return_value.metadata.batch = 
f"prefix/{BATCH_ID}"
+        batch_state_succeeded = Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
batch_state_succeeded
+
         op.execute(context=MagicMock())
         mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, 
impersonation_chain=IMPERSONATION_CHAIN)
         mock_hook.return_value.create_batch.assert_called_once_with(
@@ -3216,6 +3227,16 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+        to_dict_mock.assert_called_once_with(batch_state_succeeded)
+        logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, 
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+        mock_log.info.assert_has_calls(
+            [
+                mock.call("Starting batch %s", BATCH_ID),
+                mock.call("The batch %s was created.", BATCH_ID),
+                mock.call("Waiting for the completion of batch job %s", 
BATCH_ID),
+                mock.call("Batch job %s completed.\nDriver logs: %s", 
BATCH_ID, logs_link),
+            ]
+        )
 
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -3234,7 +3255,7 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.SUCCEEDED)
         op.execute(context=MagicMock())
         mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, 
impersonation_chain=IMPERSONATION_CHAIN)
         mock_hook.return_value.create_batch.assert_called_once_with(
@@ -3268,8 +3289,9 @@ class TestDataprocCreateBatchOperator:
         with pytest.raises(AirflowException):
             op.execute(context=MagicMock())
 
+    @mock.patch.object(DataprocCreateBatchOperator, "log", 
new_callable=mock.MagicMock)
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute_batch_already_exists_succeeds(self, mock_hook):
+    def test_execute_batch_already_exists_succeeds(self, mock_hook, mock_log):
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3283,9 +3305,10 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
-        mock_hook.return_value.wait_for_operation.side_effect = 
AlreadyExists("")
-        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
         mock_hook.return_value.create_batch.return_value.metadata.batch = 
f"prefix/{BATCH_ID}"
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+
         op.execute(context=MagicMock())
         mock_hook.return_value.wait_for_batch.assert_called_once_with(
             batch_id=BATCH_ID,
@@ -3295,9 +3318,23 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+        # Check for succeeded run
+        logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, 
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+
+        mock_log.info.assert_has_calls(
+            [
+                mock.call(
+                    "Batch with given id already exists.",
+                ),
+                mock.call("Attaching to the job %s if it is still running.", 
BATCH_ID),
+                mock.call("Waiting for the completion of batch job %s", 
BATCH_ID),
+                mock.call("Batch job %s completed.\nDriver logs: %s", 
BATCH_ID, logs_link),
+            ]
+        )
 
+    @mock.patch.object(DataprocCreateBatchOperator, "log", 
new_callable=mock.MagicMock)
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute_batch_already_exists_fails(self, mock_hook):
+    def test_execute_batch_already_exists_fails(self, mock_hook, mock_log):
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3311,11 +3348,15 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
-        mock_hook.return_value.wait_for_operation.side_effect = 
AlreadyExists("")
-        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.FAILED)
+        mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
         mock_hook.return_value.create_batch.return_value.metadata.batch = 
f"prefix/{BATCH_ID}"
-        with pytest.raises(AirflowException):
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.FAILED)
+
+        with pytest.raises(AirflowException) as exc:
             op.execute(context=MagicMock())
+        # Check msg for FAILED batch state
+        logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, 
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+        assert str(exc.value) == (f"Batch job {BATCH_ID} failed with error: 
.\nDriver logs: {logs_link}")
         mock_hook.return_value.wait_for_batch.assert_called_once_with(
             batch_id=BATCH_ID,
             region=GCP_REGION,
@@ -3324,9 +3365,12 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+        # Check logs for AlreadyExists being called
+        mock_log.info.assert_any_call("Batch with given id already exists.")
 
+    @mock.patch.object(DataprocCreateBatchOperator, "log", 
new_callable=mock.MagicMock)
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute_batch_already_exists_cancelled(self, mock_hook):
+    def test_execute_batch_already_exists_cancelled(self, mock_hook, mock_log):
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3340,11 +3384,16 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
-        mock_hook.return_value.wait_for_operation.side_effect = 
AlreadyExists("")
-        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.CANCELLED)
+        mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
         mock_hook.return_value.create_batch.return_value.metadata.batch = 
f"prefix/{BATCH_ID}"
-        with pytest.raises(AirflowException):
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.CANCELLED)
+
+        with pytest.raises(AirflowException) as exc:
             op.execute(context=MagicMock())
+        # Check msg for CANCELLED batch state
+        logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, 
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+        assert str(exc.value) == f"Batch job {BATCH_ID} was cancelled.\nDriver 
logs: {logs_link}"
+
         mock_hook.return_value.wait_for_batch.assert_called_once_with(
             batch_id=BATCH_ID,
             region=GCP_REGION,
@@ -3353,21 +3402,29 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+        # Check logs for AlreadyExists being called
+        mock_log.info.assert_any_call("Batch with given id already exists.")
 
+    @mock.patch.object(DataprocCreateBatchOperator, "log", 
new_callable=mock.MagicMock)
     
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute_openlineage_parent_job_info_injection(
-        self, mock_hook, to_dict_mock, mock_ol_accessible, mock_static_uuid
+        self,
+        mock_hook,
+        to_dict_mock,
+        mock_ol_accessible,
+        mock_static_uuid,
+        mock_log,
     ):
         mock_ol_accessible.return_value = True
         mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
         expected_batch = {
             **BATCH,
+            "labels": EXPECTED_LABELS,
             "runtime_config": {"properties": 
OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES},
         }
-
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3381,9 +3438,13 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
             openlineage_inject_parent_job_info=True,
+            dag=DAG(dag_id=TEST_DAG_ID),
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        batch_state_succeeded = Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
batch_state_succeeded
+        mock_hook.return_value.create_batch.return_value.metadata.batch = 
f"prefix/{BATCH_ID}"
         op.execute(context=EXAMPLE_CONTEXT)
+
         mock_hook.return_value.create_batch.assert_called_once_with(
             region=GCP_REGION,
             project_id=GCP_PROJECT,
@@ -3394,14 +3455,19 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+        to_dict_mock.assert_called_once_with(batch_state_succeeded)
+        logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, 
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+        # Check SUCCEED run from the logs
+        mock_log.info.assert_any_call("Batch job %s completed.\nDriver logs: 
%s", BATCH_ID, logs_link)
 
+    @mock.patch.object(DataprocCreateBatchOperator, "log", 
new_callable=mock.MagicMock)
     
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
     
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute_openlineage_transport_info_injection(
-        self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, 
mock_static_uuid
+        self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, 
mock_static_uuid, mock_log
     ):
         mock_ol_accessible.return_value = True
         mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
@@ -3410,9 +3476,9 @@ class TestDataprocCreateBatchOperator:
         )
         expected_batch = {
             **BATCH,
+            "labels": EXPECTED_LABELS,
             "runtime_config": {"properties": 
OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES},
         }
-
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3426,9 +3492,13 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
             openlineage_inject_transport_info=True,
+            dag=DAG(dag_id=TEST_DAG_ID),
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        batch_state_succeeded = Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
batch_state_succeeded
+        mock_hook.return_value.create_batch.return_value.metadata.batch = 
f"prefix/{BATCH_ID}"
         op.execute(context=EXAMPLE_CONTEXT)
+
         mock_hook.return_value.create_batch.assert_called_once_with(
             region=GCP_REGION,
             project_id=GCP_PROJECT,
@@ -3439,6 +3509,14 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+        to_dict_mock.assert_called_once_with(batch_state_succeeded)
+        logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, 
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+        # Verify logs for successful run
+        mock_log.info.assert_any_call(
+            "Batch job %s completed.\nDriver logs: %s",
+            BATCH_ID,
+            logs_link,
+        )
 
     
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
     
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
@@ -3455,6 +3533,7 @@ class TestDataprocCreateBatchOperator:
         )
         expected_batch = {
             **BATCH,
+            "labels": EXPECTED_LABELS,
             "runtime_config": {
                 "properties": {
                     **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES,
@@ -3462,7 +3541,6 @@ class TestDataprocCreateBatchOperator:
                 }
             },
         }
-
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3477,8 +3555,9 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
             openlineage_inject_parent_job_info=True,
             openlineage_inject_transport_info=True,
+            dag=DAG(dag_id=TEST_DAG_ID),
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.SUCCEEDED)
         op.execute(context=EXAMPLE_CONTEXT)
         mock_hook.return_value.create_batch.assert_called_once_with(
             region=GCP_REGION,
@@ -3498,22 +3577,15 @@ class TestDataprocCreateBatchOperator:
         self, mock_hook, to_dict_mock, mock_ol_accessible
     ):
         mock_ol_accessible.return_value = True
-        expected_labels = {
-            "airflow-dag-id": "test_dag",
-            "airflow-dag-display-name": "test_dag",
-            "airflow-task-id": "task-id",
-        }
-
         batch = {
             **BATCH,
-            "labels": expected_labels,
+            "labels": EXPECTED_LABELS,
             "runtime_config": {
                 "properties": {
                     "spark.openlineage.parentJobName": "dag_id.task_id",
                 }
             },
         }
-
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3527,9 +3599,9 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
             openlineage_inject_parent_job_info=True,
-            dag=DAG(dag_id="test_dag"),
+            dag=DAG(dag_id=TEST_DAG_ID),
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.SUCCEEDED)
         op.execute(context=EXAMPLE_CONTEXT)
         mock_hook.return_value.create_batch.assert_called_once_with(
             region=GCP_REGION,
@@ -3553,23 +3625,15 @@ class TestDataprocCreateBatchOperator:
         
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
 = HttpTransport(
             HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
         )
-
-        expected_labels = {
-            "airflow-dag-id": "test_dag",
-            "airflow-dag-display-name": "test_dag",
-            "airflow-task-id": "task-id",
-        }
-
         batch = {
             **BATCH,
-            "labels": expected_labels,
+            "labels": EXPECTED_LABELS,
             "runtime_config": {
                 "properties": {
                     "spark.openlineage.transport.type": "console",
                 }
             },
         }
-
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3583,7 +3647,7 @@ class TestDataprocCreateBatchOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
             openlineage_inject_transport_info=True,
-            dag=DAG(dag_id="test_dag"),
+            dag=DAG(dag_id=TEST_DAG_ID),
         )
         mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
         op.execute(context=EXAMPLE_CONTEXT)
@@ -3609,7 +3673,6 @@ class TestDataprocCreateBatchOperator:
             **BATCH,
             "runtime_config": {"properties": {}},
         }
-
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3624,7 +3687,7 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
             # not passing openlineage_inject_parent_job_info, should be False 
by default
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.SUCCEEDED)
         op.execute(context=EXAMPLE_CONTEXT)
         mock_hook.return_value.create_batch.assert_called_once_with(
             region=GCP_REGION,
@@ -3652,7 +3715,6 @@ class TestDataprocCreateBatchOperator:
             **BATCH,
             "runtime_config": {"properties": {}},
         }
-
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3667,7 +3729,7 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
             # not passing openlineage_inject_transport_info, should be False 
by default
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.SUCCEEDED)
         op.execute(context=EXAMPLE_CONTEXT)
         mock_hook.return_value.create_batch.assert_called_once_with(
             region=GCP_REGION,
@@ -3691,7 +3753,6 @@ class TestDataprocCreateBatchOperator:
             **BATCH,
             "runtime_config": {"properties": {}},
         }
-
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3706,7 +3767,7 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
             openlineage_inject_parent_job_info=True,
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.SUCCEEDED)
         op.execute(context=EXAMPLE_CONTEXT)
         mock_hook.return_value.create_batch.assert_called_once_with(
             region=GCP_REGION,
@@ -3734,7 +3795,6 @@ class TestDataprocCreateBatchOperator:
             **BATCH,
             "runtime_config": {"properties": {}},
         }
-
         op = DataprocCreateBatchOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -3749,7 +3809,7 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
             openlineage_inject_transport_info=True,
         )
-        mock_hook.return_value.wait_for_operation.return_value = 
Batch(state=Batch.State.SUCCEEDED)
+        mock_hook.return_value.wait_for_batch.return_value = 
Batch(state=Batch.State.SUCCEEDED)
         op.execute(context=EXAMPLE_CONTEXT)
         mock_hook.return_value.create_batch.assert_called_once_with(
             region=GCP_REGION,
@@ -3778,20 +3838,13 @@ class TestDataprocCreateBatchOperator:
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_create_batch_asdict_labels_updated(self, mock_hook, to_dict_mock):
-        expected_labels = {
-            "airflow-dag-id": "test_dag",
-            "airflow-dag-display-name": "test_dag",
-            "airflow-task-id": "test-task",
-        }
-
         expected_batch = {
             **BATCH,
-            "labels": expected_labels,
+            "labels": EXPECTED_LABELS,
         }
-
         DataprocCreateBatchOperator(
-            task_id="test-task",
-            dag=DAG(dag_id="test_dag"),
+            task_id=TASK_ID,
+            dag=DAG(dag_id=TEST_DAG_ID),
             batch=BATCH,
             region=GCP_REGION,
         ).execute(context=EXAMPLE_CONTEXT)
@@ -3801,20 +3854,13 @@ class TestDataprocCreateBatchOperator:
     @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_create_batch_asdict_labels_uppercase_transformed(self, mock_hook, 
to_dict_mock):
-        expected_labels = {
-            "airflow-dag-id": "test_dag",
-            "airflow-dag-display-name": "test_dag",
-            "airflow-task-id": "test-task",
-        }
-
         expected_batch = {
             **BATCH,
-            "labels": expected_labels,
+            "labels": EXPECTED_LABELS,
         }
-
         DataprocCreateBatchOperator(
-            task_id="test-TASK",
-            dag=DAG(dag_id="Test_dag"),
+            task_id=TASK_ID,
+            dag=DAG(dag_id=TEST_DAG_ID),
             batch=BATCH,
             region=GCP_REGION,
         ).execute(context=EXAMPLE_CONTEXT)
@@ -3826,7 +3872,7 @@ class TestDataprocCreateBatchOperator:
     def test_create_batch_invalid_taskid_labels_ignored(self, mock_hook, 
to_dict_mock):
         DataprocCreateBatchOperator(
             task_id=".task-id",
-            dag=DAG(dag_id="test-dag"),
+            dag=DAG(dag_id=TEST_DAG_ID),
             batch=BATCH,
             region=GCP_REGION,
         ).execute(context=EXAMPLE_CONTEXT)
@@ -3838,7 +3884,7 @@ class TestDataprocCreateBatchOperator:
     def test_create_batch_long_taskid_labels_ignored(self, mock_hook, 
to_dict_mock):
         DataprocCreateBatchOperator(
             task_id="a" * 65,
-            dag=DAG(dag_id="test-dag"),
+            dag=DAG(dag_id=TEST_DAG_ID),
             batch=BATCH,
             region=GCP_REGION,
         ).execute(context=EXAMPLE_CONTEXT)
@@ -3850,21 +3896,13 @@ class TestDataprocCreateBatchOperator:
     def test_create_batch_asobj_labels_updated(self, mock_hook, to_dict_mock):
         batch = Batch(name="test")
         batch.labels["foo"] = "bar"
-        dag = DAG(dag_id="test_dag")
-
-        expected_labels = {
-            "airflow-dag-id": "test_dag",
-            "airflow-dag-display-name": "test_dag",
-            "airflow-task-id": "test-task",
-        }
-
         expected_batch = deepcopy(batch)
-        expected_batch.labels.update(expected_labels)
+        expected_batch.labels.update(EXPECTED_LABELS)
+        dag = DAG(dag_id=TEST_DAG_ID)
 
-        DataprocCreateBatchOperator(task_id="test-task", batch=batch, 
region=GCP_REGION, dag=dag).execute(
+        DataprocCreateBatchOperator(task_id=TASK_ID, batch=batch, 
region=GCP_REGION, dag=dag).execute(
             context=EXAMPLE_CONTEXT
         )
-
         TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, 
expected_batch)
 
 

Reply via email to