This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c31cb5f19ec Get `LatestOnlyOperator` working with Task SDK (#48945)
c31cb5f19ec is described below

commit c31cb5f19ec77cbd1c40899c187ca83757c3fe20
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Apr 8 22:09:14 2025 +0530

    Get `LatestOnlyOperator` working with Task SDK (#48945)
    
    closes https://github.com/apache/airflow/issues/48897
---
 devel-common/src/tests_common/pytest_plugin.py     | 52 +++++++++++++++++++---
 .../providers/standard/operators/latest_only.py    | 36 ++++++++++++++-
 .../operators/test_latest_only_operator.py         | 22 +++++++++
 task-sdk/src/airflow/sdk/types.py                  |  4 +-
 4 files changed, 103 insertions(+), 11 deletions(-)

diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index 2f0a416f47f..33cb0fe6ea5 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -46,10 +46,12 @@ if TYPE_CHECKING:
     from airflow.models.dagrun import DagRun, DagRunType
     from airflow.models.taskinstance import TaskInstance
     from airflow.providers.standard.operators.empty import EmptyOperator
+    from airflow.sdk import Context
     from airflow.sdk.api.datamodels._generated import IntermediateTIState, 
TerminalTIState
     from airflow.sdk.bases.operator import BaseOperator as TaskSDKBaseOperator
     from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
     from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+    from airflow.sdk.types import DagRunProtocol
     from airflow.timetables.base import DataInterval
     from airflow.typing_compat import Self
     from airflow.utils.state import DagRunState, TaskInstanceState
@@ -1996,6 +1998,15 @@ class RunTaskCallable(Protocol):
     @property
     def error(self) -> BaseException | None: ...
 
+    @property
+    def ti(self) -> RuntimeTaskInstance: ...
+
+    @property
+    def dagrun(self) -> DagRunProtocol: ...
+
+    @property
+    def context(self) -> Context: ...
+
     xcom: _XComHelperProtocol
 
     def __call__(
@@ -2039,6 +2050,7 @@ def create_runtime_ti(mocked_parse):
     from airflow.sdk.api.datamodels._generated import TaskInstance
     from airflow.sdk.definitions.dag import DAG
     from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
+    from airflow.timetables.base import TimeRestriction
     from airflow.utils import timezone
 
     def _create_task_instance(
@@ -2058,6 +2070,7 @@ def create_runtime_ti(mocked_parse):
         max_tries: int | None = None,
     ) -> RuntimeTaskInstance:
         from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
+        from airflow.utils.types import DagRunType
 
         if not ti_id:
             ti_id = uuid7()
@@ -2067,13 +2080,22 @@ def create_runtime_ti(mocked_parse):
             task.dag = dag  # type: ignore[assignment]
             task = dag.task_dict[task.task_id]
 
+        data_interval_start = None
+        data_interval_end = None
+
         if task.dag.timetable:
-            data_interval_start, data_interval_end = 
task.dag.timetable.infer_manual_data_interval(
-                run_after=logical_date  # type: ignore
-            )
-        else:
-            data_interval_start = None
-            data_interval_end = None
+            if run_type == DagRunType.MANUAL:
+                data_interval_start, data_interval_end = 
task.dag.timetable.infer_manual_data_interval(
+                    run_after=logical_date  # type: ignore
+                )
+            else:
+                drinfo = task.dag.timetable.next_dagrun_info(
+                    last_automated_data_interval=None,
+                    restriction=TimeRestriction(earliest=None, latest=None, 
catchup=False),
+                )
+                if drinfo:
+                    data_interval = drinfo.data_interval
+                    data_interval_start, data_interval_end = 
data_interval.start, data_interval.end
 
         dag_id = task.dag.dag_id
         task_retries = task.retries or 0
@@ -2252,6 +2274,9 @@ def run_task(create_runtime_ti, mock_supervisor_comms, 
spy_agency) -> RunTaskCal
             self._state = None
             self._msg = None
             self._error = None
+            self._ti = None
+            self._dagrun = None
+            self._context = None
 
         @property
         def state(self) -> IntermediateTIState | TerminalTIState:
@@ -2268,6 +2293,18 @@ def run_task(create_runtime_ti, mock_supervisor_comms, 
spy_agency) -> RunTaskCal
             """Get the error message if there was any."""
             return self._error
 
+        @property
+        def ti(self) -> RuntimeTaskInstance:
+            return self._ti
+
+        @property
+        def dagrun(self) -> DagRunProtocol:
+            return self._dagrun
+
+        @property
+        def context(self) -> Context:
+            return self._context
+
         def __call__(
             self,
             task: TaskSDKBaseOperator,
@@ -2315,6 +2352,9 @@ def run_task(create_runtime_ti, mock_supervisor_comms, 
spy_agency) -> RunTaskCal
             self._state = state
             self._msg = msg
             self._error = error
+            self._ti = ti
+            self._dagrun = context.get("dag_run")
+            self._context = context
 
             return state, msg, error
 
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/latest_only.py 
b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
index 930d5947563..d7f4c636ebf 100644
--- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py
+++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
@@ -25,10 +25,12 @@ from typing import TYPE_CHECKING
 import pendulum
 
 from airflow.providers.standard.operators.branch import BaseBranchOperator
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
 from airflow.utils.types import DagRunType
 
 if TYPE_CHECKING:
     from airflow.models import DAG, DagRun
+    from airflow.timetables.base import DagRunInfo
 
     try:
         from airflow.sdk.definitions.context import Context
@@ -46,6 +48,10 @@ class LatestOnlyOperator(BaseBranchOperator):
 
     Note that downstream tasks are never skipped if the given DAG_Run is
     marked as externally triggered.
+
+    Note that when used with timetables that produce zero-length or 
point-in-time data intervals
+    (e.g., ``DeltaTriggerTimetable``), this operator assumes each run is the 
latest
+    and does not skip downstream tasks.
     """
 
     ui_color = "#e9ffdb"  # nyanza
@@ -58,8 +64,7 @@ class LatestOnlyOperator(BaseBranchOperator):
             self.log.info("Manually triggered DAG_Run: allowing execution to 
proceed.")
             return 
list(context["task"].get_direct_relative_ids(upstream=False))
 
-        dag: DAG = context["dag"]  # type: ignore[assignment]
-        next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), 
restricted=False)
+        next_info = self._get_next_run_info(context, dag_run)
         now = pendulum.now("UTC")
 
         if next_info is None:
@@ -74,6 +79,15 @@ class LatestOnlyOperator(BaseBranchOperator):
             now,
         )
 
+        if left_window == right_window:
+            self.log.info(
+                "Zero-length interval [%s, %s) from timetable (%s); treating 
current run as latest.",
+                left_window,
+                right_window,
+                self.dag.timetable.__class__,
+            )
+            return 
list(context["task"].get_direct_relative_ids(upstream=False))
+
         if not left_window < now <= right_window:
             self.log.info("Not latest execution, skipping downstream.")
             # we return an empty list, thus the parent BaseBranchOperator
@@ -82,3 +96,21 @@ class LatestOnlyOperator(BaseBranchOperator):
         else:
             self.log.info("Latest, allowing execution to proceed.")
             return 
list(context["task"].get_direct_relative_ids(upstream=False))
+
+    def _get_next_run_info(self, context: Context, dag_run: DagRun) -> 
DagRunInfo | None:
+        dag: DAG = context["dag"]  # type: ignore[assignment]
+
+        if AIRFLOW_V_3_0_PLUS:
+            from airflow.timetables.base import DataInterval, TimeRestriction
+
+            time_restriction = TimeRestriction(earliest=None, latest=None, 
catchup=True)
+            current_interval = DataInterval(start=dag_run.data_interval_start, 
end=dag_run.data_interval_end)
+
+            next_info = dag.timetable.next_dagrun_info(
+                last_automated_data_interval=current_interval,
+                restriction=time_restriction,
+            )
+
+        else:
+            next_info = 
dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False)
+        return next_info
diff --git 
a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py 
b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py
index e5f0e842fe4..b976f41fa9a 100644
--- 
a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py
+++ 
b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py
@@ -36,6 +36,8 @@ from tests_common.test_utils.db import clear_db_runs, 
clear_db_xcom
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 
 if AIRFLOW_V_3_0_PLUS:
+    from airflow.sdk import DAG
+    from airflow.timetables.trigger import DeltaTriggerTimetable
     from airflow.utils.types import DagRunTriggeredByType
 
 pytestmark = pytest.mark.db_test
@@ -310,3 +312,23 @@ class TestLatestOnlyOperator:
             timezone.datetime(2016, 1, 1, 12): "success",
             timezone.datetime(2016, 1, 2): "success",
         }
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Only applicable to 
Airflow 3.0+")
+    def test_zero_length_interval_treated_as_latest(self, run_task):
+        """Test that when the data_interval_start and data_interval_end are 
the same, the task is treated as latest."""
+        with DAG(
+            "test_dag",
+            schedule=DeltaTriggerTimetable(datetime.timedelta(hours=1)),
+            start_date=DEFAULT_DATE,
+            catchup=False,
+        ):
+            latest_task = LatestOnlyOperator(task_id="latest")
+            downstream_task = EmptyOperator(task_id="downstream")
+            latest_task >> downstream_task
+
+        run_task(latest_task, run_type=DagRunType.SCHEDULED)
+
+        assert run_task.dagrun.data_interval_start == 
run_task.dagrun.data_interval_end
+
+        # The task will raise DownstreamTasksSkipped exception if it is not 
the latest run
+        assert run_task.state == State.SUCCESS
diff --git a/task-sdk/src/airflow/sdk/types.py 
b/task-sdk/src/airflow/sdk/types.py
index b3ecc7be91a..a7749ccd230 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -21,8 +21,6 @@ import uuid
 from collections.abc import Iterable
 from typing import TYPE_CHECKING, Any, Protocol, Union
 
-import attrs
-
 from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
 
 if TYPE_CHECKING:
@@ -109,7 +107,7 @@ class RuntimeTaskInstanceProtocol(Protocol):
     def get_dagrun_state(dag_id: str, run_id: str) -> str: ...
 
 
-class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance):
+class OutletEventAccessorProtocol(Protocol):
     """Protocol for managing access to a specific outlet event accessor."""
 
     key: BaseAssetUniqueKey

Reply via email to