This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2bdb51550cb Refactor the google cloud DataprocCreateBatchOperator
tests (#52573)
2bdb51550cb is described below
commit 2bdb51550cb29da5a3dbb03812ace55ebf26b694
Author: olegkachur-e <[email protected]>
AuthorDate: Mon Jun 30 23:14:38 2025 +0200
Refactor the google cloud DataprocCreateBatchOperator tests (#52573)
- replce un-called method mock
- add logging checks
- populate labels checks
Co-authored-by: Oleg Kachur <[email protected]>
---
.../providers/google/cloud/operators/dataproc.py | 8 +-
.../unit/google/cloud/operators/test_dataproc.py | 204 ++++++++++++---------
2 files changed, 125 insertions(+), 87 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index 5ed6923821a..349250150a5 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -2488,13 +2488,13 @@ class
DataprocCreateBatchOperator(GoogleCloudBaseOperator):
link = DATAPROC_BATCH_LINK.format(region=self.region,
project_id=self.project_id, batch_id=batch_id)
if state == Batch.State.FAILED:
raise AirflowException(
- f"Batch job {batch_id} failed with error:
{state_message}\nDriver Logs: {link}"
+ f"Batch job {batch_id} failed with error:
{state_message}.\nDriver logs: {link}"
)
if state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
- raise AirflowException(f"Batch job {batch_id} was cancelled.
Driver logs: {link}")
+ raise AirflowException(f"Batch job {batch_id} was
cancelled.\nDriver logs: {link}")
if state == Batch.State.STATE_UNSPECIFIED:
- raise AirflowException(f"Batch job {batch_id} unspecified. Driver
logs: {link}")
- self.log.info("Batch job %s completed. Driver logs: %s", batch_id,
link)
+ raise AirflowException(f"Batch job {batch_id} unspecified.\nDriver
logs: {link}")
+ self.log.info("Batch job %s completed.\nDriver logs: %s", batch_id,
link)
def retry_batch_creation(
self,
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index 00923cb590a..c76c7046db1 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -39,6 +39,7 @@ from airflow.exceptions import (
)
from airflow.models import DAG, DagBag
from airflow.providers.google.cloud.links.dataproc import (
+ DATAPROC_BATCH_LINK,
DATAPROC_CLUSTER_LINK_DEPRECATED,
DATAPROC_JOB_LINK_DEPRECATED,
)
@@ -353,6 +354,12 @@ DEFAULT_DATE = datetime(2020, 1, 1)
TEST_JOB_ID = "test-job"
TEST_WORKFLOW_ID = "test-workflow"
+EXPECTED_LABELS = {
+ "airflow-dag-id": TEST_DAG_ID,
+ "airflow-dag-display-name": TEST_DAG_ID,
+ "airflow-task-id": TASK_ID,
+}
+
DATAPROC_JOB_LINK_EXPECTED = (
f"https://console.cloud.google.com/dataproc/jobs/{TEST_JOB_ID}?region={GCP_REGION}&project={GCP_PROJECT}"
)
@@ -3187,9 +3194,10 @@ class TestDataprocCreateWorkflowTemplateOperator:
class TestDataprocCreateBatchOperator:
+ @mock.patch.object(DataprocCreateBatchOperator, "log",
new_callable=mock.MagicMock)
@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_execute(self, mock_hook, to_dict_mock):
+ def test_execute(self, mock_hook, to_dict_mock, mock_log):
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3203,7 +3211,10 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.create_batch.return_value.metadata.batch =
f"prefix/{BATCH_ID}"
+ batch_state_succeeded = Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
batch_state_succeeded
+
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_batch.assert_called_once_with(
@@ -3216,6 +3227,16 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+ to_dict_mock.assert_called_once_with(batch_state_succeeded)
+ logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION,
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+ mock_log.info.assert_has_calls(
+ [
+ mock.call("Starting batch %s", BATCH_ID),
+ mock.call("The batch %s was created.", BATCH_ID),
+ mock.call("Waiting for the completion of batch job %s",
BATCH_ID),
+ mock.call("Batch job %s completed.\nDriver logs: %s",
BATCH_ID, logs_link),
+ ]
+ )
@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -3234,7 +3255,7 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.SUCCEEDED)
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_batch.assert_called_once_with(
@@ -3268,8 +3289,9 @@ class TestDataprocCreateBatchOperator:
with pytest.raises(AirflowException):
op.execute(context=MagicMock())
+ @mock.patch.object(DataprocCreateBatchOperator, "log",
new_callable=mock.MagicMock)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_execute_batch_already_exists_succeeds(self, mock_hook):
+ def test_execute_batch_already_exists_succeeds(self, mock_hook, mock_log):
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3283,9 +3305,10 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
- mock_hook.return_value.wait_for_operation.side_effect =
AlreadyExists("")
- mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
mock_hook.return_value.create_batch.return_value.metadata.batch =
f"prefix/{BATCH_ID}"
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.SUCCEEDED)
+
op.execute(context=MagicMock())
mock_hook.return_value.wait_for_batch.assert_called_once_with(
batch_id=BATCH_ID,
@@ -3295,9 +3318,23 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+ # Check for succeeded run
+ logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION,
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+
+ mock_log.info.assert_has_calls(
+ [
+ mock.call(
+ "Batch with given id already exists.",
+ ),
+ mock.call("Attaching to the job %s if it is still running.",
BATCH_ID),
+ mock.call("Waiting for the completion of batch job %s",
BATCH_ID),
+ mock.call("Batch job %s completed.\nDriver logs: %s",
BATCH_ID, logs_link),
+ ]
+ )
+ @mock.patch.object(DataprocCreateBatchOperator, "log",
new_callable=mock.MagicMock)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_execute_batch_already_exists_fails(self, mock_hook):
+ def test_execute_batch_already_exists_fails(self, mock_hook, mock_log):
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3311,11 +3348,15 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
- mock_hook.return_value.wait_for_operation.side_effect =
AlreadyExists("")
- mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.FAILED)
+ mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
mock_hook.return_value.create_batch.return_value.metadata.batch =
f"prefix/{BATCH_ID}"
- with pytest.raises(AirflowException):
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.FAILED)
+
+ with pytest.raises(AirflowException) as exc:
op.execute(context=MagicMock())
+ # Check msg for FAILED batch state
+ logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION,
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+ assert str(exc.value) == (f"Batch job {BATCH_ID} failed with error:
.\nDriver logs: {logs_link}")
mock_hook.return_value.wait_for_batch.assert_called_once_with(
batch_id=BATCH_ID,
region=GCP_REGION,
@@ -3324,9 +3365,12 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+ # Check logs for AlreadyExists being called
+ mock_log.info.assert_any_call("Batch with given id already exists.")
+ @mock.patch.object(DataprocCreateBatchOperator, "log",
new_callable=mock.MagicMock)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_execute_batch_already_exists_cancelled(self, mock_hook):
+ def test_execute_batch_already_exists_cancelled(self, mock_hook, mock_log):
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3340,11 +3384,16 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
- mock_hook.return_value.wait_for_operation.side_effect =
AlreadyExists("")
- mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.CANCELLED)
+ mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
mock_hook.return_value.create_batch.return_value.metadata.batch =
f"prefix/{BATCH_ID}"
- with pytest.raises(AirflowException):
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.CANCELLED)
+
+ with pytest.raises(AirflowException) as exc:
op.execute(context=MagicMock())
+ # Check msg for CANCELLED batch state
+ logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION,
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+ assert str(exc.value) == f"Batch job {BATCH_ID} was cancelled.\nDriver
logs: {logs_link}"
+
mock_hook.return_value.wait_for_batch.assert_called_once_with(
batch_id=BATCH_ID,
region=GCP_REGION,
@@ -3353,21 +3402,29 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+ # Check logs for AlreadyExists being called
+ mock_log.info.assert_any_call("Batch with given id already exists.")
+ @mock.patch.object(DataprocCreateBatchOperator, "log",
new_callable=mock.MagicMock)
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_parent_job_info_injection(
- self, mock_hook, to_dict_mock, mock_ol_accessible, mock_static_uuid
+ self,
+ mock_hook,
+ to_dict_mock,
+ mock_ol_accessible,
+ mock_static_uuid,
+ mock_log,
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
expected_batch = {
**BATCH,
+ "labels": EXPECTED_LABELS,
"runtime_config": {"properties":
OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES},
}
-
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3381,9 +3438,13 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
openlineage_inject_parent_job_info=True,
+ dag=DAG(dag_id=TEST_DAG_ID),
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ batch_state_succeeded = Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
batch_state_succeeded
+ mock_hook.return_value.create_batch.return_value.metadata.batch =
f"prefix/{BATCH_ID}"
op.execute(context=EXAMPLE_CONTEXT)
+
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_REGION,
project_id=GCP_PROJECT,
@@ -3394,14 +3455,19 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+ to_dict_mock.assert_called_once_with(batch_state_succeeded)
+ logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION,
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+ # Check SUCCEED run from the logs
+ mock_log.info.assert_any_call("Batch job %s completed.\nDriver logs:
%s", BATCH_ID, logs_link)
+ @mock.patch.object(DataprocCreateBatchOperator, "log",
new_callable=mock.MagicMock)
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_transport_info_injection(
- self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener,
mock_static_uuid
+ self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener,
mock_static_uuid, mock_log
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
@@ -3410,9 +3476,9 @@ class TestDataprocCreateBatchOperator:
)
expected_batch = {
**BATCH,
+ "labels": EXPECTED_LABELS,
"runtime_config": {"properties":
OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES},
}
-
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3426,9 +3492,13 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
openlineage_inject_transport_info=True,
+ dag=DAG(dag_id=TEST_DAG_ID),
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ batch_state_succeeded = Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
batch_state_succeeded
+ mock_hook.return_value.create_batch.return_value.metadata.batch =
f"prefix/{BATCH_ID}"
op.execute(context=EXAMPLE_CONTEXT)
+
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_REGION,
project_id=GCP_PROJECT,
@@ -3439,6 +3509,14 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+ to_dict_mock.assert_called_once_with(batch_state_succeeded)
+ logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION,
project_id=GCP_PROJECT, batch_id=BATCH_ID)
+ # Verify logs for successful run
+ mock_log.info.assert_any_call(
+ "Batch job %s completed.\nDriver logs: %s",
+ BATCH_ID,
+ logs_link,
+ )
@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
@@ -3455,6 +3533,7 @@ class TestDataprocCreateBatchOperator:
)
expected_batch = {
**BATCH,
+ "labels": EXPECTED_LABELS,
"runtime_config": {
"properties": {
**OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES,
@@ -3462,7 +3541,6 @@ class TestDataprocCreateBatchOperator:
}
},
}
-
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3477,8 +3555,9 @@ class TestDataprocCreateBatchOperator:
metadata=METADATA,
openlineage_inject_parent_job_info=True,
openlineage_inject_transport_info=True,
+ dag=DAG(dag_id=TEST_DAG_ID),
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.SUCCEEDED)
op.execute(context=EXAMPLE_CONTEXT)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_REGION,
@@ -3498,22 +3577,15 @@ class TestDataprocCreateBatchOperator:
self, mock_hook, to_dict_mock, mock_ol_accessible
):
mock_ol_accessible.return_value = True
- expected_labels = {
- "airflow-dag-id": "test_dag",
- "airflow-dag-display-name": "test_dag",
- "airflow-task-id": "task-id",
- }
-
batch = {
**BATCH,
- "labels": expected_labels,
+ "labels": EXPECTED_LABELS,
"runtime_config": {
"properties": {
"spark.openlineage.parentJobName": "dag_id.task_id",
}
},
}
-
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3527,9 +3599,9 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
openlineage_inject_parent_job_info=True,
- dag=DAG(dag_id="test_dag"),
+ dag=DAG(dag_id=TEST_DAG_ID),
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.SUCCEEDED)
op.execute(context=EXAMPLE_CONTEXT)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_REGION,
@@ -3553,23 +3625,15 @@ class TestDataprocCreateBatchOperator:
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport
= HttpTransport(
HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
)
-
- expected_labels = {
- "airflow-dag-id": "test_dag",
- "airflow-dag-display-name": "test_dag",
- "airflow-task-id": "task-id",
- }
-
batch = {
**BATCH,
- "labels": expected_labels,
+ "labels": EXPECTED_LABELS,
"runtime_config": {
"properties": {
"spark.openlineage.transport.type": "console",
}
},
}
-
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3583,7 +3647,7 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
openlineage_inject_transport_info=True,
- dag=DAG(dag_id="test_dag"),
+ dag=DAG(dag_id=TEST_DAG_ID),
)
mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
op.execute(context=EXAMPLE_CONTEXT)
@@ -3609,7 +3673,6 @@ class TestDataprocCreateBatchOperator:
**BATCH,
"runtime_config": {"properties": {}},
}
-
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3624,7 +3687,7 @@ class TestDataprocCreateBatchOperator:
metadata=METADATA,
# not passing openlineage_inject_parent_job_info, should be False
by default
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.SUCCEEDED)
op.execute(context=EXAMPLE_CONTEXT)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_REGION,
@@ -3652,7 +3715,6 @@ class TestDataprocCreateBatchOperator:
**BATCH,
"runtime_config": {"properties": {}},
}
-
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3667,7 +3729,7 @@ class TestDataprocCreateBatchOperator:
metadata=METADATA,
# not passing openlineage_inject_transport_info, should be False
by default
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.SUCCEEDED)
op.execute(context=EXAMPLE_CONTEXT)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_REGION,
@@ -3691,7 +3753,6 @@ class TestDataprocCreateBatchOperator:
**BATCH,
"runtime_config": {"properties": {}},
}
-
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3706,7 +3767,7 @@ class TestDataprocCreateBatchOperator:
metadata=METADATA,
openlineage_inject_parent_job_info=True,
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.SUCCEEDED)
op.execute(context=EXAMPLE_CONTEXT)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_REGION,
@@ -3734,7 +3795,6 @@ class TestDataprocCreateBatchOperator:
**BATCH,
"runtime_config": {"properties": {}},
}
-
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -3749,7 +3809,7 @@ class TestDataprocCreateBatchOperator:
metadata=METADATA,
openlineage_inject_transport_info=True,
)
- mock_hook.return_value.wait_for_operation.return_value =
Batch(state=Batch.State.SUCCEEDED)
+ mock_hook.return_value.wait_for_batch.return_value =
Batch(state=Batch.State.SUCCEEDED)
op.execute(context=EXAMPLE_CONTEXT)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_REGION,
@@ -3778,20 +3838,13 @@ class TestDataprocCreateBatchOperator:
@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_create_batch_asdict_labels_updated(self, mock_hook, to_dict_mock):
- expected_labels = {
- "airflow-dag-id": "test_dag",
- "airflow-dag-display-name": "test_dag",
- "airflow-task-id": "test-task",
- }
-
expected_batch = {
**BATCH,
- "labels": expected_labels,
+ "labels": EXPECTED_LABELS,
}
-
DataprocCreateBatchOperator(
- task_id="test-task",
- dag=DAG(dag_id="test_dag"),
+ task_id=TASK_ID,
+ dag=DAG(dag_id=TEST_DAG_ID),
batch=BATCH,
region=GCP_REGION,
).execute(context=EXAMPLE_CONTEXT)
@@ -3801,20 +3854,13 @@ class TestDataprocCreateBatchOperator:
@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_create_batch_asdict_labels_uppercase_transformed(self, mock_hook,
to_dict_mock):
- expected_labels = {
- "airflow-dag-id": "test_dag",
- "airflow-dag-display-name": "test_dag",
- "airflow-task-id": "test-task",
- }
-
expected_batch = {
**BATCH,
- "labels": expected_labels,
+ "labels": EXPECTED_LABELS,
}
-
DataprocCreateBatchOperator(
- task_id="test-TASK",
- dag=DAG(dag_id="Test_dag"),
+ task_id=TASK_ID,
+ dag=DAG(dag_id=TEST_DAG_ID),
batch=BATCH,
region=GCP_REGION,
).execute(context=EXAMPLE_CONTEXT)
@@ -3826,7 +3872,7 @@ class TestDataprocCreateBatchOperator:
def test_create_batch_invalid_taskid_labels_ignored(self, mock_hook,
to_dict_mock):
DataprocCreateBatchOperator(
task_id=".task-id",
- dag=DAG(dag_id="test-dag"),
+ dag=DAG(dag_id=TEST_DAG_ID),
batch=BATCH,
region=GCP_REGION,
).execute(context=EXAMPLE_CONTEXT)
@@ -3838,7 +3884,7 @@ class TestDataprocCreateBatchOperator:
def test_create_batch_long_taskid_labels_ignored(self, mock_hook,
to_dict_mock):
DataprocCreateBatchOperator(
task_id="a" * 65,
- dag=DAG(dag_id="test-dag"),
+ dag=DAG(dag_id=TEST_DAG_ID),
batch=BATCH,
region=GCP_REGION,
).execute(context=EXAMPLE_CONTEXT)
@@ -3850,21 +3896,13 @@ class TestDataprocCreateBatchOperator:
def test_create_batch_asobj_labels_updated(self, mock_hook, to_dict_mock):
batch = Batch(name="test")
batch.labels["foo"] = "bar"
- dag = DAG(dag_id="test_dag")
-
- expected_labels = {
- "airflow-dag-id": "test_dag",
- "airflow-dag-display-name": "test_dag",
- "airflow-task-id": "test-task",
- }
-
expected_batch = deepcopy(batch)
- expected_batch.labels.update(expected_labels)
+ expected_batch.labels.update(EXPECTED_LABELS)
+ dag = DAG(dag_id=TEST_DAG_ID)
- DataprocCreateBatchOperator(task_id="test-task", batch=batch,
region=GCP_REGION, dag=dag).execute(
+ DataprocCreateBatchOperator(task_id=TASK_ID, batch=batch,
region=GCP_REGION, dag=dag).execute(
context=EXAMPLE_CONTEXT
)
-
TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook,
expected_batch)