This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7e9d7b177d7 AIP-72: Adding missing supervisor handler for RTIF (#45102)
7e9d7b177d7 is described below

commit 7e9d7b177d7e2a3e55763f90a4008850fc2e33f4
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 20 13:58:47 2024 +0530

    AIP-72: Adding missing supervisor handler for RTIF (#45102)
---
 .../src/airflow/sdk/execution_time/supervisor.py   |  3 +
 task_sdk/tests/execution_time/test_supervisor.py   |  9 +++
 task_sdk/tests/execution_time/test_task_runner.py  | 65 +++++++---------------
 3 files changed, 33 insertions(+), 44 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 932e6ead379..2cdcbd84237 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -68,6 +68,7 @@ from airflow.sdk.execution_time.comms import (
     GetXCom,
     PutVariable,
     RescheduleTask,
+    SetRenderedFields,
     SetXCom,
     StartupDetails,
     TaskState,
@@ -733,6 +734,8 @@ class WatchedSubprocess:
             self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.value, msg.map_index)
         elif isinstance(msg, PutVariable):
             self.client.variables.set(msg.key, msg.value, msg.description)
+        elif isinstance(msg, SetRenderedFields):
+            self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
         else:
             log.error("Unhandled request", msg=msg)
             return
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 73ac6dea630..03d2bab9396 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -46,6 +46,7 @@ from airflow.sdk.execution_time.comms import (
     GetXCom,
     PutVariable,
     RescheduleTask,
+    SetRenderedFields,
     SetXCom,
     TaskState,
     VariableResult,
@@ -882,6 +883,14 @@ class TestHandleRequest:
                 "",
                 id="patch_task_instance_to_skipped",
             ),
+            pytest.param(
+                SetRenderedFields(rendered_fields={"field1": 
"rendered_value1", "field2": "rendered_value2"}),
+                b"",
+                "task_instances.set_rtif",
+                (TI_ID, {"field1": "rendered_value1", "field2": 
"rendered_value2"}),
+                {"ok": True},
+                id="set_rtif",
+            ),
         ],
     )
     def test_handle_requests(
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index 96ac89db5cd..2c08a9b97cd 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -248,43 +248,6 @@ def test_run_basic_skipped(time_machine, mocked_parse, 
make_ti_context):
         )
 
 
-def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
-    """Test running a DAG with templated task."""
-    from airflow.providers.standard.operators.bash import BashOperator
-
-    task = BashOperator(
-        task_id="templated_task",
-        bash_command="echo 'Logical date is {{ logical_date }}'",
-    )
-
-    what = StartupDetails(
-        ti=TaskInstance(
-            id=uuid7(), task_id="templated_task", 
dag_id="basic_templated_dag", run_id="c", try_number=1
-        ),
-        file="",
-        requests_fd=0,
-        ti_context=make_ti_context(),
-    )
-    mocked_parse(what, "basic_templated_dag", task)
-
-    with mock.patch(
-        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
-    ) as mock_supervisor_comms:
-        mock_supervisor_comms.get_message.return_value = what
-        startup()
-
-        mock_supervisor_comms.send_request.assert_called_once_with(
-            msg=SetRenderedFields(
-                rendered_fields={
-                    "bash_command": "echo 'Logical date is {{ logical_date 
}}'",
-                    "cwd": None,
-                    "env": None,
-                }
-            ),
-            log=mock.ANY,
-        )
-
-
 @pytest.mark.parametrize(
     ["task_params", "expected_rendered_fields"],
     [
@@ -311,8 +274,8 @@ def test_startup_basic_templated_dag(mocked_parse, 
make_ti_context):
         ),
     ],
 )
-def test_startup_dag_with_templated_fields(
-    mocked_parse, task_params, expected_rendered_fields, make_ti_context
+def test_startup_and_run_dag_with_templated_fields(
+    mocked_parse, task_params, expected_rendered_fields, make_ti_context, 
time_machine
 ):
     """Test startup of a DAG with various templated fields."""
 
@@ -324,6 +287,10 @@ def test_startup_dag_with_templated_fields(
             for key, value in task_params.items():
                 setattr(self, key, value)
 
+        def execute(self, context):
+            for key in self.template_fields:
+                print(key, getattr(self, key))
+
     task = CustomOperator(task_id="templated_task")
 
     what = StartupDetails(
@@ -332,18 +299,28 @@ def test_startup_dag_with_templated_fields(
         requests_fd=0,
         ti_context=make_ti_context(),
     )
-    mocked_parse(what, "basic_dag", task)
+    ti = mocked_parse(what, "basic_dag", task)
 
+    instant = timezone.datetime(2024, 12, 3, 10, 0)
+    time_machine.move_to(instant, tick=False)
     with mock.patch(
         "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
     ) as mock_supervisor_comms:
         mock_supervisor_comms.get_message.return_value = what
 
         startup()
-        mock_supervisor_comms.send_request.assert_called_once_with(
-            msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
-            log=mock.ANY,
-        )
+        run(ti, log=mock.MagicMock())
+        expected_calls = [
+            mock.call.send_request(
+                
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
+                log=mock.ANY,
+            ),
+            mock.call.send_request(
+                msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS),
+                log=mock.ANY,
+            ),
+        ]
+        mock_supervisor_comms.assert_has_calls(expected_calls)
 
 
 @pytest.mark.parametrize(

Reply via email to