uranusjr commented on code in PR #55542:
URL: https://github.com/apache/airflow/pull/55542#discussion_r2342400111


##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -831,6 +837,132 @@ def fake_collect_dags(self, *args, **kwargs):
         # Should log warning about no callback found
         log.warning.assert_called_once_with("Callback requested, but dag 
didn't have any", dag_id="test_dag")
 
+    @pytest.mark.parametrize(
+        "xcom_operation,expected_message_type,expected_message,mock_response",
+        [
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="report_df", 
task_ids=task_ids),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="report_df",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="single_value", 
task_ids=["test_task"]),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="single_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="direct_value", 
task_ids="test_task", map_indexes=None),
+                "GetXCom",
+                GetXCom(
+                    key="direct_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    map_index=None,
+                    include_prior_dates=False,
+                ),
+                XComResult(
+                    key="direct_value",
+                    value="test",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    map_index=None,
+                ),
+            ),
+        ],
+    )
+    def test_notifier_xcom_operations_send_correct_messages(
+        self,
+        spy_agency,
+        mock_supervisor_comms,
+        xcom_operation,
+        expected_message_type,
+        expected_message,
+        mock_response,
+    ):
+        """Test that different XCom operations send correct message types"""
+
+        mock_supervisor_comms.send.return_value = mock_response
+
+        class TestNotifier:
+            def __call__(self, context):
+                ti = context["ti"]
+                dag = context["dag"]
+                task_ids = list(dag.task_dict.keys())
+                xcom_operation(ti, task_ids)
+
+        with DAG(dag_id="test_dag", on_success_callback=TestNotifier()) as dag:
+            BaseOperator(task_id="test_task")
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        current_time = timezone.utcnow()
+        request = DagCallbackRequest(
+            filepath="test.py",
+            dag_id="test_dag",
+            run_id="test_run",
+            bundle_name="testing",
+            bundle_version=None,
+            context_from_server=DagRunContext(
+                dag_run=DRDataModel(
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    logical_date=current_time,
+                    data_interval_start=current_time,
+                    data_interval_end=current_time,
+                    run_after=current_time,
+                    start_date=current_time,
+                    end_date=None,
+                    run_type="manual",
+                    state="success",
+                    consumed_asset_events=[],
+                ),
+                last_ti=TIDataModel(
+                    id=uuid.uuid4(),
+                    dag_id="test_dag",
+                    task_id="test_task",
+                    run_id="test_run",
+                    map_index=-1,
+                    try_number=1,
+                    dag_version_id=uuid.uuid4(),
+                ),
+            ),
+            is_failure_callback=False,
+            msg="Test success message",
+        )
+
+        _execute_dag_callbacks(dagbag, request, structlog.get_logger())
+
+        mock_supervisor_comms.send.assert_called()
+        mock_supervisor_comms.send.assert_called_with(msg=expected_message)

Review Comment:
   ```suggestion
           mock_supervisor_comms.send.mock_calls == 
[mock.call(msg=expected_message)]
   ```



##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -831,6 +837,132 @@ def fake_collect_dags(self, *args, **kwargs):
         # Should log warning about no callback found
         log.warning.assert_called_once_with("Callback requested, but dag 
didn't have any", dag_id="test_dag")
 
+    @pytest.mark.parametrize(
+        "xcom_operation,expected_message_type,expected_message,mock_response",
+        [
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="report_df", 
task_ids=task_ids),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="report_df",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="single_value", 
task_ids=["test_task"]),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="single_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="direct_value", 
task_ids="test_task", map_indexes=None),
+                "GetXCom",
+                GetXCom(
+                    key="direct_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    map_index=None,
+                    include_prior_dates=False,
+                ),
+                XComResult(
+                    key="direct_value",
+                    value="test",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    map_index=None,
+                ),
+            ),
+        ],
+    )
+    def test_notifier_xcom_operations_send_correct_messages(
+        self,
+        spy_agency,
+        mock_supervisor_comms,
+        xcom_operation,
+        expected_message_type,
+        expected_message,
+        mock_response,
+    ):
+        """Test that different XCom operations send correct message types"""
+
+        mock_supervisor_comms.send.return_value = mock_response
+
+        class TestNotifier:
+            def __call__(self, context):
+                ti = context["ti"]
+                dag = context["dag"]
+                task_ids = list(dag.task_dict.keys())

Review Comment:
   ```suggestion
                   task_ids = list(dag.task_dict)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to