This is an automated email from the ASF dual-hosted git repository.
gopidesu pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-1-test by this push:
new 924d5573d58 Fix get_ti_count and get_task_states access in
callbackrequests (#56822) (#56860)
924d5573d58 is described below
commit 924d5573d5842d90a1978ae89ff2a3cfd74723fb
Author: GPK <[email protected]>
AuthorDate: Mon Oct 20 07:23:57 2025 +0100
Fix get_ti_count and get_task_states access in callbackrequests (#56822)
(#56860)
* Fix get_ti_count and get_task_states access in callbackrequests
* Add tests
---
.../src/airflow/dag_processing/processor.py | 30 +++++++
.../tests/unit/dag_processing/test_processor.py | 96 ++++++++++++++++++++++
2 files changed, 126 insertions(+)
diff --git a/airflow-core/src/airflow/dag_processing/processor.py
b/airflow-core/src/airflow/dag_processing/processor.py
index c0a4675ba6c..527e5d8c0aa 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -43,6 +43,8 @@ from airflow.sdk.execution_time.comms import (
GetConnection,
GetPreviousDagRun,
GetPrevSuccessfulDagRun,
+ GetTaskStates,
+ GetTICount,
GetVariable,
GetXCom,
GetXComCount,
@@ -53,6 +55,7 @@ from airflow.sdk.execution_time.comms import (
PreviousDagRunResult,
PrevSuccessfulDagRunResult,
PutVariable,
+ TaskStatesResult,
VariableResult,
XComCountResponse,
XComResult,
@@ -112,6 +115,8 @@ ToManager = Annotated[
| GetConnection
| GetVariable
| PutVariable
+ | GetTaskStates
+ | GetTICount
| DeleteVariable
| GetPrevSuccessfulDagRun
| GetPreviousDagRun
@@ -127,6 +132,7 @@ ToDagProcessor = Annotated[
DagFileParseRequest
| ConnectionResult
| VariableResult
+ | TaskStatesResult
| PreviousDagRunResult
| PrevSuccessfulDagRunResult
| ErrorResponse
@@ -477,6 +483,7 @@ class DagFileProcessorProcess(WatchedSubprocess):
def _handle_request(self, msg: ToManager, log: FilteringBoundLogger,
req_id: int) -> None:
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
+ TaskStatesResponse,
VariableResponse,
XComSequenceIndexResponse,
)
@@ -549,6 +556,29 @@ class DagFileProcessorProcess(WatchedSubprocess):
from airflow.sdk.log import mask_secret
mask_secret(msg.value, msg.name)
+ elif isinstance(msg, GetTICount):
+ resp = self.client.task_instances.get_count(
+ dag_id=msg.dag_id,
+ map_index=msg.map_index,
+ task_ids=msg.task_ids,
+ task_group_id=msg.task_group_id,
+ logical_dates=msg.logical_dates,
+ run_ids=msg.run_ids,
+ states=msg.states,
+ )
+ elif isinstance(msg, GetTaskStates):
+ task_states_map = self.client.task_instances.get_task_states(
+ dag_id=msg.dag_id,
+ map_index=msg.map_index,
+ task_ids=msg.task_ids,
+ task_group_id=msg.task_group_id,
+ logical_dates=msg.logical_dates,
+ run_ids=msg.run_ids,
+ )
+ if isinstance(task_states_map, TaskStatesResponse):
+ resp = TaskStatesResult.from_api_response(task_states_map)
+ else:
+ resp = task_states_map
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py
b/airflow-core/tests/unit/dag_processing/test_processor.py
index f50883f9308..2269d4a2d61 100644
--- a/airflow-core/tests/unit/dag_processing/test_processor.py
+++ b/airflow-core/tests/unit/dag_processing/test_processor.py
@@ -64,8 +64,12 @@ from airflow.sdk.api.client import Client
from airflow.sdk.api.datamodels._generated import DagRunState
from airflow.sdk.execution_time import comms
from airflow.sdk.execution_time.comms import (
+ GetTaskStates,
+ GetTICount,
GetXCom,
GetXComSequenceSlice,
+ TaskStatesResult,
+ TICount,
XComResult,
XComSequenceSliceResult,
)
@@ -1021,6 +1025,98 @@ class TestExecuteDagCallbacks:
mock_supervisor_comms.send.assert_called_once_with(msg=expected_message)
+ @pytest.mark.parametrize(
+ "request_operation,operation_type,mock_response,operation_response",
+ [
+ (
+ lambda context:
context["task_instance"].get_ti_count(dag_id="test_dag"),
+ GetTICount(dag_id="test_dag"),
+ TICount(count=2),
+ "Got response 2",
+ ),
+ (
+ lambda context: context["task_instance"].get_task_states(
+ dag_id="test_dag", task_ids=["test_task"]
+ ),
+ GetTaskStates(
+ dag_id="test_dag",
+ task_ids=["test_task"],
+ ),
+ TaskStatesResult(task_states={"test_run": {"task1":
"running"}}),
+ "Got response {'test_run': {'task1': 'running'}}",
+ ),
+ ],
+ )
+ def test_dagfileprocessorprocess_request_handler_operations(
+ self,
+ spy_agency,
+ mock_supervisor_comms,
+ request_operation,
+ operation_type,
+ mock_response,
+ operation_response,
+ caplog,
+ ):
+ """Test that DagFileProcessorProcess Request Handler Operations"""
+
+ mock_supervisor_comms.send.return_value = mock_response
+
+ def callback_fn(context):
+ log = structlog.get_logger()
+ log.info("Callback started..")
+ log.info("Got response %s", request_operation(context))
+
+ with DAG(dag_id="test_dag", on_success_callback=callback_fn) as dag:
+ BaseOperator(task_id="test_task")
+
+ def fake_collect_dags(self, *args, **kwargs):
+ self.dags[dag.dag_id] = dag
+
+ spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags,
owner=DagBag)
+
+ dagbag = DagBag()
+ dagbag.collect_dags()
+
+ current_time = timezone.utcnow()
+ request = DagCallbackRequest(
+ filepath="test.py",
+ dag_id="test_dag",
+ run_id="test_run",
+ bundle_name="testing",
+ bundle_version=None,
+ context_from_server=DagRunContext(
+ dag_run=DRDataModel(
+ dag_id="test_dag",
+ run_id="test_run",
+ logical_date=current_time,
+ data_interval_start=current_time,
+ data_interval_end=current_time,
+ run_after=current_time,
+ start_date=current_time,
+ end_date=None,
+ run_type="manual",
+ state="success",
+ consumed_asset_events=[],
+ ),
+ last_ti=TIDataModel(
+ id=uuid.uuid4(),
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=-1,
+ try_number=1,
+ dag_version_id=uuid.uuid4(),
+ ),
+ ),
+ is_failure_callback=False,
+ msg="Test success message",
+ )
+
+ _execute_dag_callbacks(dagbag, request, structlog.get_logger())
+
+ mock_supervisor_comms.send.assert_called_once_with(msg=operation_type)
+ assert operation_response in caplog.text
+
class TestExecuteTaskCallbacks:
"""Test the _execute_task_callbacks function"""