This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v3-0-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 2b34a6544401673c0200f3fe57573e437d5f56d4 Author: Amogh Desai <[email protected]> AuthorDate: Mon Aug 11 13:17:55 2025 +0530 Restore ``get_previous_dagrun`` functionality for task context (#53655) Co-authored-by: Kaxil Naik <[email protected]> (cherry picked from commit 35d23c3222472ef354acc34fd3a76237bced727b) --- .../execution_api/datamodels/taskinstance.py | 8 +- .../api_fastapi/execution_api/routes/dag_runs.py | 45 ++++++- .../api_fastapi/execution_api/versions/__init__.py | 4 + .../execution_api/versions/v2025_08_10.py | 39 ++++++ .../execution_api/versions/head/test_dag_runs.py | 139 +++++++++++++++++++++ .../versions/head/test_task_instances.py | 4 +- .../tests/unit/dag_processing/test_processor.py | 2 + devel-common/src/tests_common/pytest_plugin.py | 27 ++-- .../unit/openlineage/plugins/test_listener.py | 29 +++-- task-sdk/src/airflow/sdk/api/client.py | 18 +++ .../src/airflow/sdk/api/datamodels/_generated.py | 3 +- task-sdk/src/airflow/sdk/execution_time/comms.py | 17 +++ .../src/airflow/sdk/execution_time/supervisor.py | 7 ++ .../src/airflow/sdk/execution_time/task_runner.py | 27 ++++ task-sdk/src/airflow/sdk/types.py | 2 + task-sdk/tests/conftest.py | 2 + task-sdk/tests/task_sdk/api/test_client.py | 82 ++++++++++++ .../tests/task_sdk/execution_time/test_comms.py | 1 + .../task_sdk/execution_time/test_supervisor.py | 70 +++++++++++ .../task_sdk/execution_time/test_task_runner.py | 63 ++++++++++ 20 files changed, 555 insertions(+), 34 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 2d7968bbc62..f1e42ef1e0b 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -35,7 +35,12 @@ from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse -from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState +from airflow.utils.state import ( + DagRunState, + IntermediateTIState, + TaskInstanceState as TIState, + TerminalTIState, +) from airflow.utils.types import DagRunType AwareDatetimeAdapter = TypeAdapter(AwareDatetime) @@ -292,6 +297,7 @@ class DagRun(StrictBaseModel): end_date: UtcDateTime | None clear_number: int = 0 run_type: DagRunType + state: DagRunState conf: Annotated[dict[str, Any], Field(default_factory=dict)] consumed_asset_events: list[AssetEventDagRunReference] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py index db27f7bd93e..d0dcef4faeb 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py @@ -27,10 +27,12 @@ from airflow.api.common.trigger_dag import trigger_dag from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload +from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun from airflow.exceptions import DagRunAlreadyExists from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag -from airflow.models.dagrun import DagRun +from airflow.models.dagrun import DagRun as DagRunModel +from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType router = APIRouter() @@ -140,7 +142,9 @@ def get_dagrun_state( session: SessionDep, ) -> DagRunStateResponse: """Get a DAG Run State.""" - dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)) + dag_run = session.scalar( + select(DagRunModel).where(DagRunModel.dag_id == dag_id, DagRunModel.run_id == run_id) + ) if dag_run is None: raise HTTPException( status.HTTP_404_NOT_FOUND, @@ -162,16 +166,45 @@ def get_dr_count( states: Annotated[list[str] | None, Query()] = None, ) -> int: """Get the count of DAG runs matching the given criteria.""" - query = select(func.count()).select_from(DagRun).where(DagRun.dag_id == dag_id) + query = select(func.count()).select_from(DagRunModel).where(DagRunModel.dag_id == dag_id) if logical_dates: - query = query.where(DagRun.logical_date.in_(logical_dates)) + query = query.where(DagRunModel.logical_date.in_(logical_dates)) if run_ids: - query = query.where(DagRun.run_id.in_(run_ids)) + query = query.where(DagRunModel.run_id.in_(run_ids)) if states: - query = query.where(DagRun.state.in_(states)) + query = query.where(DagRunModel.state.in_(states)) count = session.scalar(query) return count or 0 + + [email protected]("/{dag_id}/previous", status_code=status.HTTP_200_OK) +def get_previous_dagrun( + dag_id: str, + logical_date: UtcDateTime, + session: SessionDep, + state: Annotated[DagRunState | None, Query()] = None, +) -> DagRun | None: + """Get the previous DAG run before the given logical date, optionally filtered by state.""" + query = ( + select(DagRunModel) + .where( + DagRunModel.dag_id == dag_id, + DagRunModel.logical_date < logical_date, + ) + .order_by(DagRunModel.logical_date.desc()) + .limit(1) + ) + + if state: + query = query.where(DagRunModel.state == state) + + dag_run = session.scalar(query) + + if not dag_run: + return None + + return DagRun.model_validate(dag_run) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 5462f102974..fee781a8ecf 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -21,9 +21,13 @@ from cadwyn import HeadVersion, Version, VersionBundle from airflow.api_fastapi.execution_api.versions.v2025_04_28 import AddRenderedMapIndexField from airflow.api_fastapi.execution_api.versions.v2025_05_20 import DowngradeUpstreamMapIndexes +from airflow.api_fastapi.execution_api.versions.v2025_08_10 import ( + AddDagRunStateFieldAndPreviousEndpoint, +) bundle = VersionBundle( HeadVersion(), + Version("2025-08-10", AddDagRunStateFieldAndPreviousEndpoint), Version("2025-05-20", DowngradeUpstreamMapIndexes), Version("2025-04-28", AddRenderedMapIndexField), Version("2025-04-11"), diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py new file mode 100644 index 00000000000..188eaec2d79 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, endpoint, schema + +from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun, TIRunContext + + +class AddDagRunStateFieldAndPreviousEndpoint(VersionChange): + """Add the `state` field to DagRun model and `/dag-runs/{dag_id}/previous` endpoint.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(DagRun).field("state").didnt_exist, + endpoint("/dag-runs/{dag_id}/previous", ["GET"]).didnt_exist, + ) + + @convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type] + def remove_state_from_dag_run(response: ResponseInfo) -> None: # type: ignore[misc] + """Remove the `state` field from the dag_run object when converting to the previous version.""" + if "dag_run" in response.body and isinstance(response.body["dag_run"], dict): + response.body["dag_run"].pop("state", None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py index f9f8d489d3d..ac414f53ee8 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py @@ -314,3 +314,142 @@ class TestGetDagRunCount: ) assert response.status_code == 200 assert response.json() == 2 + + +class TestGetPreviousDagRun: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + def test_get_previous_dag_run_basic(self, client, session, dag_maker): + """Test getting the previous DAG run without state filtering.""" + dag_id = "test_get_previous_basic" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create multiple DAG runs + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + dag_maker.create_dagrun( + run_id="run2", logical_date=timezone.datetime(2025, 1, 5), state=DagRunState.FAILED + ) + dag_maker.create_dagrun( + run_id="run3", logical_date=timezone.datetime(2025, 1, 10), state=DagRunState.SUCCESS + ) + session.commit() + + # Query for previous DAG run before 2025-01-10 + response = client.get( + f"/execution/dag-runs/{dag_id}/previous", + params={ + "logical_date": timezone.datetime(2025, 1, 10).isoformat(), + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["dag_id"] == dag_id + assert result["run_id"] == "run2" # Most recent before 2025-01-10 + assert result["state"] == "failed" + + def test_get_previous_dag_run_with_state_filter(self, client, session, dag_maker): + """Test getting the previous DAG run with state filtering.""" + dag_id = "test_get_previous_with_state" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create multiple DAG runs with different states + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + dag_maker.create_dagrun( + run_id="run2", logical_date=timezone.datetime(2025, 1, 5), state=DagRunState.FAILED + ) + dag_maker.create_dagrun( + run_id="run3", logical_date=timezone.datetime(2025, 1, 8), state=DagRunState.SUCCESS + ) + session.commit() + + # Query for previous successful DAG run before 2025-01-10 + response = client.get( + f"/execution/dag-runs/{dag_id}/previous", + params={"logical_date": timezone.datetime(2025, 1, 10).isoformat(), "state": "success"}, + ) + + assert response.status_code == 200 + result = response.json() + assert result["dag_id"] == dag_id + assert result["run_id"] == "run3" # Most recent successful run before 2025-01-10 + assert result["state"] == "success" + + def test_get_previous_dag_run_no_previous_found(self, client, session, dag_maker): + """Test getting previous DAG run when none exists returns null.""" + dag_id = "test_get_previous_none" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create only one DAG run - no previous should exist + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + + response = client.get(f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-01T00:00:00Z") + + assert response.status_code == 200 + assert response.json() is None # Should return null + + def test_get_previous_dag_run_no_matching_state(self, client, session, dag_maker): + """Test getting previous DAG run with state filter that matches nothing returns null.""" + dag_id = "test_get_previous_no_match" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create DAG runs with different states + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.FAILED + ) + dag_maker.create_dagrun( + run_id="run2", logical_date=timezone.datetime(2025, 1, 2), state=DagRunState.FAILED + ) + + # Look for previous success but only failed runs exist + response = client.get( + f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-03T00:00:00Z&state=success" + ) + + assert response.status_code == 200 + assert response.json() is None + + def test_get_previous_dag_run_dag_not_found(self, client, session): + """Test getting previous DAG run for non-existent DAG returns 404.""" + response = client.get( + "/execution/dag-runs/nonexistent_dag/previous?logical_date=2025-01-01T00:00:00Z" + ) + + assert response.status_code == 200 + assert response.json() is None + + def test_get_previous_dag_run_invalid_state_parameter(self, client, session, dag_maker): + """Test that invalid state parameter returns 422 validation error.""" + dag_id = "test_get_previous_invalid_state" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + session.commit() + + response = client.get( + f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-02T00:00:00Z&state=invalid_state" + ) + + assert response.status_code == 422 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 3c107f61863..4f1e28a11f6 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -35,7 +35,7 @@ from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import Asset, TaskGroup, task, task_group from airflow.utils import timezone -from airflow.utils.state import State, TaskInstanceState, TerminalTIState +from airflow.utils.state import DagRunState, State, TaskInstanceState, TerminalTIState from tests_common.test_utils.db import ( clear_db_assets, @@ -155,6 +155,7 @@ class TestTIRunState: ti = create_task_instance( task_id="test_ti_run_state_to_running", state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, session=session, start_date=instant, dag_id=str(uuid4()), @@ -184,6 +185,7 @@ class TestTIRunState: "data_interval_end": instant_str, "run_after": instant_str, "start_date": instant_str, + "state": "running", "end_date": None, "run_type": "manual", "conf": {}, diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 1a3925c077d..dc85bfb1b21 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -57,6 +57,7 @@ from airflow.models import DagBag, DagRun from airflow.models.baseoperator import BaseOperator from airflow.sdk import DAG from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import DagRunState from airflow.sdk.execution_time import comms from airflow.utils import timezone from airflow.utils.session import create_session @@ -957,6 +958,7 @@ class TestExecuteTaskCallbacks: logical_date=timezone.utcnow(), start_date=timezone.utcnow(), run_type="manual", + state=DagRunState.RUNNING, ) dag_run.run_after = timezone.utcnow() diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 2288fefc268..e161771b84a 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2134,7 +2134,7 @@ def create_runtime_ti(mocked_parse): should_retry: bool | None = None, max_tries: int | None = None, ) -> RuntimeTaskInstance: - from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + from airflow.sdk.api.datamodels._generated import DagRun, DagRunState, TIRunContext from airflow.utils.types import DagRunType if not ti_id: @@ -2167,17 +2167,20 @@ def create_runtime_ti(mocked_parse): run_after = data_interval_end or logical_date or timezone.utcnow() ti_context = TIRunContext( - dag_run=DagRun( - dag_id=dag_id, - run_id=run_id, - logical_date=logical_date, # type: ignore - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, - start_date=start_date, # type: ignore - run_type=run_type, # type: ignore - run_after=run_after, # type: ignore - conf=conf, - consumed_asset_events=[], + dag_run=DagRun.model_validate( + { + "dag_id": dag_id, + "run_id": run_id, + "logical_date": logical_date, # type: ignore + "data_interval_start": data_interval_start, + "data_interval_end": data_interval_end, + "start_date": start_date, # type: ignore + "run_type": run_type, # type: ignore + "run_after": run_after, # type: ignore + "conf": conf, + "consumed_asset_events": [], + **({"state": DagRunState.RUNNING} if "state" in DagRun.model_fields else {}), + } ), task_reschedule_count=task_reschedule_count, max_tries=task_retries if max_tries is None else max_tries, diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index eff8dc9efcb..555725f0246 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -853,6 +853,7 @@ class TestOpenLineageListenerAirflow3: task_instance.dag_run.clear_number = 0 task_instance.dag_run.logical_date = timezone.datetime(2020, 1, 1, 1, 1, 1) task_instance.dag_run.run_after = timezone.datetime(2020, 1, 1, 1, 1, 1) + task_instance.dag_run.state = DagRunState.RUNNING task_instance.task = None task_instance.dag = None task_instance.task_id = "task_id" @@ -862,6 +863,7 @@ class TestOpenLineageListenerAirflow3: # RuntimeTaskInstance is used when on worker from airflow.sdk.api.datamodels._generated import ( DagRun as SdkDagRun, + DagRunState as SdkDagRunState, DagRunType, TaskInstance as SdkTaskInstance, TIRunContext, @@ -887,19 +889,20 @@ class TestOpenLineageListenerAirflow3: **sdk_task_instance.model_dump(exclude_unset=True), task=task, _ti_context_from_server=TIRunContext( - dag_run=SdkDagRun( - dag_id="dag_id", - run_id="dag_run_run_id", - logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), - data_interval_start=None, - data_interval_end=None, - start_date=timezone.datetime(2023, 1, 1, 13, 1, 1), - end_date=timezone.datetime(2023, 1, 3, 13, 1, 1), - clear_number=0, - run_type=DagRunType.MANUAL, - run_after=timezone.datetime(2023, 1, 3, 13, 1, 1), - conf=None, - consumed_asset_events=[], + dag_run=SdkDagRun.model_validate( + { + "dag_id": "dag_id_from_dagrun_not_ti", + "run_id": "dag_run_run_id_from_dagrun_not_ti", + "logical_date": timezone.datetime(2020, 1, 1, 1, 1, 1), + "start_date": timezone.datetime(2023, 1, 1, 13, 1, 1), + "end_date": timezone.datetime(2023, 1, 3, 13, 1, 1), + "run_type": DagRunType.MANUAL, + "run_after": timezone.datetime(2023, 1, 3, 13, 1, 1), + "consumed_asset_events": [], + **( + {"state": SdkDagRunState.RUNNING} if "state" in SdkDagRun.model_fields else {} + ), + } ), task_reschedule_count=0, max_tries=1, diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 3413e98189c..f36c46f45de 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -69,6 +69,7 @@ from airflow.sdk.execution_time.comms import ( DRCount, ErrorResponse, OKResponse, + PreviousDagRunResult, SkipDownstreamTasks, TaskRescheduleStartDate, TICount, @@ -620,6 +621,23 @@ class DagRunOperations: resp = self.client.get("dag-runs/count", params=params) return DRCount(count=resp.json()) + def get_previous( + self, + dag_id: str, + logical_date: datetime, + state: str | None = None, + ) -> PreviousDagRunResult: + """Get the previous DAG run before the given logical date, optionally filtered by state.""" + params = { + "logical_date": logical_date.isoformat(), + } + + if state: + params["state"] = state + + resp = self.client.get(f"dag-runs/{dag_id}/previous", params=params) + return PreviousDagRunResult(dag_run=resp.json()) + class BearerAuth(httpx.Auth): def __init__(self, token: str): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 1dabd8c9022..9cdc08379ac 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -27,7 +27,7 @@ from uuid import UUID from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel -API_VERSION: Final[str] = "2025-05-20" +API_VERSION: Final[str] = "2025-08-10" class AssetAliasReferenceAssetEventDagRun(BaseModel): @@ -494,6 +494,7 @@ class DagRun(BaseModel): end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None clear_number: Annotated[int | None, Field(title="Clear Number")] = 0 run_type: DagRunType + state: DagRunState conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None consumed_asset_events: Annotated[list[AssetEventDagRunReference], Field(title="Consumed Asset Events")] diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 2c6dfea4e60..1f9eb3141a3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -70,6 +70,7 @@ from airflow.sdk.api.datamodels._generated import ( AssetResponse, BundleInfo, ConnectionResponse, + DagRun, DagRunStateResponse, InactiveAssetsResponse, PrevSuccessfulDagRunResponse, @@ -492,6 +493,13 @@ class DagRunStateResult(DagRunStateResponse): return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult") +class PreviousDagRunResult(BaseModel): + """Response containing previous DAG run information.""" + + dag_run: DagRun | None = None + type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult" + + class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse): type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult" @@ -579,6 +587,7 @@ ToTask = Annotated[ XComSequenceSliceResult, InactiveAssetsResult, OKResponse, + PreviousDagRunResult, ], Field(discriminator="type"), ] @@ -775,6 +784,13 @@ class GetDagRunState(BaseModel): type: Literal["GetDagRunState"] = "GetDagRunState" +class GetPreviousDagRun(BaseModel): + dag_id: str + logical_date: AwareDatetime + state: str | None = None + type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun" + + class GetAssetByName(BaseModel): name: str type: Literal["GetAssetByName"] = "GetAssetByName" @@ -853,6 +869,7 @@ ToSupervisor = Annotated[ GetDagRunState, GetDRCount, GetPrevSuccessfulDagRun, + GetPreviousDagRun, GetTaskRescheduleStartDate, GetTICount, GetTaskStates, diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 57224f66a58..ccac2d8d2ad 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -83,6 +83,7 @@ from airflow.sdk.execution_time.comms import ( GetConnection, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, GetTaskStates, @@ -1227,6 +1228,12 @@ class ActivitySubprocess(WatchedSubprocess): run_ids=msg.run_ids, states=msg.states, ) + elif isinstance(msg, GetPreviousDagRun): + resp = self.client.dag_runs.get_previous( + dag_id=msg.dag_id, + logical_date=msg.logical_date, + state=msg.state, + ) elif isinstance(msg, DeleteVariable): resp = self.client.variables.delete(msg.key) elif isinstance(msg, ValidateInletsAndOutlets): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 62aa7d37b7f..a6bf2d11e1b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -44,6 +44,7 @@ from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.datamodels._generated import ( AssetProfile, + DagRun, TaskInstance, TaskInstanceState, TIRunContext, @@ -65,10 +66,12 @@ from airflow.sdk.execution_time.comms import ( ErrorResponse, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetTaskRescheduleStartDate, GetTaskStates, GetTICount, InactiveAssetsResult, + PreviousDagRunResult, RescheduleTask, ResendLoggingFD, RetryTask, @@ -438,6 +441,30 @@ class RuntimeTaskInstance(TaskInstance): return response.start_date + def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: + """Return the previous DAG run before the given logical date, optionally filtered by state.""" + context = self.get_template_context() + dag_run = context.get("dag_run") + + log = structlog.get_logger(logger_name="task") + + log.debug("Getting previous DAG run", dag_run=dag_run) + + if dag_run is None: + return None + + if dag_run.logical_date is None: + return None + + response = SUPERVISOR_COMMS.send( + msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) + ) + + if TYPE_CHECKING: + assert isinstance(response, PreviousDagRunResult) + + return response.dag_run + @staticmethod def get_ti_count( dag_id: str, diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 8bd0ea0db8d..abe1f1f84a8 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -85,6 +85,8 @@ class RuntimeTaskInstanceProtocol(Protocol): def get_first_reschedule_date(self, first_try_number) -> AwareDatetime | None: ... + def get_previous_dagrun(self, state: str | None = None) -> DagRunProtocol | None: ... + @staticmethod def get_ti_count( dag_id: str, diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index d1866f52391..c3195a7b9bb 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -198,6 +198,7 @@ class MakeTIContextDictCallable(Protocol): def make_ti_context() -> MakeTIContextCallable: """Factory for creating TIRunContext objects.""" from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + from airflow.utils.state import DagRunState def _make_context( dag_id: str = "test_dag", @@ -226,6 +227,7 @@ def make_ti_context() -> MakeTIContextCallable: start_date=start_date, # type: ignore run_type=run_type, # type: ignore run_after=run_after, # type: ignore + state=DagRunState.RUNNING, conf=conf, # type: ignore consumed_asset_events=list(consumed_asset_events), ), diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index caa515de09a..4ffa8479835 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -19,6 +19,7 @@ from __future__ import annotations import json import pickle +from datetime import datetime from unittest import mock import httpx @@ -41,6 +42,7 @@ from airflow.sdk.execution_time.comms import ( DeferTask, ErrorResponse, OKResponse, + PreviousDagRunResult, RescheduleTask, TaskRescheduleStartDate, ) @@ -1139,6 +1141,86 @@ class TestDagRunOperations: result = client.dag_runs.get_count(dag_id="test_dag", run_ids=["run1", "run2"]) assert result.count == 2 + def test_get_previous_basic(self): + """Test basic get_previous functionality with dag_id and logical_date.""" + logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_dag/previous": + assert request.url.params["logical_date"] == logical_date.isoformat() + # Return complete DagRun data + return httpx.Response( + status_code=200, + json={ + "dag_id": "test_dag", + "run_id": "prev_run", + "logical_date": "2024-01-14T12:00:00+00:00", + "start_date": "2024-01-14T12:05:00+00:00", + "run_after": "2024-01-14T12:00:00+00:00", + "run_type": "scheduled", + "state": "success", + "consumed_asset_events": [], + }, + ) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date) + + assert isinstance(result, PreviousDagRunResult) + assert result.dag_run.dag_id == "test_dag" + assert result.dag_run.run_id == "prev_run" + assert result.dag_run.state == "success" + + def test_get_previous_with_state_filter(self): + """Test get_previous functionality with state filtering.""" + logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_dag/previous": + assert request.url.params["logical_date"] == logical_date.isoformat() + assert request.url.params["state"] == "success" + # Return complete DagRun data + return httpx.Response( + status_code=200, + json={ + "dag_id": "test_dag", + "run_id": "prev_success_run", + "logical_date": "2024-01-14T12:00:00+00:00", + "start_date": "2024-01-14T12:05:00+00:00", + "run_after": "2024-01-14T12:00:00+00:00", + "run_type": "scheduled", + "state": "success", + "consumed_asset_events": [], + }, + ) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date, state="success") + + assert isinstance(result, PreviousDagRunResult) + assert result.dag_run.dag_id == "test_dag" + assert result.dag_run.run_id == "prev_success_run" + assert result.dag_run.state == "success" + + def test_get_previous_not_found(self): + """Test get_previous when no previous DAG run exists returns None.""" + logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_dag/previous": + assert request.url.params["logical_date"] == logical_date.isoformat() + # Return None (null) when no previous DAG run found + return httpx.Response(status_code=200, content="null") + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date) + + assert isinstance(result, PreviousDagRunResult) + assert result.dag_run is None + class TestTaskRescheduleOperations: def test_get_start_date(self): diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index 48c7ad74a15..d0be736b4ce 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -56,6 +56,7 @@ class TestCommsDecoder: "run_after": "2024-12-01T01:00:00Z", "end_date": None, "run_type": "manual", + "state": "success", "conf": None, "consumed_asset_events": [], }, diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 600ca4ae549..511cccc2a33 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -51,7 +51,9 @@ from airflow.sdk.api.datamodels._generated import ( AssetEventResponse, AssetProfile, AssetResponse, + DagRun, DagRunState, + DagRunType, TaskInstance, TaskInstanceState, ) @@ -75,6 +77,7 @@ from airflow.sdk.execution_time.comms import ( GetConnection, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, GetTaskStates, @@ -85,6 +88,7 @@ from airflow.sdk.execution_time.comms import ( GetXComSequenceSlice, InactiveAssetsResult, OKResponse, + PreviousDagRunResult, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -1822,6 +1826,72 @@ class TestHandleRequest: None, id="get_dr_count", ), + pytest.param( + GetPreviousDagRun( + dag_id="test_dag", + logical_date=timezone.parse("2024-01-15T12:00:00Z"), + ), + { + "dag_run": { + "dag_id": "test_dag", + "run_id": "prev_run", + "logical_date": timezone.parse("2024-01-14T12:00:00Z"), + "run_type": "scheduled", + "start_date": timezone.parse("2024-01-15T12:00:00Z"), + "run_after": timezone.parse("2024-01-15T12:00:00Z"), + "consumed_asset_events": [], + "state": "success", + "data_interval_start": None, + "data_interval_end": None, + "end_date": None, + "clear_number": 0, + "conf": None, + }, + "type": "PreviousDagRunResult", + }, + "dag_runs.get_previous", + (), + { + "dag_id": "test_dag", + "logical_date": timezone.parse("2024-01-15T12:00:00Z"), + "state": None, + }, + PreviousDagRunResult( + dag_run=DagRun( + dag_id="test_dag", + run_id="prev_run", + logical_date=timezone.parse("2024-01-14T12:00:00Z"), + run_type=DagRunType.SCHEDULED, + start_date=timezone.parse("2024-01-15T12:00:00Z"), + run_after=timezone.parse("2024-01-15T12:00:00Z"), + consumed_asset_events=[], + state=DagRunState.SUCCESS, + ) + ), + None, + id="get_previous_dagrun", + ), + pytest.param( + GetPreviousDagRun( + dag_id="test_dag", + logical_date=timezone.parse("2024-01-15T12:00:00Z"), + state="success", + ), + { + "dag_run": None, + "type": "PreviousDagRunResult", + }, + "dag_runs.get_previous", + (), + { + "dag_id": "test_dag", + "logical_date": timezone.parse("2024-01-15T12:00:00Z"), + "state": "success", + }, + PreviousDagRunResult(dag_run=None), + None, + id="get_previous_dagrun_with_state", + ), pytest.param( GetTaskStates(dag_id="test_dag", task_group_id="test_group"), { diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 4b445404d8a..b65c093856e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -50,6 +50,7 @@ from airflow.sdk import DAG, BaseOperator, Connection, dag as dag_decorator, get from airflow.sdk.api.datamodels._generated import ( AssetProfile, AssetResponse, + DagRun, DagRunState, TaskInstance, TaskInstanceState, @@ -71,12 +72,14 @@ from airflow.sdk.execution_time.comms import ( GetConnection, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetTaskStates, GetTICount, GetVariable, GetXCom, GetXComSequenceSlice, OKResponse, + PreviousDagRunResult, PrevSuccessfulDagRunResult, SetRenderedFields, SetXCom, @@ -1802,6 +1805,66 @@ class TestRuntimeTaskInstance: ) assert states == {"run1": {"task1": "running"}} + def test_get_previous_dagrun_basic(self, create_runtime_ti, mock_supervisor_comms): + """Test that get_previous_dagrun sends the correct request without state filter.""" + + task = BaseOperator(task_id="hello") + dag_id = "test_dag" + runtime_ti = create_runtime_ti(task=task, dag_id=dag_id, logical_date=timezone.datetime(2025, 1, 2)) + + dag_run_data = DagRun( + dag_id=dag_id, + run_id="prev_run", + logical_date=timezone.datetime(2025, 1, 1), + start_date=timezone.datetime(2025, 1, 1), + run_after=timezone.datetime(2025, 1, 1), + run_type="scheduled", + state="success", + consumed_asset_events=[], + ) + + mock_supervisor_comms.send.return_value = PreviousDagRunResult(dag_run=dag_run_data) + + dr = runtime_ti.get_previous_dagrun() + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetPreviousDagRun(dag_id="test_dag", logical_date=timezone.datetime(2025, 1, 2), state=None), + ) + assert dr.dag_id == "test_dag" + assert dr.run_id == "prev_run" + assert dr.state == "success" + + def test_get_previous_dagrun_with_state(self, create_runtime_ti, mock_supervisor_comms): + """Test that get_previous_dagrun sends the correct request with state filter.""" + + task = BaseOperator(task_id="hello") + dag_id = "test_dag" + runtime_ti = create_runtime_ti(task=task, dag_id=dag_id, logical_date=timezone.datetime(2025, 1, 2)) + + dag_run_data = DagRun( + dag_id=dag_id, + run_id="prev_success_run", + logical_date=timezone.datetime(2025, 1, 1), + start_date=timezone.datetime(2025, 1, 1), + run_after=timezone.datetime(2025, 1, 1), + run_type="scheduled", + state="success", + consumed_asset_events=[], + ) + + mock_supervisor_comms.send.return_value = PreviousDagRunResult(dag_run=dag_run_data) + + dr = runtime_ti.get_previous_dagrun(state="success") + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetPreviousDagRun( + dag_id="test_dag", logical_date=timezone.datetime(2025, 1, 2), state="success" + ), + ) + assert dr.dag_id == "test_dag" + assert dr.run_id == "prev_success_run" + assert dr.state == "success" + class TestXComAfterTaskExecution: @pytest.mark.parametrize(
