dabla commented on code in PR #68299:
URL: https://github.com/apache/airflow/pull/68299#discussion_r3388708373
##########
task-sdk/tests/task_sdk/bases/test_xcom.py:
##########
@@ -70,3 +77,190 @@ def
test_delete_includes_map_index_in_delete_xcom_message(self, map_index, mock_
assert sent_message.task_id == "test_task"
assert sent_message.run_id == "test_run"
assert sent_message.map_index == map_index
+
+ @pytest.mark.asyncio
+ async def test_aget_one_returns_value(self, mock_supervisor_comms):
+ """aget_one awaits asend and returns the deserialized value."""
+ mock_supervisor_comms.asend.return_value = XComResult(key="test_key",
value="test_value")
+
+ result = await BaseXCom.aget_one(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=0,
+ )
+
+ assert result == "test_value"
+ mock_supervisor_comms.asend.assert_called_once_with(
+ GetXCom(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=0,
+ include_prior_dates=False,
+ )
+ )
+ mock_supervisor_comms.send.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_aget_one_returns_none_when_not_found(self,
mock_supervisor_comms):
+ """aget_one returns None when XCom value is not found."""
+ mock_supervisor_comms.asend.return_value = XComResult(key="test_key",
value=None)
+
+ result = await BaseXCom.aget_one(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_aget_one_with_include_prior_dates(self,
mock_supervisor_comms):
+ """aget_one passes include_prior_dates parameter correctly."""
+ mock_supervisor_comms.asend.return_value = XComResult(key="test_key",
value="prior_value")
+
+ result = await BaseXCom.aget_one(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ include_prior_dates=True,
+ )
+
+ assert result == "prior_value"
+ mock_supervisor_comms.asend.assert_called_once_with(
+ GetXCom(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=None,
+ include_prior_dates=True,
+ )
+ )
+
+ @pytest.mark.asyncio
+ async def test_aget_one_raises_on_invalid_response(self,
mock_supervisor_comms):
+ """aget_one raises TypeError when receiving unexpected response
type."""
+ mock_supervisor_comms.asend.return_value = "invalid_response"
+
+ with pytest.raises(TypeError, match="Expected XComResult"):
+ await BaseXCom.aget_one(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ @pytest.mark.asyncio
+ async def test_aget_all_returns_values(self, mock_supervisor_comms):
+ """aget_all awaits asend and returns deserialized values from all map
indexes."""
+ mock_supervisor_comms.asend.return_value = XComSequenceSliceResult(
+ root=["value1", "value2", "value3"]
+ )
+
+ result = await BaseXCom.aget_all(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ assert result == ["value1", "value2", "value3"]
+ mock_supervisor_comms.asend.assert_called_once_with(
+ msg=GetXComSequenceSlice(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ start=None,
+ stop=None,
+ step=None,
+ include_prior_dates=False,
+ )
+ )
+ mock_supervisor_comms.send.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_aget_all_returns_none_when_empty(self,
mock_supervisor_comms):
+ """aget_all returns None when no XCom values are found."""
+ mock_supervisor_comms.asend.return_value =
XComSequenceSliceResult(root=[])
+
+ result = await BaseXCom.aget_all(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_aget_all_with_include_prior_dates(self,
mock_supervisor_comms):
+ """aget_all passes include_prior_dates parameter correctly."""
+ mock_supervisor_comms.asend.return_value =
XComSequenceSliceResult(root=["prior_value"])
+
+ result = await BaseXCom.aget_all(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ include_prior_dates=True,
+ )
+
+ assert result == ["prior_value"]
+ mock_supervisor_comms.asend.assert_called_once_with(
+ msg=GetXComSequenceSlice(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ start=None,
+ stop=None,
+ step=None,
+ include_prior_dates=True,
+ )
+ )
+
+ @pytest.mark.asyncio
+ async def test_aget_all_raises_on_invalid_response(self,
mock_supervisor_comms):
+ """aget_all raises TypeError when receiving unexpected response
type."""
+ mock_supervisor_comms.asend.return_value = "invalid_response"
+
+ with pytest.raises(TypeError, match="Expected
XComSequenceSliceResult"):
+ await BaseXCom.aget_all(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ @pytest.mark.asyncio
Review Comment:
No we need to mark those, or you mark it at class/module level.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]