This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8b1492edef7 AIP-72: Adding Endpoint to set rendered task instance
fields (#44692)
8b1492edef7 is described below
commit 8b1492edef7272b7d83b7644ee12a87806cef32b
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Dec 9 11:35:07 2024 +0530
AIP-72: Adding Endpoint to set rendered task instance fields (#44692)
---
.../execution_api/datamodels/taskinstance.py | 6 +-
.../execution_api/routes/task_instances.py | 33 ++++++++-
task_sdk/src/airflow/sdk/api/client.py | 8 +++
task_sdk/tests/api/test_client.py | 42 ++++++++++++
.../execution_api/routes/test_task_instances.py | 80 +++++++++++++++++++++-
5 files changed, 165 insertions(+), 4 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index ae05cc140c4..2b8830c6de1 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -21,7 +21,7 @@ import uuid
from datetime import timedelta
from typing import Annotated, Any, Literal, Union
-from pydantic import Discriminator, Field, Tag, WithJsonSchema
+from pydantic import Discriminator, Field, RootModel, Tag, WithJsonSchema
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
@@ -135,3 +135,7 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None
+
+
+"""Schema for setting RTIF for a task instance."""
+RTIFPayload = RootModel[dict[str, str]]
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 8d15a063f4f..90bbe1c1d3e 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -29,13 +29,14 @@ from sqlalchemy.sql import select
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+ RTIFPayload,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIStateUpdate,
TITerminalStatePayload,
)
-from airflow.models.taskinstance import TaskInstance as TI
+from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.trigger import Trigger
from airflow.utils import timezone
from airflow.utils.state import State
@@ -219,3 +220,33 @@ def ti_heartbeat(
# Update the last heartbeat time!
session.execute(update(TI).where(TI.id ==
ti_id_str).values(last_heartbeat_at=timezone.utcnow()))
log.debug("Task with %s state heartbeated", previous_state)
+
+
[email protected](
+ "/{task_instance_id}/rtif",
+ status_code=status.HTTP_201_CREATED,
+ # TODO: Add description to the operation
+ # TODO: Add Operation ID to control the function name in the OpenAPI spec
+ # TODO: Do we need to use create_openapi_http_exception_doc here?
+ responses={
+ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
+ status.HTTP_422_UNPROCESSABLE_ENTITY: {
+ "description": "Invalid payload for the setting rendered task
instance fields"
+ },
+ },
+)
+def ti_put_rtif(
+ task_instance_id: UUID,
+ put_rtif_payload: RTIFPayload,
+ session: SessionDep,
+):
+ """Add an RTIF entry for a task instance, sent by the worker."""
+ ti_id_str = str(task_instance_id)
+ task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
+ if not task_instance:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ )
+ _update_rtif(task_instance, put_rtif_payload.model_dump(), session)
+
+ return {"message": "Rendered task instance fields successfully set"}
diff --git a/task_sdk/src/airflow/sdk/api/client.py
b/task_sdk/src/airflow/sdk/api/client.py
index 5de5c7a8d90..568eb3c90bd 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -134,6 +134,14 @@ class TaskInstanceOperations:
# Create a deferred state payload from msg
self.client.patch(f"task-instances/{id}/state",
content=body.model_dump_json())
+ def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> dict[str, bool]:
+ """Set Rendered Task Instance Fields via the API server."""
+ self.client.put(f"task-instances/{id}/rtif", json=body)
+ # Any error from the server will anyway be propagated down to the
supervisor,
+ # so we choose to send a generic response to the supervisor over the
server response to
+ # decouple from the server response string
+ return {"ok": True}
+
class ConnectionOperations:
__slots__ = ("client",)
diff --git a/task_sdk/tests/api/test_client.py
b/task_sdk/tests/api/test_client.py
index 900d1c90de5..ca8a6af3dd3 100644
--- a/task_sdk/tests/api/test_client.py
+++ b/task_sdk/tests/api/test_client.py
@@ -21,6 +21,7 @@ import json
import httpx
import pytest
+import uuid6
from airflow.sdk.api.client import Client, RemoteValidationError,
ServerResponseError
from airflow.sdk.api.datamodels._generated import VariableResponse,
XComResponse
@@ -84,6 +85,47 @@ def make_client(transport: httpx.MockTransport) -> Client:
return Client(base_url="test://server", token="", transport=transport)
+class TestTaskInstanceOperations:
+ """
+ Test that the TestVariableOperations class works as expected. While the
operations are simple, it
+ still catches the basic functionality of the client for task instances
including endpoint and
+ response parsing.
+ """
+
+ # TODO: Add tests for different ti endpoints
+
+ @pytest.mark.parametrize(
+ "rendered_fields",
+ [
+ pytest.param({"field1": "rendered_value1", "field2":
"rendered_value2"}, id="simple-rendering"),
+ pytest.param(
+ {
+ "field1": "ClassWithCustomAttributes({'nested1':
ClassWithCustomAttributes("
+ "{'att1': 'test', 'att2': 'test2'), "
+ "'nested2': ClassWithCustomAttributes("
+ "{'att3': 'test3', 'att4': 'test4')"
+ },
+ id="complex-rendering",
+ ),
+ ],
+ )
+ def test_taskinstance_set_rtif_success(self, rendered_fields):
+ TI_ID = uuid6.uuid7()
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{TI_ID}/rtif":
+ return httpx.Response(
+ status_code=201,
+ json={"message": "Rendered task instance fields
successfully set"},
+ )
+ return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_instances.set_rtif(id=TI_ID, body=rendered_fields)
+
+ assert result == {"ok": True}
+
+
class TestVariableOperations:
"""
Test that the VariableOperations class works as expected. While the
operations are simple, it
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index 251ef3388cd..a0011e3ad89 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -20,15 +20,16 @@ from __future__ import annotations
from unittest import mock
import pytest
+import uuid6
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
-from airflow.models import Trigger
+from airflow.models import RenderedTaskInstanceFields, Trigger
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState
-from tests_common.test_utils.db import clear_db_runs
+from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields
pytestmark = pytest.mark.db_test
@@ -410,3 +411,78 @@ class TestTIHealthEndpoint:
# If successful, ensure last_heartbeat_at is updated
session.refresh(ti)
assert ti.last_heartbeat_at == time_now.add(minutes=10)
+
+
+class TestTIPutRTIF:
+ def setup_method(self):
+ clear_db_runs()
+ clear_rendered_ti_fields()
+
+ def teardown_method(self):
+ clear_db_runs()
+ clear_rendered_ti_fields()
+
+ def test_ti_put_rtif_success(self, client, session, create_task_instance):
+ ti = create_task_instance(
+ task_id="test_ti_put_rtif_success",
+ state=State.RUNNING,
+ session=session,
+ )
+ session.commit()
+
+ payload = {"field1": "rendered_value1", "field2": "rendered_value2"}
+
+ response = client.put(f"/execution/task-instances/{ti.id}/rtif",
json=payload)
+ assert response.status_code == 201
+ assert response.json() == {"message": "Rendered task instance fields
successfully set"}
+
+ session.expire_all()
+
+ rtifs = session.query(RenderedTaskInstanceFields).all()
+ assert len(rtifs) == 1
+
+ assert rtifs[0].dag_id == "dag"
+ assert rtifs[0].run_id == "test"
+ assert rtifs[0].task_id == "test_ti_put_rtif_success"
+ assert rtifs[0].map_index == -1
+ assert rtifs[0].rendered_fields == payload
+
+ def test_ti_put_rtif_missing_ti(self, client, session,
create_task_instance):
+ create_task_instance(
+ task_id="test_ti_put_rtif_missing_ti",
+ state=State.RUNNING,
+ session=session,
+ )
+ session.commit()
+
+ payload = {"field1": "rendered_value1", "field2": "rendered_value2"}
+
+ random_id = uuid6.uuid7()
+ response = client.put(f"/execution/task-instances/{random_id}/rtif",
json=payload)
+ assert response.status_code == 404
+ assert response.json()["detail"] == "Not Found"
+
+ def test_ti_put_rtif_extra_fields(self, client, session,
create_task_instance):
+ ti = create_task_instance(
+ task_id="test_ti_put_rtif_missing_ti",
+ state=State.RUNNING,
+ session=session,
+ )
+ session.commit()
+
+ payload = {
+ "field1": "rendered_value1",
+ "field2": "rendered_value2",
+ "invalid_key": {"field3": "rendered_value3"},
+ }
+
+ response = client.put(f"/execution/task-instances/{ti.id}/rtif",
json=payload)
+ assert response.status_code == 422
+ assert response.json()["detail"] == [
+ {
+ "input": {"field3": "rendered_value3"},
+ "loc": ["body", "invalid_key"],
+ "msg": "Input should be a valid string",
+ "type": "string_type",
+ }
+ ]