kaxil commented on code in PR #55542:
URL: https://github.com/apache/airflow/pull/55542#discussion_r2342381835
##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -831,6 +837,132 @@ def fake_collect_dags(self, *args, **kwargs):
# Should log warning about no callback found
log.warning.assert_called_once_with("Callback requested, but dag
didn't have any", dag_id="test_dag")
+ @pytest.mark.parametrize(
+ "xcom_operation,expected_message_type,expected_message,mock_response",
+ [
+ (
+ lambda ti, task_ids: ti.xcom_pull(key="report_df",
task_ids=task_ids),
+ "GetXComSequenceSlice",
+ GetXComSequenceSlice(
+ key="report_df",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ start=None,
+ stop=None,
+ step=None,
+ include_prior_dates=False,
+ ),
+ XComSequenceSliceResult(root=["test data"]),
+ ),
+ (
+ lambda ti, task_ids: ti.xcom_pull(key="single_value",
task_ids=["test_task"]),
+ "GetXComSequenceSlice",
+ GetXComSequenceSlice(
+ key="single_value",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ start=None,
+ stop=None,
+ step=None,
+ include_prior_dates=False,
+ ),
+ XComSequenceSliceResult(root=["test data"]),
+ ),
+ (
+ lambda ti, task_ids: ti.xcom_pull(key="direct_value",
task_ids="test_task", map_indexes=None),
+ "GetXCom",
+ GetXCom(
+ key="direct_value",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ map_index=None,
+ include_prior_dates=False,
+ ),
+ XComResult(
+ key="direct_value",
+ value="test",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ map_index=None,
+ ),
+ ),
+ ],
+ )
+ def test_notifier_xcom_operations_send_correct_messages(
+ self,
+ spy_agency,
+ mock_supervisor_comms,
+ xcom_operation,
+ expected_message_type,
+ expected_message,
+ mock_response,
+ ):
+ """Test that different XCom operations send correct message types"""
+
+ mock_supervisor_comms.send.return_value = mock_response
+
+ class TestNotifier:
+ def __call__(self, context):
+ ti = context["ti"]
+ dag = context["dag"]
+ task_ids = list(dag.task_dict.keys())
+ xcom_operation(ti, task_ids)
+
+ with DAG(dag_id="test_dag", on_success_callback=TestNotifier()) as dag:
+ BaseOperator(task_id="test_task")
+
+ def fake_collect_dags(self, *args, **kwargs):
+ self.dags[dag.dag_id] = dag
+
+ spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags,
owner=DagBag)
+
+ dagbag = DagBag()
+ dagbag.collect_dags()
+
+ current_time = timezone.utcnow()
+ request = DagCallbackRequest(
+ filepath="test.py",
+ dag_id="test_dag",
+ run_id="test_run",
+ bundle_name="testing",
+ bundle_version=None,
+ context_from_server=DagRunContext(
+ dag_run=DRDataModel(
+ dag_id="test_dag",
+ run_id="test_run",
+ logical_date=current_time,
+ data_interval_start=current_time,
+ data_interval_end=current_time,
+ run_after=current_time,
+ start_date=current_time,
+ end_date=None,
+ run_type="manual",
+ state="success",
+ consumed_asset_events=[],
+ ),
+ last_ti=TIDataModel(
+ id=uuid.uuid4(),
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=-1,
+ try_number=1,
+ dag_version_id=uuid.uuid4(),
+ ),
+ ),
+ is_failure_callback=False,
+ msg="Test success message",
+ )
+
+ _execute_dag_callbacks(dagbag, request, structlog.get_logger())
+
+ mock_supervisor_comms.send.assert_called()
+ mock_supervisor_comms.send.assert_called_with(msg=expected_message)
Review Comment:
```suggestion
mock_supervisor_comms.send.assert_called_with(msg=expected_message)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]