This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-1-test by this push:
     new 924d5573d58 Fix get_ti_count and get_task_states access in 
callbackrequests (#56822) (#56860)
924d5573d58 is described below

commit 924d5573d5842d90a1978ae89ff2a3cfd74723fb
Author: GPK <[email protected]>
AuthorDate: Mon Oct 20 07:23:57 2025 +0100

    Fix get_ti_count and get_task_states access in callbackrequests (#56822) 
(#56860)
    
    * Fix get_ti_count and get_task_states access in callbackrequests
    
    * Add tests
---
 .../src/airflow/dag_processing/processor.py        | 30 +++++++
 .../tests/unit/dag_processing/test_processor.py    | 96 ++++++++++++++++++++++
 2 files changed, 126 insertions(+)

diff --git a/airflow-core/src/airflow/dag_processing/processor.py 
b/airflow-core/src/airflow/dag_processing/processor.py
index c0a4675ba6c..527e5d8c0aa 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -43,6 +43,8 @@ from airflow.sdk.execution_time.comms import (
     GetConnection,
     GetPreviousDagRun,
     GetPrevSuccessfulDagRun,
+    GetTaskStates,
+    GetTICount,
     GetVariable,
     GetXCom,
     GetXComCount,
@@ -53,6 +55,7 @@ from airflow.sdk.execution_time.comms import (
     PreviousDagRunResult,
     PrevSuccessfulDagRunResult,
     PutVariable,
+    TaskStatesResult,
     VariableResult,
     XComCountResponse,
     XComResult,
@@ -112,6 +115,8 @@ ToManager = Annotated[
     | GetConnection
     | GetVariable
     | PutVariable
+    | GetTaskStates
+    | GetTICount
     | DeleteVariable
     | GetPrevSuccessfulDagRun
     | GetPreviousDagRun
@@ -127,6 +132,7 @@ ToDagProcessor = Annotated[
     DagFileParseRequest
     | ConnectionResult
     | VariableResult
+    | TaskStatesResult
     | PreviousDagRunResult
     | PrevSuccessfulDagRunResult
     | ErrorResponse
@@ -477,6 +483,7 @@ class DagFileProcessorProcess(WatchedSubprocess):
     def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, 
req_id: int) -> None:
         from airflow.sdk.api.datamodels._generated import (
             ConnectionResponse,
+            TaskStatesResponse,
             VariableResponse,
             XComSequenceIndexResponse,
         )
@@ -549,6 +556,29 @@ class DagFileProcessorProcess(WatchedSubprocess):
             from airflow.sdk.log import mask_secret
 
             mask_secret(msg.value, msg.name)
+        elif isinstance(msg, GetTICount):
+            resp = self.client.task_instances.get_count(
+                dag_id=msg.dag_id,
+                map_index=msg.map_index,
+                task_ids=msg.task_ids,
+                task_group_id=msg.task_group_id,
+                logical_dates=msg.logical_dates,
+                run_ids=msg.run_ids,
+                states=msg.states,
+            )
+        elif isinstance(msg, GetTaskStates):
+            task_states_map = self.client.task_instances.get_task_states(
+                dag_id=msg.dag_id,
+                map_index=msg.map_index,
+                task_ids=msg.task_ids,
+                task_group_id=msg.task_group_id,
+                logical_dates=msg.logical_dates,
+                run_ids=msg.run_ids,
+            )
+            if isinstance(task_states_map, TaskStatesResponse):
+                resp = TaskStatesResult.from_api_response(task_states_map)
+            else:
+                resp = task_states_map
         else:
             log.error("Unhandled request", msg=msg)
             self.send_msg(
diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py 
b/airflow-core/tests/unit/dag_processing/test_processor.py
index f50883f9308..2269d4a2d61 100644
--- a/airflow-core/tests/unit/dag_processing/test_processor.py
+++ b/airflow-core/tests/unit/dag_processing/test_processor.py
@@ -64,8 +64,12 @@ from airflow.sdk.api.client import Client
 from airflow.sdk.api.datamodels._generated import DagRunState
 from airflow.sdk.execution_time import comms
 from airflow.sdk.execution_time.comms import (
+    GetTaskStates,
+    GetTICount,
     GetXCom,
     GetXComSequenceSlice,
+    TaskStatesResult,
+    TICount,
     XComResult,
     XComSequenceSliceResult,
 )
@@ -1021,6 +1025,98 @@ class TestExecuteDagCallbacks:
 
         
mock_supervisor_comms.send.assert_called_once_with(msg=expected_message)
 
+    @pytest.mark.parametrize(
+        "request_operation,operation_type,mock_response,operation_response",
+        [
+            (
+                lambda context: 
context["task_instance"].get_ti_count(dag_id="test_dag"),
+                GetTICount(dag_id="test_dag"),
+                TICount(count=2),
+                "Got response 2",
+            ),
+            (
+                lambda context: context["task_instance"].get_task_states(
+                    dag_id="test_dag", task_ids=["test_task"]
+                ),
+                GetTaskStates(
+                    dag_id="test_dag",
+                    task_ids=["test_task"],
+                ),
+                TaskStatesResult(task_states={"test_run": {"task1": 
"running"}}),
+                "Got response {'test_run': {'task1': 'running'}}",
+            ),
+        ],
+    )
+    def test_dagfileprocessorprocess_request_handler_operations(
+        self,
+        spy_agency,
+        mock_supervisor_comms,
+        request_operation,
+        operation_type,
+        mock_response,
+        operation_response,
+        caplog,
+    ):
+        """Test that DagFileProcessorProcess Request Handler Operations"""
+
+        mock_supervisor_comms.send.return_value = mock_response
+
+        def callback_fn(context):
+            log = structlog.get_logger()
+            log.info("Callback started..")
+            log.info("Got response %s", request_operation(context))
+
+        with DAG(dag_id="test_dag", on_success_callback=callback_fn) as dag:
+            BaseOperator(task_id="test_task")
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        current_time = timezone.utcnow()
+        request = DagCallbackRequest(
+            filepath="test.py",
+            dag_id="test_dag",
+            run_id="test_run",
+            bundle_name="testing",
+            bundle_version=None,
+            context_from_server=DagRunContext(
+                dag_run=DRDataModel(
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    logical_date=current_time,
+                    data_interval_start=current_time,
+                    data_interval_end=current_time,
+                    run_after=current_time,
+                    start_date=current_time,
+                    end_date=None,
+                    run_type="manual",
+                    state="success",
+                    consumed_asset_events=[],
+                ),
+                last_ti=TIDataModel(
+                    id=uuid.uuid4(),
+                    dag_id="test_dag",
+                    task_id="test_task",
+                    run_id="test_run",
+                    map_index=-1,
+                    try_number=1,
+                    dag_version_id=uuid.uuid4(),
+                ),
+            ),
+            is_failure_callback=False,
+            msg="Test success message",
+        )
+
+        _execute_dag_callbacks(dagbag, request, structlog.get_logger())
+
+        mock_supervisor_comms.send.assert_called_once_with(msg=operation_type)
+        assert operation_response in caplog.text
+
 
 class TestExecuteTaskCallbacks:
     """Test the _execute_task_callbacks function"""

Reply via email to