This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fdd353a03e4 AIP-72: Adding PUT Variable Endpoint for execution API 
(#44449)
fdd353a03e4 is described below

commit fdd353a03e4b058fff834b611880c2475abfac61
Author: Amogh Desai <amoghrajesh1...@gmail.com>
AuthorDate: Fri Nov 29 12:07:45 2024 +0530

    AIP-72: Adding PUT Variable Endpoint for execution API (#44449)
---
 .../execution_api/datamodels/variable.py           |  7 ++
 .../api_fastapi/execution_api/routes/variables.py  | 26 ++++++-
 .../execution_api/routes/test_variables.py         | 87 ++++++++++++++++++++++
 3 files changed, 118 insertions(+), 2 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/datamodels/variable.py 
b/airflow/api_fastapi/execution_api/datamodels/variable.py
index 548d5934766..ce542af0d84 100644
--- a/airflow/api_fastapi/execution_api/datamodels/variable.py
+++ b/airflow/api_fastapi/execution_api/datamodels/variable.py
@@ -27,3 +27,10 @@ class VariableResponse(BaseModel):
 
     key: str
     val: str | None = Field(alias="value")
+
+
+class VariablePostBody(BaseModel):
+    """Request body schema for creating variables."""
+
+    value: str | None = Field(serialization_alias="val")
+    description: str | None = Field(default=None)
diff --git a/airflow/api_fastapi/execution_api/routes/variables.py 
b/airflow/api_fastapi/execution_api/routes/variables.py
index e8e2012e8d1..0e454f7dae0 100644
--- a/airflow/api_fastapi/execution_api/routes/variables.py
+++ b/airflow/api_fastapi/execution_api/routes/variables.py
@@ -24,7 +24,7 @@ from fastapi import HTTPException, status
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.execution_api import deps
 from airflow.api_fastapi.execution_api.datamodels.token import TIToken
-from airflow.api_fastapi.execution_api.datamodels.variable import 
VariableResponse
+from airflow.api_fastapi.execution_api.datamodels.variable import 
VariablePostBody, VariableResponse
 from airflow.models.variable import Variable
 
 # TODO: Add dependency on JWT token
@@ -67,7 +67,29 @@ def get_variable(variable_key: str, token: deps.TokenDep) -> 
VariableResponse:
     return VariableResponse(key=variable_key, value=variable_value)
 
 
-def has_variable_access(variable_key: str, token: TIToken) -> bool:
+@router.put(
+    "/{variable_key}",
+    status_code=status.HTTP_201_CREATED,
+    responses={
+        status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
+        status.HTTP_403_FORBIDDEN: {"description": "Task does not have access 
to the variable"},
+    },
+)
+def put_variable(variable_key: str, body: VariablePostBody, token: 
deps.TokenDep):
+    """Set an Airflow Variable."""
+    if not has_variable_access(variable_key, token, write_access=True):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail={
+                "reason": "access_denied",
+                "message": f"Task does not have access to write variable 
{variable_key}",
+            },
+        )
+    Variable.set(key=variable_key, value=body.value, 
description=body.description)
+    return {"message": "Variable successfully set"}
+
+
+def has_variable_access(variable_key: str, token: TIToken, write_access: bool 
= False) -> bool:
     """Check if the task has access to the variable."""
     # TODO: Placeholder for actual implementation
 
diff --git a/tests/api_fastapi/execution_api/routes/test_variables.py 
b/tests/api_fastapi/execution_api/routes/test_variables.py
index 67247e4adb9..9ae7f9a2739 100644
--- a/tests/api_fastapi/execution_api/routes/test_variables.py
+++ b/tests/api_fastapi/execution_api/routes/test_variables.py
@@ -23,9 +23,18 @@ import pytest
 
 from airflow.models.variable import Variable
 
+from tests_common.test_utils.db import clear_db_variables
+
 pytestmark = pytest.mark.db_test
 
 
+@pytest.fixture(autouse=True)
+def setup_method():
+    clear_db_variables()
+    yield
+    clear_db_variables()
+
+
 class TestGetVariable:
     def test_variable_get_from_db(self, client, session):
         Variable.set(key="var1", value="value", session=session)
@@ -75,3 +84,81 @@ class TestGetVariable:
                 "message": "Task does not have access to variable key1",
             }
         }
+
+
+class TestPostVariable:
+    @pytest.mark.parametrize(
+        "payload",
+        [
+            pytest.param({"value": "{}", "description": "description"}, 
id="valid-payload"),
+            pytest.param({"value": "{}"}, id="missing-description"),
+        ],
+    )
+    def test_should_create_variable(self, client, payload, session):
+        key = "var_create"
+        response = client.put(
+            f"/execution/variables/{key}",
+            json=payload,
+        )
+        assert response.status_code == 201
+
+        var_from_db = session.query(Variable).where(Variable.key == 
"var_create").first()
+        assert var_from_db is not None
+        assert var_from_db.key == key
+        assert var_from_db.val == payload["value"]
+        if "description" in payload:
+            assert var_from_db.description == payload["description"]
+
+    @pytest.mark.parametrize(
+        "key, status_code, payload",
+        [
+            pytest.param("", 404, {"value": "{}", "description": 
"description"}, id="missing-key"),
+            pytest.param("var_create", 422, {"description": "description"}, 
id="missing-value"),
+        ],
+    )
+    def test_variable_missing_fields(self, client, key, status_code, payload, 
session):
+        response = client.put(
+            f"/execution/variables/{key}",
+            json=payload,
+        )
+        assert response.status_code == status_code
+        if response.status_code == 422:
+            assert response.json()["detail"][0]["type"] == "missing"
+            assert response.json()["detail"][0]["msg"] == "Field required"
+
+    def test_overwriting_existing_variable(self, client, session):
+        key = "var_create"
+        Variable.set(key=key, value="value", session=session)
+        session.commit()
+
+        payload = {"value": "new_value"}
+        response = client.put(
+            f"/execution/variables/{key}",
+            json=payload,
+        )
+        assert response.status_code == 201
+        # variable should have been updated to the new value
+        var_from_db = session.query(Variable).where(Variable.key == 
key).first()
+        assert var_from_db is not None
+        assert var_from_db.key == key
+        assert var_from_db.val == payload["value"]
+
+    def test_post_variable_access_denied(self, client):
+        with mock.patch(
+            
"airflow.api_fastapi.execution_api.routes.variables.has_variable_access", 
return_value=False
+        ):
+            key = "var_create"
+            payload = {"value": "{}"}
+            response = client.put(
+                f"/execution/variables/{key}",
+                json=payload,
+            )
+
+        # Assert response status code and detail for access denied
+        assert response.status_code == 403
+        assert response.json() == {
+            "detail": {
+                "reason": "access_denied",
+                "message": "Task does not have access to write variable 
var_create",
+            }
+        }

Reply via email to