amoghrajesh commented on code in PR #64297:
URL: https://github.com/apache/airflow/pull/64297#discussion_r3008071381


##########
task-sdk/tests/task_sdk/execution_time/test_task_runner.py:
##########
@@ -2058,6 +2027,42 @@ def mock_get_all_side_effect(task_id, **kwargs):
                 assert mock_get_one.called
                 assert not mock_get_all.called
 
+    @pytest.mark.parametrize(
+        ("task_ids", "default", "expected_value"),
+        [
+            pytest.param("task_a", "fallback", "fallback", 
id="single_task_str_default"),
+            pytest.param("task_a", NOTSET, NOTSET, 
id="single_task_NOTSET_default"),
+            pytest.param(["task_a"], "fallback", ["fallback"], 
id="list_task_str_default"),
+            pytest.param(
+                ["task_a", "task_b"],
+                "fallback",
+                ["fallback", "fallback"],
+                id="multiple_tasks_str_default",
+            ),
+        ],
+    )
+    def test_xcom_pull_default_with_notset_map_indexes(

Review Comment:
   Can you add / edit tests that cover this scenario also:
   
   1. Test that shows that some tasks return not None xcoms
   2. `XCom.get_all()` returns an empty list instead of None?



##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -415,7 +417,7 @@ def xcom_pull(
                 )
 
                 if values is None:
-                    xcoms.append(None)
+                    xcoms.append(default)

Review Comment:
   Thanks, this looks good.



##########
task-sdk/tests/task_sdk/execution_time/test_task_runner.py:
##########
@@ -2058,6 +2027,42 @@ def mock_get_all_side_effect(task_id, **kwargs):
                 assert mock_get_one.called
                 assert not mock_get_all.called
 
+    @pytest.mark.parametrize(
+        ("task_ids", "default", "expected_value"),
+        [
+            pytest.param("task_a", "fallback", "fallback", 
id="single_task_str_default"),
+            pytest.param("task_a", NOTSET, NOTSET, 
id="single_task_NOTSET_default"),
+            pytest.param(["task_a"], "fallback", ["fallback"], 
id="list_task_str_default"),
+            pytest.param(
+                ["task_a", "task_b"],
+                "fallback",
+                ["fallback", "fallback"],
+                id="multiple_tasks_str_default",
+            ),
+        ],
+    )
+    def test_xcom_pull_default_with_notset_map_indexes(
+        self,
+        create_runtime_ti,
+        mock_supervisor_comms,
+        task_ids,
+        default,
+        expected_value,

Review Comment:
   ```suggestion
           expected_default,
   ```



##########
task-sdk/tests/task_sdk/execution_time/test_task_runner.py:
##########
@@ -2058,6 +2027,42 @@ def mock_get_all_side_effect(task_id, **kwargs):
                 assert mock_get_one.called
                 assert not mock_get_all.called
 
+    @pytest.mark.parametrize(
+        ("task_ids", "default", "expected_value"),
+        [
+            pytest.param("task_a", "fallback", "fallback", 
id="single_task_str_default"),
+            pytest.param("task_a", NOTSET, NOTSET, 
id="single_task_NOTSET_default"),
+            pytest.param(["task_a"], "fallback", ["fallback"], 
id="list_task_str_default"),
+            pytest.param(
+                ["task_a", "task_b"],
+                "fallback",
+                ["fallback", "fallback"],
+                id="multiple_tasks_str_default",
+            ),
+        ],
+    )
+    def test_xcom_pull_default_with_notset_map_indexes(
+        self,
+        create_runtime_ti,
+        mock_supervisor_comms,
+        task_ids,
+        default,
+        expected_value,
+    ):
+        """Test that xcom_pull returns `default` when no XCom is found and 
map_indexes is NOTSET."""
+
+        class CustomOperator(BaseOperator):
+            def execute(self, context):
+                print("This is a custom operator")
+
+        task = CustomOperator(task_id="pull_task")
+        runtime_ti = create_runtime_ti(task=task)
+
+        with patch.object(XCom, "get_all", return_value=None) as mock_get_all:
+            result = runtime_ti.xcom_pull(key="key", task_ids=task_ids, 
default=default)
+            assert result == expected_value

Review Comment:
   ```suggestion
               assert result == expected_default
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to