This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ab76581  Add DAG run endpoint for marking a dagrun success or 
failed(#17839)
ab76581 is described below

commit ab7658147445161fa3f7f2b139fbf9c223877f77
Author: bbenshalom <benny1...@gmail.com>
AuthorDate: Thu Sep 2 22:32:45 2021 +0300

    Add DAG run endpoint for marking a dagrun success or failed(#17839)
    
    
    Co-authored-by: bbenshalom <bbensha...@outbrain.com>
    Co-authored-by: Tzu-ping Chung <uranu...@gmail.com>
    Co-authored-by: Ephraim Anierobi <splendidzig...@gmail.com>
---
 .../api_connexion/endpoints/dag_run_endpoint.py    | 31 +++++++-
 airflow/api_connexion/openapi/v1.yaml              | 39 ++++++++++
 airflow/api_connexion/schemas/dag_run_schema.py    | 10 ++-
 .../endpoints/test_dag_run_endpoint.py             | 88 ++++++++++++++++++++++
 4 files changed, 166 insertions(+), 2 deletions(-)

diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py 
b/airflow/api_connexion/endpoints/dag_run_endpoint.py
index feb6d24..e816aac 100644
--- a/airflow/api_connexion/endpoints/dag_run_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py
@@ -14,6 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from typing import Optional
+
 from flask import current_app, g, request
 from marshmallow import ValidationError
 from sqlalchemy import or_
@@ -27,11 +29,12 @@ from airflow.api_connexion.schemas.dag_run_schema import (
     dagrun_collection_schema,
     dagrun_schema,
     dagruns_batch_form_schema,
+    set_dagrun_state_form_schema,
 )
 from airflow.models import DagModel, DagRun
 from airflow.security import permissions
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.types import DagRunType
 
 
@@ -271,3 +274,29 @@ def post_dag_run(dag_id, session):
         )
 
     raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: 
'{run_id}' already exists")
+
+
+@security.requires_access(
+    [
+        (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+        (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
+    ]
+)
+@provide_session
+def update_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict:
+    """Set a state of a dag run."""
+    dag_run: Optional[DagRun] = (
+        session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == 
dag_run_id).one_or_none()
+    )
+    if dag_run is None:
+        error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
+        raise NotFound(error_message)
+    try:
+        post_body = set_dagrun_state_form_schema.load(request.json)
+    except ValidationError as err:
+        raise BadRequest(detail=str(err))
+
+    state = post_body['state']
+    dag_run.set_state(state=DagRunState(state))
+    session.merge(dag_run)
+    return dagrun_schema.dump(dag_run)
diff --git a/airflow/api_connexion/openapi/v1.yaml 
b/airflow/api_connexion/openapi/v1.yaml
index 2ce6804..8dae5f2 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -674,6 +674,35 @@ paths:
         '404':
           $ref: '#/components/responses/NotFound'
 
+    patch:
+      summary: Modify a DAG run
+      description: Modify a DAG run
+      x-openapi-router-controller: 
airflow.api_connexion.endpoints.dag_run_endpoint
+      operationId: update_dag_run_state
+      tags: [ UpdateDagRunState ]
+      requestBody:
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/UpdateDagRunState'
+
+      responses:
+        '200':
+          description: Success.
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/DAGRun'
+        '400':
+          $ref: '#/components/responses/BadRequest'
+        '401':
+          $ref: '#/components/responses/Unauthenticated'
+        '403':
+          $ref: '#/components/responses/PermissionDenied'
+        '404':
+          $ref: '#/components/responses/NotFound'
+
   /eventLogs:
     get:
       summary: List log entries
@@ -1996,6 +2025,16 @@ components:
       required:
         - dag_id
 
+    UpdateDagRunState:
+      type: object
+      properties:
+        state:
+          description: The state to set this DagRun
+          type: string
+          enum:
+            - success
+            - failed
+
     DAGRunCollection:
       type: object
       description: Collection of DAG runs.
diff --git a/airflow/api_connexion/schemas/dag_run_schema.py 
b/airflow/api_connexion/schemas/dag_run_schema.py
index 62ac172..fa51f10 100644
--- a/airflow/api_connexion/schemas/dag_run_schema.py
+++ b/airflow/api_connexion/schemas/dag_run_schema.py
@@ -18,7 +18,7 @@
 import json
 from typing import List, NamedTuple
 
-from marshmallow import fields, post_dump, pre_load
+from marshmallow import fields, post_dump, pre_load, validate
 from marshmallow.schema import Schema
 from marshmallow.validate import Range
 from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
@@ -29,6 +29,7 @@ from airflow.api_connexion.parameters import 
validate_istimezone
 from airflow.api_connexion.schemas.enum_schemas import DagStateField
 from airflow.models.dagrun import DagRun
 from airflow.utils import timezone
+from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunType
 
 
@@ -104,6 +105,12 @@ class DAGRunSchema(SQLAlchemySchema):
         return data
 
 
+class SetDagRunStateFormSchema(Schema):
+    """Schema for handling the request of setting state of DAG run"""
+
+    state = DagStateField(validate=validate.OneOf([DagRunState.SUCCESS.value, 
DagRunState.FAILED.value]))
+
+
 class DAGRunCollection(NamedTuple):
     """List of DAGRuns with metadata"""
 
@@ -141,4 +148,5 @@ class DagRunsBatchFormSchema(Schema):
 
 dagrun_schema = DAGRunSchema()
 dagrun_collection_schema = DAGRunCollectionSchema()
+set_dagrun_state_form_schema = SetDagRunStateFormSchema()
 dagruns_batch_form_schema = DagRunsBatchFormSchema()
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py 
b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index 71b5274..47d89cb 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -18,10 +18,12 @@ from datetime import timedelta
 from unittest import mock
 
 import pytest
+from freezegun import freeze_time
 from parameterized import parameterized
 
 from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
 from airflow.models import DAG, DagModel, DagRun
+from airflow.operators.dummy import DummyOperator
 from airflow.security import permissions
 from airflow.utils import timezone
 from airflow.utils.session import create_session, provide_session
@@ -1154,3 +1156,89 @@ class TestPostDagRun(TestDagRunEndpoint):
             environ_overrides={'REMOTE_USER': "test_view_dags"},
         )
         assert response.status_code == 403
+
+
+class TestPostSetDagRunState(TestDagRunEndpoint):
+    @pytest.mark.parametrize("state", ["failed", "success"])
+    @freeze_time(TestDagRunEndpoint.default_time)
+    def test_should_respond_200(self, state, dag_maker):
+        dag_id = "TEST_DAG_ID"
+        dag_run_id = 'TEST_DAG_RUN_ID'
+        with dag_maker(dag_id):
+            DummyOperator(task_id='task_id')
+        dag_maker.create_dagrun(run_id=dag_run_id)
+
+        request_json = {"state": state}
+
+        response = self.client.patch(
+            f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}",
+            json=request_json,
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+
+        assert response.status_code == 200
+        assert response.json == {
+            'conf': {},
+            'dag_id': dag_id,
+            'dag_run_id': dag_run_id,
+            'end_date': self.default_time,
+            'execution_date': dag_maker.start_date.isoformat(),
+            'external_trigger': False,
+            'logical_date': dag_maker.start_date.isoformat(),
+            'start_date': dag_maker.start_date.isoformat(),
+            'state': state,
+        }
+
+    @pytest.mark.parametrize('invalid_state', ["running", "queued"])
+    @freeze_time(TestDagRunEndpoint.default_time)
+    def test_should_response_400_for_non_existing_dag_run_state(self, 
invalid_state, dag_maker):
+        dag_id = "TEST_DAG_ID"
+        dag_run_id = 'TEST_DAG_RUN_ID'
+        with dag_maker(dag_id):
+            DummyOperator(task_id='task_id')
+        dag_maker.create_dagrun(run_id=dag_run_id)
+
+        request_json = {"state": invalid_state}
+
+        response = self.client.patch(
+            "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1",
+            json=request_json,
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert response.status_code == 400
+        assert response.json == {
+            'detail': f"'{invalid_state}' is not one of ['success', 'failed'] 
- 'state'",
+            'status': 400,
+            'title': 'Bad Request',
+            'type': EXCEPTIONS_LINK_MAP[400],
+        }
+
+    def test_should_raises_401_unauthenticated(self, session):
+        response = self.client.patch(
+            "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1",
+            json={
+                "state": 'success',
+            },
+        )
+
+        assert_401(response)
+
+    def test_should_raise_403_forbidden(self):
+        response = self.client.patch(
+            "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1",
+            json={
+                "state": 'success',
+            },
+            environ_overrides={'REMOTE_USER': "test_no_permissions"},
+        )
+        assert response.status_code == 403
+
+    def test_should_respond_404(self):
+        response = self.client.patch(
+            "api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1",
+            json={
+                "state": 'success',
+            },
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert response.status_code == 404

Reply via email to