This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new ab76581 Add DAG run endpoint for marking a dagrun success or failed(#17839) ab76581 is described below commit ab7658147445161fa3f7f2b139fbf9c223877f77 Author: bbenshalom <benny1...@gmail.com> AuthorDate: Thu Sep 2 22:32:45 2021 +0300 Add DAG run endpoint for marking a dagrun success or failed(#17839) Co-authored-by: bbenshalom <bbensha...@outbrain.com> Co-authored-by: Tzu-ping Chung <uranu...@gmail.com> Co-authored-by: Ephraim Anierobi <splendidzig...@gmail.com> --- .../api_connexion/endpoints/dag_run_endpoint.py | 31 +++++++- airflow/api_connexion/openapi/v1.yaml | 39 ++++++++++ airflow/api_connexion/schemas/dag_run_schema.py | 10 ++- .../endpoints/test_dag_run_endpoint.py | 88 ++++++++++++++++++++++ 4 files changed, 166 insertions(+), 2 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index feb6d24..e816aac 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional + from flask import current_app, g, request from marshmallow import ValidationError from sqlalchemy import or_ @@ -27,11 +29,12 @@ from airflow.api_connexion.schemas.dag_run_schema import ( dagrun_collection_schema, dagrun_schema, dagruns_batch_form_schema, + set_dagrun_state_form_schema, ) from airflow.models import DagModel, DagRun from airflow.security import permissions from airflow.utils.session import provide_session -from airflow.utils.state import State +from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType @@ -271,3 +274,29 @@ def post_dag_run(dag_id, session): ) raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists") + + +@security.requires_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + ] +) +@provide_session +def update_dag_run_state(dag_id: str, dag_run_id: str, session) -> dict: + """Set a state of a dag run.""" + dag_run: Optional[DagRun] = ( + session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + ) + if dag_run is None: + error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' + raise NotFound(error_message) + try: + post_body = set_dagrun_state_form_schema.load(request.json) + except ValidationError as err: + raise BadRequest(detail=str(err)) + + state = post_body['state'] + dag_run.set_state(state=DagRunState(state)) + session.merge(dag_run) + return dagrun_schema.dump(dag_run) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 2ce6804..8dae5f2 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -674,6 +674,35 @@ paths: '404': $ref: '#/components/responses/NotFound' + patch: + summary: Modify a DAG run + description: Modify a DAG run + x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint + operationId: update_dag_run_state + tags: [ UpdateDagRunState ] + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateDagRunState' + + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRun' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + /eventLogs: get: summary: List log entries @@ -1996,6 +2025,16 @@ components: required: - dag_id + UpdateDagRunState: + type: object + properties: + state: + description: The state to set this DagRun + type: string + enum: + - success + - failed + DAGRunCollection: type: object description: Collection of DAG runs. diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index 62ac172..fa51f10 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -18,7 +18,7 @@ import json from typing import List, NamedTuple -from marshmallow import fields, post_dump, pre_load +from marshmallow import fields, post_dump, pre_load, validate from marshmallow.schema import Schema from marshmallow.validate import Range from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -29,6 +29,7 @@ from airflow.api_connexion.parameters import validate_istimezone from airflow.api_connexion.schemas.enum_schemas import DagStateField from airflow.models.dagrun import DagRun from airflow.utils import timezone +from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -104,6 +105,12 @@ class DAGRunSchema(SQLAlchemySchema): return data +class SetDagRunStateFormSchema(Schema): + """Schema for handling the request of setting state of DAG run""" + + state = DagStateField(validate=validate.OneOf([DagRunState.SUCCESS.value, DagRunState.FAILED.value])) + + class DAGRunCollection(NamedTuple): """List of DAGRuns with metadata""" @@ -141,4 +148,5 @@ class DagRunsBatchFormSchema(Schema): dagrun_schema = DAGRunSchema() dagrun_collection_schema = DAGRunCollectionSchema() +set_dagrun_state_form_schema = SetDagRunStateFormSchema() dagruns_batch_form_schema = DagRunsBatchFormSchema() diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 71b5274..47d89cb 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -18,10 +18,12 @@ from datetime import timedelta from unittest import mock import pytest +from freezegun import freeze_time from parameterized import parameterized from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import DAG, DagModel, DagRun +from airflow.operators.dummy import DummyOperator from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session, provide_session @@ -1154,3 +1156,89 @@ class TestPostDagRun(TestDagRunEndpoint): environ_overrides={'REMOTE_USER': "test_view_dags"}, ) assert response.status_code == 403 + + +class TestPostSetDagRunState(TestDagRunEndpoint): + @pytest.mark.parametrize("state", ["failed", "success"]) + @freeze_time(TestDagRunEndpoint.default_time) + def test_should_respond_200(self, state, dag_maker): + dag_id = "TEST_DAG_ID" + dag_run_id = 'TEST_DAG_RUN_ID' + with dag_maker(dag_id): + DummyOperator(task_id='task_id') + dag_maker.create_dagrun(run_id=dag_run_id) + + request_json = {"state": state} + + response = self.client.patch( + f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}", + json=request_json, + environ_overrides={"REMOTE_USER": "test"}, + ) + + assert response.status_code == 200 + assert response.json == { + 'conf': {}, + 'dag_id': dag_id, + 'dag_run_id': dag_run_id, + 'end_date': self.default_time, + 'execution_date': dag_maker.start_date.isoformat(), + 'external_trigger': False, + 'logical_date': dag_maker.start_date.isoformat(), + 'start_date': dag_maker.start_date.isoformat(), + 'state': state, + } + + @pytest.mark.parametrize('invalid_state', ["running", "queued"]) + @freeze_time(TestDagRunEndpoint.default_time) + def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, dag_maker): + dag_id = "TEST_DAG_ID" + dag_run_id = 'TEST_DAG_RUN_ID' + with dag_maker(dag_id): + DummyOperator(task_id='task_id') + dag_maker.create_dagrun(run_id=dag_run_id) + + request_json = {"state": invalid_state} + + response = self.client.patch( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", + json=request_json, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json == { + 'detail': f"'{invalid_state}' is not one of ['success', 'failed'] - 'state'", + 'status': 400, + 'title': 'Bad Request', + 'type': EXCEPTIONS_LINK_MAP[400], + } + + def test_should_raises_401_unauthenticated(self, session): + response = self.client.patch( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", + json={ + "state": 'success', + }, + ) + + assert_401(response) + + def test_should_raise_403_forbidden(self): + response = self.client.patch( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", + json={ + "state": 'success', + }, + environ_overrides={'REMOTE_USER': "test_no_permissions"}, + ) + assert response.status_code == 403 + + def test_should_respond_404(self): + response = self.client.patch( + "api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", + json={ + "state": 'success', + }, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 404