This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c31cb5f19ec Get `LatestOnlyOperator` working with Task SDK (#48945)
c31cb5f19ec is described below
commit c31cb5f19ec77cbd1c40899c187ca83757c3fe20
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Apr 8 22:09:14 2025 +0530
Get `LatestOnlyOperator` working with Task SDK (#48945)
closes https://github.com/apache/airflow/issues/48897
---
devel-common/src/tests_common/pytest_plugin.py | 52 +++++++++++++++++++---
.../providers/standard/operators/latest_only.py | 36 ++++++++++++++-
.../operators/test_latest_only_operator.py | 22 +++++++++
task-sdk/src/airflow/sdk/types.py | 4 +-
4 files changed, 103 insertions(+), 11 deletions(-)
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 2f0a416f47f..33cb0fe6ea5 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -46,10 +46,12 @@ if TYPE_CHECKING:
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
+ from airflow.sdk import Context
from airflow.sdk.api.datamodels._generated import IntermediateTIState,
TerminalTIState
from airflow.sdk.bases.operator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+ from airflow.sdk.types import DagRunProtocol
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Self
from airflow.utils.state import DagRunState, TaskInstanceState
@@ -1996,6 +1998,15 @@ class RunTaskCallable(Protocol):
@property
def error(self) -> BaseException | None: ...
+ @property
+ def ti(self) -> RuntimeTaskInstance: ...
+
+ @property
+ def dagrun(self) -> DagRunProtocol: ...
+
+ @property
+ def context(self) -> Context: ...
+
xcom: _XComHelperProtocol
def __call__(
@@ -2039,6 +2050,7 @@ def create_runtime_ti(mocked_parse):
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
+ from airflow.timetables.base import TimeRestriction
from airflow.utils import timezone
def _create_task_instance(
@@ -2058,6 +2070,7 @@ def create_runtime_ti(mocked_parse):
max_tries: int | None = None,
) -> RuntimeTaskInstance:
from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
+ from airflow.utils.types import DagRunType
if not ti_id:
ti_id = uuid7()
@@ -2067,13 +2080,22 @@ def create_runtime_ti(mocked_parse):
task.dag = dag # type: ignore[assignment]
task = dag.task_dict[task.task_id]
+ data_interval_start = None
+ data_interval_end = None
+
if task.dag.timetable:
- data_interval_start, data_interval_end =
task.dag.timetable.infer_manual_data_interval(
- run_after=logical_date # type: ignore
- )
- else:
- data_interval_start = None
- data_interval_end = None
+ if run_type == DagRunType.MANUAL:
+ data_interval_start, data_interval_end =
task.dag.timetable.infer_manual_data_interval(
+ run_after=logical_date # type: ignore
+ )
+ else:
+ drinfo = task.dag.timetable.next_dagrun_info(
+ last_automated_data_interval=None,
+ restriction=TimeRestriction(earliest=None, latest=None,
catchup=False),
+ )
+ if drinfo:
+ data_interval = drinfo.data_interval
+ data_interval_start, data_interval_end =
data_interval.start, data_interval.end
dag_id = task.dag.dag_id
task_retries = task.retries or 0
@@ -2252,6 +2274,9 @@ def run_task(create_runtime_ti, mock_supervisor_comms,
spy_agency) -> RunTaskCal
self._state = None
self._msg = None
self._error = None
+ self._ti = None
+ self._dagrun = None
+ self._context = None
@property
def state(self) -> IntermediateTIState | TerminalTIState:
@@ -2268,6 +2293,18 @@ def run_task(create_runtime_ti, mock_supervisor_comms,
spy_agency) -> RunTaskCal
"""Get the error message if there was any."""
return self._error
+ @property
+ def ti(self) -> RuntimeTaskInstance:
+ return self._ti
+
+ @property
+ def dagrun(self) -> DagRunProtocol:
+ return self._dagrun
+
+ @property
+ def context(self) -> Context:
+ return self._context
+
def __call__(
self,
task: TaskSDKBaseOperator,
@@ -2315,6 +2352,9 @@ def run_task(create_runtime_ti, mock_supervisor_comms,
spy_agency) -> RunTaskCal
self._state = state
self._msg = msg
self._error = error
+ self._ti = ti
+ self._dagrun = context.get("dag_run")
+ self._context = context
return state, msg, error
diff --git
a/providers/standard/src/airflow/providers/standard/operators/latest_only.py
b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
index 930d5947563..d7f4c636ebf 100644
--- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py
+++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
@@ -25,10 +25,12 @@ from typing import TYPE_CHECKING
import pendulum
from airflow.providers.standard.operators.branch import BaseBranchOperator
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.types import DagRunType
if TYPE_CHECKING:
from airflow.models import DAG, DagRun
+ from airflow.timetables.base import DagRunInfo
try:
from airflow.sdk.definitions.context import Context
@@ -46,6 +48,10 @@ class LatestOnlyOperator(BaseBranchOperator):
Note that downstream tasks are never skipped if the given DAG_Run is
marked as externally triggered.
+
+ Note that when used with timetables that produce zero-length or
point-in-time data intervals
+ (e.g., ``DeltaTriggerTimetable``), this operator assumes each run is the
latest
+ and does not skip downstream tasks.
"""
ui_color = "#e9ffdb" # nyanza
@@ -58,8 +64,7 @@ class LatestOnlyOperator(BaseBranchOperator):
self.log.info("Manually triggered DAG_Run: allowing execution to
proceed.")
return
list(context["task"].get_direct_relative_ids(upstream=False))
- dag: DAG = context["dag"] # type: ignore[assignment]
- next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run),
restricted=False)
+ next_info = self._get_next_run_info(context, dag_run)
now = pendulum.now("UTC")
if next_info is None:
@@ -74,6 +79,15 @@ class LatestOnlyOperator(BaseBranchOperator):
now,
)
+ if left_window == right_window:
+ self.log.info(
+ "Zero-length interval [%s, %s) from timetable (%s); treating
current run as latest.",
+ left_window,
+ right_window,
+ self.dag.timetable.__class__,
+ )
+ return
list(context["task"].get_direct_relative_ids(upstream=False))
+
if not left_window < now <= right_window:
self.log.info("Not latest execution, skipping downstream.")
# we return an empty list, thus the parent BaseBranchOperator
@@ -82,3 +96,21 @@ class LatestOnlyOperator(BaseBranchOperator):
else:
self.log.info("Latest, allowing execution to proceed.")
return
list(context["task"].get_direct_relative_ids(upstream=False))
+
+ def _get_next_run_info(self, context: Context, dag_run: DagRun) ->
DagRunInfo | None:
+ dag: DAG = context["dag"] # type: ignore[assignment]
+
+ if AIRFLOW_V_3_0_PLUS:
+ from airflow.timetables.base import DataInterval, TimeRestriction
+
+ time_restriction = TimeRestriction(earliest=None, latest=None,
catchup=True)
+ current_interval = DataInterval(start=dag_run.data_interval_start,
end=dag_run.data_interval_end)
+
+ next_info = dag.timetable.next_dagrun_info(
+ last_automated_data_interval=current_interval,
+ restriction=time_restriction,
+ )
+
+ else:
+ next_info =
dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False)
+ return next_info
diff --git
a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py
b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py
index e5f0e842fe4..b976f41fa9a 100644
---
a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py
+++
b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py
@@ -36,6 +36,8 @@ from tests_common.test_utils.db import clear_db_runs,
clear_db_xcom
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import DAG
+ from airflow.timetables.trigger import DeltaTriggerTimetable
from airflow.utils.types import DagRunTriggeredByType
pytestmark = pytest.mark.db_test
@@ -310,3 +312,23 @@ class TestLatestOnlyOperator:
timezone.datetime(2016, 1, 1, 12): "success",
timezone.datetime(2016, 1, 2): "success",
}
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Only applicable to
Airflow 3.0+")
+ def test_zero_length_interval_treated_as_latest(self, run_task):
+ """Test that when the data_interval_start and data_interval_end are
the same, the task is treated as latest."""
+ with DAG(
+ "test_dag",
+ schedule=DeltaTriggerTimetable(datetime.timedelta(hours=1)),
+ start_date=DEFAULT_DATE,
+ catchup=False,
+ ):
+ latest_task = LatestOnlyOperator(task_id="latest")
+ downstream_task = EmptyOperator(task_id="downstream")
+ latest_task >> downstream_task
+
+ run_task(latest_task, run_type=DagRunType.SCHEDULED)
+
+ assert run_task.dagrun.data_interval_start ==
run_task.dagrun.data_interval_end
+
+ # The task will raise DownstreamTasksSkipped exception if it is not
the latest run
+ assert run_task.state == State.SUCCESS
diff --git a/task-sdk/src/airflow/sdk/types.py
b/task-sdk/src/airflow/sdk/types.py
index b3ecc7be91a..a7749ccd230 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -21,8 +21,6 @@ import uuid
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Protocol, Union
-import attrs
-
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
@@ -109,7 +107,7 @@ class RuntimeTaskInstanceProtocol(Protocol):
def get_dagrun_state(dag_id: str, run_id: str) -> str: ...
-class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance):
+class OutletEventAccessorProtocol(Protocol):
"""Protocol for managing access to a specific outlet event accessor."""
key: BaseAssetUniqueKey