This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7e9d7b177d7 AIP-72: Adding missing supervisor handler for RTIF (#45102)
7e9d7b177d7 is described below
commit 7e9d7b177d7e2a3e55763f90a4008850fc2e33f4
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 20 13:58:47 2024 +0530
AIP-72: Adding missing supervisor handler for RTIF (#45102)
---
.../src/airflow/sdk/execution_time/supervisor.py | 3 +
task_sdk/tests/execution_time/test_supervisor.py | 9 +++
task_sdk/tests/execution_time/test_task_runner.py | 65 +++++++---------------
3 files changed, 33 insertions(+), 44 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 932e6ead379..2cdcbd84237 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -68,6 +68,7 @@ from airflow.sdk.execution_time.comms import (
GetXCom,
PutVariable,
RescheduleTask,
+ SetRenderedFields,
SetXCom,
StartupDetails,
TaskState,
@@ -733,6 +734,8 @@ class WatchedSubprocess:
self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.value, msg.map_index)
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
+ elif isinstance(msg, SetRenderedFields):
+ self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
else:
log.error("Unhandled request", msg=msg)
return
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 73ac6dea630..03d2bab9396 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -46,6 +46,7 @@ from airflow.sdk.execution_time.comms import (
GetXCom,
PutVariable,
RescheduleTask,
+ SetRenderedFields,
SetXCom,
TaskState,
VariableResult,
@@ -882,6 +883,14 @@ class TestHandleRequest:
"",
id="patch_task_instance_to_skipped",
),
+ pytest.param(
+ SetRenderedFields(rendered_fields={"field1":
"rendered_value1", "field2": "rendered_value2"}),
+ b"",
+ "task_instances.set_rtif",
+ (TI_ID, {"field1": "rendered_value1", "field2":
"rendered_value2"}),
+ {"ok": True},
+ id="set_rtif",
+ ),
],
)
def test_handle_requests(
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index 96ac89db5cd..2c08a9b97cd 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -248,43 +248,6 @@ def test_run_basic_skipped(time_machine, mocked_parse,
make_ti_context):
)
-def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
- """Test running a DAG with templated task."""
- from airflow.providers.standard.operators.bash import BashOperator
-
- task = BashOperator(
- task_id="templated_task",
- bash_command="echo 'Logical date is {{ logical_date }}'",
- )
-
- what = StartupDetails(
- ti=TaskInstance(
- id=uuid7(), task_id="templated_task",
dag_id="basic_templated_dag", run_id="c", try_number=1
- ),
- file="",
- requests_fd=0,
- ti_context=make_ti_context(),
- )
- mocked_parse(what, "basic_templated_dag", task)
-
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
- ) as mock_supervisor_comms:
- mock_supervisor_comms.get_message.return_value = what
- startup()
-
- mock_supervisor_comms.send_request.assert_called_once_with(
- msg=SetRenderedFields(
- rendered_fields={
- "bash_command": "echo 'Logical date is {{ logical_date
}}'",
- "cwd": None,
- "env": None,
- }
- ),
- log=mock.ANY,
- )
-
-
@pytest.mark.parametrize(
["task_params", "expected_rendered_fields"],
[
@@ -311,8 +274,8 @@ def test_startup_basic_templated_dag(mocked_parse,
make_ti_context):
),
],
)
-def test_startup_dag_with_templated_fields(
- mocked_parse, task_params, expected_rendered_fields, make_ti_context
+def test_startup_and_run_dag_with_templated_fields(
+ mocked_parse, task_params, expected_rendered_fields, make_ti_context,
time_machine
):
"""Test startup of a DAG with various templated fields."""
@@ -324,6 +287,10 @@ def test_startup_dag_with_templated_fields(
for key, value in task_params.items():
setattr(self, key, value)
+ def execute(self, context):
+ for key in self.template_fields:
+ print(key, getattr(self, key))
+
task = CustomOperator(task_id="templated_task")
what = StartupDetails(
@@ -332,18 +299,28 @@ def test_startup_dag_with_templated_fields(
requests_fd=0,
ti_context=make_ti_context(),
)
- mocked_parse(what, "basic_dag", task)
+ ti = mocked_parse(what, "basic_dag", task)
+ instant = timezone.datetime(2024, 12, 3, 10, 0)
+ time_machine.move_to(instant, tick=False)
with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = what
startup()
- mock_supervisor_comms.send_request.assert_called_once_with(
- msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
- log=mock.ANY,
- )
+ run(ti, log=mock.MagicMock())
+ expected_calls = [
+ mock.call.send_request(
+
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
+ log=mock.ANY,
+ ),
+ mock.call.send_request(
+ msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS),
+ log=mock.ANY,
+ ),
+ ]
+ mock_supervisor_comms.assert_has_calls(expected_calls)
@pytest.mark.parametrize(