This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8b1492edef7 AIP-72: Adding Endpoint to set rendered task instance 
fields (#44692)
8b1492edef7 is described below

commit 8b1492edef7272b7d83b7644ee12a87806cef32b
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Dec 9 11:35:07 2024 +0530

    AIP-72: Adding Endpoint to set rendered task instance fields (#44692)
---
 .../execution_api/datamodels/taskinstance.py       |  6 +-
 .../execution_api/routes/task_instances.py         | 33 ++++++++-
 task_sdk/src/airflow/sdk/api/client.py             |  8 +++
 task_sdk/tests/api/test_client.py                  | 42 ++++++++++++
 .../execution_api/routes/test_task_instances.py    | 80 +++++++++++++++++++++-
 5 files changed, 165 insertions(+), 4 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index ae05cc140c4..2b8830c6de1 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -21,7 +21,7 @@ import uuid
 from datetime import timedelta
 from typing import Annotated, Any, Literal, Union
 
-from pydantic import Discriminator, Field, Tag, WithJsonSchema
+from pydantic import Discriminator, Field, RootModel, Tag, WithJsonSchema
 
 from airflow.api_fastapi.common.types import UtcDateTime
 from airflow.api_fastapi.core_api.base import BaseModel
@@ -135,3 +135,7 @@ class TaskInstance(BaseModel):
     run_id: str
     try_number: int
     map_index: int | None = None
+
+
+"""Schema for setting RTIF for a task instance."""
+RTIFPayload = RootModel[dict[str, str]]
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 8d15a063f4f..90bbe1c1d3e 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -29,13 +29,14 @@ from sqlalchemy.sql import select
 from airflow.api_fastapi.common.db.common import SessionDep
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+    RTIFPayload,
     TIDeferredStatePayload,
     TIEnterRunningPayload,
     TIHeartbeatInfo,
     TIStateUpdate,
     TITerminalStatePayload,
 )
-from airflow.models.taskinstance import TaskInstance as TI
+from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
 from airflow.models.trigger import Trigger
 from airflow.utils import timezone
 from airflow.utils.state import State
@@ -219,3 +220,33 @@ def ti_heartbeat(
     # Update the last heartbeat time!
     session.execute(update(TI).where(TI.id == 
ti_id_str).values(last_heartbeat_at=timezone.utcnow()))
     log.debug("Task with %s state heartbeated", previous_state)
+
+
[email protected](
+    "/{task_instance_id}/rtif",
+    status_code=status.HTTP_201_CREATED,
+    # TODO: Add description to the operation
+    # TODO: Add Operation ID to control the function name in the OpenAPI spec
+    # TODO: Do we need to use create_openapi_http_exception_doc here?
+    responses={
+        status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
+        status.HTTP_422_UNPROCESSABLE_ENTITY: {
+            "description": "Invalid payload for the setting rendered task 
instance fields"
+        },
+    },
+)
+def ti_put_rtif(
+    task_instance_id: UUID,
+    put_rtif_payload: RTIFPayload,
+    session: SessionDep,
+):
+    """Add an RTIF entry for a task instance, sent by the worker."""
+    ti_id_str = str(task_instance_id)
+    task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
+    if not task_instance:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+        )
+    _update_rtif(task_instance, put_rtif_payload.model_dump(), session)
+
+    return {"message": "Rendered task instance fields successfully set"}
diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
index 5de5c7a8d90..568eb3c90bd 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -134,6 +134,14 @@ class TaskInstanceOperations:
         # Create a deferred state payload from msg
         self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
+    def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> dict[str, bool]:
+        """Set Rendered Task Instance Fields via the API server."""
+        self.client.put(f"task-instances/{id}/rtif", json=body)
+        # Any error from the server will anyway be propagated down to the 
supervisor,
+        # so we choose to send a generic response to the supervisor over the 
server response to
+        # decouple from the server response string
+        return {"ok": True}
+
 
 class ConnectionOperations:
     __slots__ = ("client",)
diff --git a/task_sdk/tests/api/test_client.py 
b/task_sdk/tests/api/test_client.py
index 900d1c90de5..ca8a6af3dd3 100644
--- a/task_sdk/tests/api/test_client.py
+++ b/task_sdk/tests/api/test_client.py
@@ -21,6 +21,7 @@ import json
 
 import httpx
 import pytest
+import uuid6
 
 from airflow.sdk.api.client import Client, RemoteValidationError, 
ServerResponseError
 from airflow.sdk.api.datamodels._generated import VariableResponse, 
XComResponse
@@ -84,6 +85,47 @@ def make_client(transport: httpx.MockTransport) -> Client:
     return Client(base_url="test://server", token="", transport=transport)
 
 
+class TestTaskInstanceOperations:
+    """
+    Test that the TestVariableOperations class works as expected. While the 
operations are simple, it
+    still catches the basic functionality of the client for task instances 
including endpoint and
+    response parsing.
+    """
+
+    # TODO: Add tests for different ti endpoints
+
+    @pytest.mark.parametrize(
+        "rendered_fields",
+        [
+            pytest.param({"field1": "rendered_value1", "field2": 
"rendered_value2"}, id="simple-rendering"),
+            pytest.param(
+                {
+                    "field1": "ClassWithCustomAttributes({'nested1': 
ClassWithCustomAttributes("
+                    "{'att1': 'test', 'att2': 'test2'), "
+                    "'nested2': ClassWithCustomAttributes("
+                    "{'att3': 'test3', 'att4': 'test4')"
+                },
+                id="complex-rendering",
+            ),
+        ],
+    )
+    def test_taskinstance_set_rtif_success(self, rendered_fields):
+        TI_ID = uuid6.uuid7()
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{TI_ID}/rtif":
+                return httpx.Response(
+                    status_code=201,
+                    json={"message": "Rendered task instance fields 
successfully set"},
+                )
+            return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.task_instances.set_rtif(id=TI_ID, body=rendered_fields)
+
+        assert result == {"ok": True}
+
+
 class TestVariableOperations:
     """
     Test that the VariableOperations class works as expected. While the 
operations are simple, it
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py 
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index 251ef3388cd..a0011e3ad89 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -20,15 +20,16 @@ from __future__ import annotations
 from unittest import mock
 
 import pytest
+import uuid6
 from sqlalchemy import select
 from sqlalchemy.exc import SQLAlchemyError
 
-from airflow.models import Trigger
+from airflow.models import RenderedTaskInstanceFields, Trigger
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
 from airflow.utils.state import State, TaskInstanceState
 
-from tests_common.test_utils.db import clear_db_runs
+from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields
 
 pytestmark = pytest.mark.db_test
 
@@ -410,3 +411,78 @@ class TestTIHealthEndpoint:
         # If successful, ensure last_heartbeat_at is updated
         session.refresh(ti)
         assert ti.last_heartbeat_at == time_now.add(minutes=10)
+
+
+class TestTIPutRTIF:
+    def setup_method(self):
+        clear_db_runs()
+        clear_rendered_ti_fields()
+
+    def teardown_method(self):
+        clear_db_runs()
+        clear_rendered_ti_fields()
+
+    def test_ti_put_rtif_success(self, client, session, create_task_instance):
+        ti = create_task_instance(
+            task_id="test_ti_put_rtif_success",
+            state=State.RUNNING,
+            session=session,
+        )
+        session.commit()
+
+        payload = {"field1": "rendered_value1", "field2": "rendered_value2"}
+
+        response = client.put(f"/execution/task-instances/{ti.id}/rtif", 
json=payload)
+        assert response.status_code == 201
+        assert response.json() == {"message": "Rendered task instance fields 
successfully set"}
+
+        session.expire_all()
+
+        rtifs = session.query(RenderedTaskInstanceFields).all()
+        assert len(rtifs) == 1
+
+        assert rtifs[0].dag_id == "dag"
+        assert rtifs[0].run_id == "test"
+        assert rtifs[0].task_id == "test_ti_put_rtif_success"
+        assert rtifs[0].map_index == -1
+        assert rtifs[0].rendered_fields == payload
+
+    def test_ti_put_rtif_missing_ti(self, client, session, 
create_task_instance):
+        create_task_instance(
+            task_id="test_ti_put_rtif_missing_ti",
+            state=State.RUNNING,
+            session=session,
+        )
+        session.commit()
+
+        payload = {"field1": "rendered_value1", "field2": "rendered_value2"}
+
+        random_id = uuid6.uuid7()
+        response = client.put(f"/execution/task-instances/{random_id}/rtif", 
json=payload)
+        assert response.status_code == 404
+        assert response.json()["detail"] == "Not Found"
+
+    def test_ti_put_rtif_extra_fields(self, client, session, 
create_task_instance):
+        ti = create_task_instance(
+            task_id="test_ti_put_rtif_missing_ti",
+            state=State.RUNNING,
+            session=session,
+        )
+        session.commit()
+
+        payload = {
+            "field1": "rendered_value1",
+            "field2": "rendered_value2",
+            "invalid_key": {"field3": "rendered_value3"},
+        }
+
+        response = client.put(f"/execution/task-instances/{ti.id}/rtif", 
json=payload)
+        assert response.status_code == 422
+        assert response.json()["detail"] == [
+            {
+                "input": {"field3": "rendered_value3"},
+                "loc": ["body", "invalid_key"],
+                "msg": "Input should be a valid string",
+                "type": "string_type",
+            }
+        ]

Reply via email to