This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 075937c4a9b Refactor and extract shared request handler logic from 
supervisor _handle_request methods (#65624)
075937c4a9b is described below

commit 075937c4a9bbc1f3f459b4e879438e835ce55e92
Author: Paul <[email protected]>
AuthorDate: Thu May 21 07:13:52 2026 -0700

    Refactor and extract shared request handler logic from supervisor 
_handle_request methods (#65624)
---
 .../src/airflow/dag_processing/processor.py        |  98 ++++-------
 .../src/airflow/jobs/triggerer_job_runner.py       |  81 +++------
 .../airflow/sdk/execution_time/request_handlers.py | 186 +++++++++++++++++++++
 .../src/airflow/sdk/execution_time/supervisor.py   | 124 ++++----------
 4 files changed, 266 insertions(+), 223 deletions(-)

diff --git a/airflow-core/src/airflow/dag_processing/processor.py 
b/airflow-core/src/airflow/dag_processing/processor.py
index d7c0a9d2b59..303b62d1411 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -69,6 +69,21 @@ from airflow.sdk.execution_time.comms import (
     XComSequenceIndexResult,
     XComSequenceSliceResult,
 )
+from airflow.sdk.execution_time.request_handlers import (
+    handle_delete_variable,
+    handle_get_prev_successful_dag_run,
+    handle_get_previous_dag_run,
+    handle_get_previous_ti,
+    handle_get_task_states,
+    handle_get_ti_count,
+    handle_get_variable_keys,
+    handle_get_xcom,
+    handle_get_xcom_count,
+    handle_get_xcom_sequence_item,
+    handle_get_xcom_sequence_slice,
+    handle_mask_secret,
+    handle_put_variable,
+)
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess
 from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, 
_send_error_email_notification
 from airflow.sdk.log import mask_secret
@@ -601,13 +616,11 @@ class DagFileProcessorProcess(WatchedSubprocess):
     def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, 
req_id: int) -> None:
         from airflow.sdk.api.datamodels._generated import (
             ConnectionResponse,
-            TaskStatesResponse,
             VariableResponse,
-            XComSequenceIndexResponse,
         )
 
         resp: BaseModel | None = None
-        dump_opts = {}
+        dump_opts: dict[str, bool] = {}
         if isinstance(msg, DagFileParsingResult):
             self.parsing_result = msg
         elif isinstance(msg, GetConnection):
@@ -633,86 +646,31 @@ class DagFileProcessorProcess(WatchedSubprocess):
             else:
                 resp = var
         elif isinstance(msg, GetVariableKeys):
-            from airflow.sdk.execution_time.request_handlers import 
handle_get_variable_keys
-
             resp, dump_opts = handle_get_variable_keys(self.client, msg)
         elif isinstance(msg, PutVariable):
-            self.client.variables.set(msg.key, msg.value, msg.description)
+            resp, dump_opts = handle_put_variable(self.client, msg)
         elif isinstance(msg, DeleteVariable):
-            resp = self.client.variables.delete(msg.key)
+            resp, dump_opts = handle_delete_variable(self.client, msg)
         elif isinstance(msg, GetPreviousDagRun):
-            resp = self.client.dag_runs.get_previous(
-                dag_id=msg.dag_id,
-                logical_date=msg.logical_date,
-                state=msg.state,
-            )
+            resp, dump_opts = handle_get_previous_dag_run(self.client, msg)
         elif isinstance(msg, GetPrevSuccessfulDagRun):
-            dagrun_resp = 
self.client.task_instances.get_previous_successful_dagrun(self.id)
-            dagrun_result = 
PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
-            resp = dagrun_result
-            dump_opts = {"exclude_unset": True}
+            resp, dump_opts = handle_get_prev_successful_dag_run(self.client, 
self.id)
         elif isinstance(msg, GetXCom):
-            xcom = self.client.xcoms.get(
-                msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, 
msg.include_prior_dates
-            )
-            xcom_result = XComResult.from_xcom_response(xcom)
-            resp = xcom_result
+            resp, dump_opts = handle_get_xcom(self.client, msg)
         elif isinstance(msg, GetXComCount):
-            resp = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, 
msg.key)
+            resp, dump_opts = handle_get_xcom_count(self.client, msg)
         elif isinstance(msg, GetXComSequenceItem):
-            xcom = self.client.xcoms.get_sequence_item(
-                msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
-            )
-            if isinstance(xcom, XComSequenceIndexResponse):
-                resp = XComSequenceIndexResult.from_response(xcom)
-            else:
-                resp = xcom
+            resp, dump_opts = handle_get_xcom_sequence_item(self.client, msg)
         elif isinstance(msg, GetXComSequenceSlice):
-            xcoms = self.client.xcoms.get_sequence_slice(
-                msg.dag_id,
-                msg.run_id,
-                msg.task_id,
-                msg.key,
-                msg.start,
-                msg.stop,
-                msg.step,
-                msg.include_prior_dates,
-            )
-            resp = XComSequenceSliceResult.from_response(xcoms)
+            resp, dump_opts = handle_get_xcom_sequence_slice(self.client, msg)
         elif isinstance(msg, MaskSecret):
-            # Use sdk masker in dag processor and triggerer because those use 
the task sdk machinery
-            mask_secret(msg.value, msg.name)
+            handle_mask_secret(msg)
         elif isinstance(msg, GetTICount):
-            resp = self.client.task_instances.get_count(
-                dag_id=msg.dag_id,
-                map_index=msg.map_index,
-                task_ids=msg.task_ids,
-                task_group_id=msg.task_group_id,
-                logical_dates=msg.logical_dates,
-                run_ids=msg.run_ids,
-                states=msg.states,
-            )
+            resp, dump_opts = handle_get_ti_count(self.client, msg)
         elif isinstance(msg, GetTaskStates):
-            task_states_map = self.client.task_instances.get_task_states(
-                dag_id=msg.dag_id,
-                map_index=msg.map_index,
-                task_ids=msg.task_ids,
-                task_group_id=msg.task_group_id,
-                logical_dates=msg.logical_dates,
-                run_ids=msg.run_ids,
-            )
-            if isinstance(task_states_map, TaskStatesResponse):
-                resp = TaskStatesResult.from_api_response(task_states_map)
-            else:
-                resp = task_states_map
+            resp, dump_opts = handle_get_task_states(self.client, msg)
         elif isinstance(msg, GetPreviousTI):
-            resp = self.client.task_instances.get_previous(
-                dag_id=msg.dag_id,
-                task_id=msg.task_id,
-                logical_date=msg.logical_date,
-                map_index=msg.map_index,
-                state=msg.state,
-            )
+            resp, dump_opts = handle_get_previous_ti(self.client, msg)
         else:
             log.error("Unhandled request", msg=msg)
             self.send_msg(
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py 
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 2b4db481c26..6f8f7baae84 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -90,10 +90,20 @@ from airflow.sdk.execution_time.comms import (
     _RequestFrame,
 )
 from airflow.sdk.execution_time.request_handlers import (
+    handle_delete_variable,
+    handle_delete_xcom,
     handle_get_connection,
+    handle_get_dag_run_state,
+    handle_get_dr_count,
+    handle_get_previous_ti,
+    handle_get_task_states,
+    handle_get_ti_count,
     handle_get_variable,
     handle_get_variable_keys,
+    handle_get_xcom,
     handle_mask_secret,
+    handle_put_variable,
+    handle_set_xcom,
 )
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess, 
make_buffered_socket_reader
 from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
@@ -494,10 +504,6 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
         return client
 
     def _handle_request(self, msg: ToTriggerSupervisor, log: 
FilteringBoundLogger, req_id: int) -> None:
-        from airflow.sdk.api.datamodels._generated import (
-            TaskStatesResponse,
-            XComResponse,
-        )
 
         resp: BaseModel | None = None
         dump_opts: dict[str, bool] = {}
@@ -536,78 +542,31 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
         elif isinstance(msg, GetConnection):
             resp, dump_opts = handle_get_connection(self.client, msg)
         elif isinstance(msg, DeleteVariable):
-            resp = self.client.variables.delete(msg.key)
+            resp, dump_opts = handle_delete_variable(self.client, msg)
         elif isinstance(msg, GetVariable):
             resp, dump_opts = handle_get_variable(self.client, msg)
         elif isinstance(msg, GetVariableKeys):
             resp, dump_opts = handle_get_variable_keys(self.client, msg)
         elif isinstance(msg, PutVariable):
-            self.client.variables.set(msg.key, msg.value, msg.description)
+            resp, dump_opts = handle_put_variable(self.client, msg)
         elif isinstance(msg, DeleteXCom):
-            self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.map_index)
+            resp, dump_opts = handle_delete_xcom(self.client, msg)
         elif isinstance(msg, GetXCom):
-            xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.map_index)
-            if isinstance(xcom, XComResponse):
-                xcom_result = XComResult.from_xcom_response(xcom)
-                resp = xcom_result
-                dump_opts = {"exclude_unset": True}
-            else:
-                resp = xcom
+            resp, dump_opts = handle_get_xcom(self.client, msg)
         elif isinstance(msg, SetXCom):
-            self.client.xcoms.set(
-                msg.dag_id,
-                msg.run_id,
-                msg.task_id,
-                msg.key,
-                msg.value,
-                msg.map_index,
-                dag_result=msg.dag_result,
-                mapped_length=msg.mapped_length,
-            )
+            resp, dump_opts = handle_set_xcom(self.client, msg)
         elif isinstance(msg, GetDRCount):
-            dr_count = self.client.dag_runs.get_count(
-                dag_id=msg.dag_id,
-                logical_dates=msg.logical_dates,
-                run_ids=msg.run_ids,
-                states=msg.states,
-            )
-            resp = dr_count
+            resp, dump_opts = handle_get_dr_count(self.client, msg)
         elif isinstance(msg, GetDagRunState):
-            dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id)
-            resp = DagRunStateResult.from_api_response(dr_resp)
+            resp, dump_opts = handle_get_dag_run_state(self.client, msg)
 
         elif isinstance(msg, GetTICount):
-            resp = self.client.task_instances.get_count(
-                dag_id=msg.dag_id,
-                map_index=msg.map_index,
-                task_ids=msg.task_ids,
-                task_group_id=msg.task_group_id,
-                logical_dates=msg.logical_dates,
-                run_ids=msg.run_ids,
-                states=msg.states,
-            )
+            resp, dump_opts = handle_get_ti_count(self.client, msg)
 
         elif isinstance(msg, GetTaskStates):
-            run_id_task_state_map = self.client.task_instances.get_task_states(
-                dag_id=msg.dag_id,
-                map_index=msg.map_index,
-                task_ids=msg.task_ids,
-                task_group_id=msg.task_group_id,
-                logical_dates=msg.logical_dates,
-                run_ids=msg.run_ids,
-            )
-            if isinstance(run_id_task_state_map, TaskStatesResponse):
-                resp = 
TaskStatesResult.from_api_response(run_id_task_state_map)
-            else:
-                resp = run_id_task_state_map
+            resp, dump_opts = handle_get_task_states(self.client, msg)
         elif isinstance(msg, GetPreviousTI):
-            resp = self.client.task_instances.get_previous(
-                dag_id=msg.dag_id,
-                task_id=msg.task_id,
-                logical_date=msg.logical_date,
-                map_index=msg.map_index,
-                state=msg.state,
-            )
+            resp, dump_opts = handle_get_previous_ti(self.client, msg)
         elif isinstance(msg, UpdateHITLDetail):
             api_resp = self.client.hitl.update_response(
                 ti_id=msg.ti_id,
diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py 
b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py
index fbd0e1cee58..959be43fe93 100644
--- a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py
+++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py
@@ -28,19 +28,45 @@ via ``send_msg``.
 from __future__ import annotations
 
 from typing import TYPE_CHECKING
+from uuid import UUID
 
 from airflow.sdk.api.datamodels._generated import (
     ConnectionResponse,
+    DagRunStateResponse,
+    TaskStatesResponse,
     VariableResponse,
+    XComResponse,
+    XComSequenceIndexResponse,
+    XComSequenceSliceResponse,
 )
 from airflow.sdk.execution_time.comms import (
     ConnectionResult,
+    DagRunStateResult,
+    DeleteVariable,
+    DeleteXCom,
     GetConnection,
+    GetDagRunState,
+    GetDRCount,
+    GetPreviousDagRun,
+    GetPreviousTI,
+    GetTaskStates,
+    GetTICount,
     GetVariable,
     GetVariableKeys,
+    GetXCom,
+    GetXComCount,
+    GetXComSequenceItem,
+    GetXComSequenceSlice,
     MaskSecret,
+    PrevSuccessfulDagRunResult,
+    PutVariable,
+    SetXCom,
+    TaskStatesResult,
     VariableKeysResult,
     VariableResult,
+    XComResult,
+    XComSequenceIndexResult,
+    XComSequenceSliceResult,
 )
 from airflow.sdk.log import mask_secret
 
@@ -86,3 +112,163 @@ def handle_get_variable_keys(
 def handle_mask_secret(msg: MaskSecret) -> None:
     """Register a value with the secrets masker."""
     mask_secret(msg.value, msg.name)
+
+
+def handle_put_variable(client: Client, msg: PutVariable) -> tuple[BaseModel | 
None, dict[str, bool]]:
+    """Store a variable value."""
+    client.variables.set(msg.key, msg.value, msg.description)
+    return None, {}
+
+
+def handle_delete_variable(client: Client, msg: DeleteVariable) -> 
tuple[BaseModel | None, dict[str, bool]]:
+    """Delete a variable value."""
+    resp = client.variables.delete(msg.key)
+    return resp, {}
+
+
+def handle_get_ti_count(client: Client, msg: GetTICount) -> tuple[BaseModel | 
None, dict[str, bool]]:
+    """Fetch task instance counts."""
+    resp = client.task_instances.get_count(
+        dag_id=msg.dag_id,
+        map_index=msg.map_index,
+        task_ids=msg.task_ids,
+        task_group_id=msg.task_group_id,
+        logical_dates=msg.logical_dates,
+        run_ids=msg.run_ids,
+        states=msg.states,
+    )
+    return resp, {}
+
+
+def handle_get_task_states(client: Client, msg: GetTaskStates) -> 
tuple[BaseModel | None, dict[str, bool]]:
+    """Fetch task states and normalize them for supervisor response 
handling."""
+    task_states_map = client.task_instances.get_task_states(
+        dag_id=msg.dag_id,
+        map_index=msg.map_index,
+        task_ids=msg.task_ids,
+        task_group_id=msg.task_group_id,
+        logical_dates=msg.logical_dates,
+        run_ids=msg.run_ids,
+    )
+    if isinstance(task_states_map, TaskStatesResponse):
+        return TaskStatesResult.from_api_response(task_states_map), {}
+    return task_states_map, {}
+
+
+def handle_get_previous_ti(client: Client, msg: GetPreviousTI) -> 
tuple[BaseModel | None, dict[str, bool]]:
+    """Fetch the previous task instance."""
+    resp = client.task_instances.get_previous(
+        dag_id=msg.dag_id,
+        task_id=msg.task_id,
+        logical_date=msg.logical_date,
+        map_index=msg.map_index,
+        state=msg.state,
+    )
+    return resp, {}
+
+
+def handle_set_xcom(client: Client, msg: SetXCom) -> tuple[BaseModel | None, 
dict[str, bool]]:
+    """Store an XCom value."""
+    client.xcoms.set(
+        msg.dag_id,
+        msg.run_id,
+        msg.task_id,
+        msg.key,
+        msg.value,
+        msg.map_index,
+        dag_result=msg.dag_result,
+        mapped_length=msg.mapped_length,
+    )
+    return None, {}
+
+
+def handle_delete_xcom(client: Client, msg: DeleteXCom) -> tuple[BaseModel | 
None, dict[str, bool]]:
+    """Delete an XCom value."""
+    client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, 
msg.map_index)
+    return None, {}
+
+
+def handle_get_dr_count(client: Client, msg: GetDRCount) -> tuple[BaseModel | 
None, dict[str, bool]]:
+    """Fetch dag run counts."""
+    resp = client.dag_runs.get_count(
+        dag_id=msg.dag_id,
+        logical_dates=msg.logical_dates,
+        run_ids=msg.run_ids,
+        states=msg.states,
+    )
+    return resp, {}
+
+
+def handle_get_dag_run_state(client: Client, msg: GetDagRunState) -> 
tuple[BaseModel | None, dict[str, bool]]:
+    """Fetch dag run state."""
+    dr_resp = client.dag_runs.get_state(msg.dag_id, msg.run_id)
+    if isinstance(dr_resp, DagRunStateResponse):
+        return DagRunStateResult.from_api_response(dr_resp), {}
+    return dr_resp, {}
+
+
+def handle_get_previous_dag_run(
+    client: Client, msg: GetPreviousDagRun
+) -> tuple[BaseModel | None, dict[str, bool]]:
+    """Fetch the previous dag run."""
+    resp = client.dag_runs.get_previous(
+        dag_id=msg.dag_id,
+        logical_date=msg.logical_date,
+        state=msg.state,
+    )
+    return resp, {}
+
+
+def handle_get_prev_successful_dag_run(
+    client: Client, subprocess_id: UUID
+) -> tuple[BaseModel | None, dict[str, bool]]:
+    """Fetch the previous successful dag run using the caller's current id."""
+    dagrun_resp = 
client.task_instances.get_previous_successful_dagrun(subprocess_id)
+    dagrun_result = 
PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
+    return dagrun_result, {"exclude_unset": True}
+
+
+def handle_get_xcom_count(client: Client, msg: GetXComCount) -> 
tuple[BaseModel | None, dict[str, bool]]:
+    """Fetch XCom count metadata."""
+    resp = client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key)
+    return resp, {}
+
+
+def handle_get_xcom_sequence_item(
+    client: Client, msg: GetXComSequenceItem
+) -> tuple[BaseModel | None, dict[str, bool]]:
+    """Fetch an XCom sequence item and normalize it for supervisor response 
handling."""
+    xcom = client.xcoms.get_sequence_item(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.offset)
+    if isinstance(xcom, XComSequenceIndexResponse):
+        return XComSequenceIndexResult.from_response(xcom), {}
+    return xcom, {}
+
+
+def handle_get_xcom_sequence_slice(
+    client: Client, msg: GetXComSequenceSlice
+) -> tuple[BaseModel | None, dict[str, bool]]:
+    """Fetch an XCom sequence slice and normalize it for supervisor response 
handling."""
+    xcoms = client.xcoms.get_sequence_slice(
+        msg.dag_id,
+        msg.run_id,
+        msg.task_id,
+        msg.key,
+        msg.start,
+        msg.stop,
+        msg.step,
+        msg.include_prior_dates,
+    )
+    if isinstance(xcoms, XComSequenceSliceResponse):
+        return XComSequenceSliceResult.from_response(xcoms), {}
+    return xcoms, {}
+
+
+def handle_get_xcom(client: Client, msg: GetXCom) -> tuple[BaseModel | None, 
dict[str, bool]]:
+    """Fetch an XCom and normalize it for supervisor response handling."""
+    xcom = client.xcoms.get(
+        msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, 
msg.include_prior_dates
+    )
+    if isinstance(xcom, XComResponse):
+        xcom_result = XComResult.from_xcom_response(xcom)
+        return xcom_result, {"exclude_unset": True}
+    return xcom, {}
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 3e705b9d211..f7af97ffd8c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -55,8 +55,6 @@ from airflow.sdk.api.datamodels._generated import (
     ConnectionResponse,
     TaskInstance,
     TaskInstanceState,
-    TaskStatesResponse,
-    XComSequenceIndexResponse,
 )
 from airflow.sdk.configuration import conf
 from airflow.sdk.exceptions import ErrorType
@@ -72,7 +70,6 @@ from airflow.sdk.execution_time.comms import (
     CreateHITLDetailPayload,
     DagResult,
     DagRunResult,
-    DagRunStateResult,
     DeferTask,
     DeleteAssetStateByName,
     DeleteAssetStateByUri,
@@ -110,7 +107,6 @@ from airflow.sdk.execution_time.comms import (
     InactiveAssetsResult,
     MaskSecret,
     OKResponse,
-    PrevSuccessfulDagRunResult,
     PutVariable,
     RescheduleTask,
     ResendLoggingFD,
@@ -128,21 +124,32 @@ from airflow.sdk.execution_time.comms import (
     TaskBreadcrumbsResult,
     TaskState,
     TaskStateResult,
-    TaskStatesResult,
     ToSupervisor,
     TriggerDagRun,
     ValidateInletsAndOutlets,
-    XComResult,
-    XComSequenceIndexResult,
-    XComSequenceSliceResult,
     _RequestFrame,
     _ResponseFrame,
 )
 from airflow.sdk.execution_time.request_handlers import (
+    handle_delete_variable,
+    handle_delete_xcom,
     handle_get_connection,
+    handle_get_dag_run_state,
+    handle_get_dr_count,
+    handle_get_prev_successful_dag_run,
+    handle_get_previous_dag_run,
+    handle_get_previous_ti,
+    handle_get_task_states,
+    handle_get_ti_count,
     handle_get_variable,
     handle_get_variable_keys,
+    handle_get_xcom,
+    handle_get_xcom_count,
+    handle_get_xcom_sequence_item,
+    handle_get_xcom_sequence_slice,
     handle_mask_secret,
+    handle_put_variable,
+    handle_set_xcom,
 )
 
 try:
@@ -1550,33 +1557,11 @@ class ActivitySubprocess(WatchedSubprocess):
         elif isinstance(msg, GetVariableKeys):
             resp, dump_opts = handle_get_variable_keys(self.client, msg)
         elif isinstance(msg, GetXCom):
-            xcom = self.client.xcoms.get(
-                msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, 
msg.include_prior_dates
-            )
-            xcom_result = XComResult.from_xcom_response(xcom)
-            resp = xcom_result
-        elif isinstance(msg, GetXComCount):
-            resp = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, 
msg.key)
+            resp, dump_opts = handle_get_xcom(self.client, msg)
         elif isinstance(msg, GetXComSequenceItem):
-            xcom = self.client.xcoms.get_sequence_item(
-                msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
-            )
-            if isinstance(xcom, XComSequenceIndexResponse):
-                resp = XComSequenceIndexResult.from_response(xcom)
-            else:
-                resp = xcom
+            resp, dump_opts = handle_get_xcom_sequence_item(self.client, msg)
         elif isinstance(msg, GetXComSequenceSlice):
-            xcoms = self.client.xcoms.get_sequence_slice(
-                msg.dag_id,
-                msg.run_id,
-                msg.task_id,
-                msg.key,
-                msg.start,
-                msg.stop,
-                msg.step,
-                msg.include_prior_dates,
-            )
-            resp = XComSequenceSliceResult.from_response(xcoms)
+            resp, dump_opts = handle_get_xcom_sequence_slice(self.client, msg)
         elif isinstance(msg, DeferTask):
             self._rendered_map_index = msg.rendered_map_index
             self._send_terminal_state_msg(msg)
@@ -1585,20 +1570,11 @@ class ActivitySubprocess(WatchedSubprocess):
         elif isinstance(msg, SkipDownstreamTasks):
             self.client.task_instances.skip_downstream_tasks(self.id, msg)
         elif isinstance(msg, SetXCom):
-            self.client.xcoms.set(
-                msg.dag_id,
-                msg.run_id,
-                msg.task_id,
-                msg.key,
-                msg.value,
-                msg.map_index,
-                dag_result=msg.dag_result,
-                mapped_length=msg.mapped_length,
-            )
+            resp, dump_opts = handle_set_xcom(self.client, msg)
         elif isinstance(msg, DeleteXCom):
-            self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.map_index)
+            resp, dump_opts = handle_delete_xcom(self.client, msg)
         elif isinstance(msg, PutVariable):
-            self.client.variables.set(msg.key, msg.value, msg.description)
+            resp, dump_opts = handle_put_variable(self.client, msg)
         elif isinstance(msg, SetRenderedFields):
             self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
         elif isinstance(msg, SetRenderedMapIndex):
@@ -1645,10 +1621,9 @@ class ActivitySubprocess(WatchedSubprocess):
             resp = asset_event_result
             dump_opts = {"exclude_unset": True}
         elif isinstance(msg, GetPrevSuccessfulDagRun):
-            dagrun_resp = 
self.client.task_instances.get_previous_successful_dagrun(self.id)
-            dagrun_result = 
PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
-            resp = dagrun_result
-            dump_opts = {"exclude_unset": True}
+            resp, dump_opts = handle_get_prev_successful_dag_run(self.client, 
self.id)
+        elif isinstance(msg, GetXComCount):
+            resp, dump_opts = handle_get_xcom_count(self.client, msg)
         elif isinstance(msg, TriggerDagRun):
             resp = self.client.dag_runs.trigger(
                 msg.dag_id, msg.run_id, msg.conf, msg.logical_date, 
msg.run_after, msg.reset_dag_run, msg.note
@@ -1656,60 +1631,25 @@ class ActivitySubprocess(WatchedSubprocess):
         elif isinstance(msg, GetDagRun):
             dr_resp = self.client.dag_runs.get_detail(msg.dag_id, msg.run_id)
             resp = DagRunResult.from_api_response(dr_resp)
-        elif isinstance(msg, GetDagRunState):
-            dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id)
-            resp = DagRunStateResult.from_api_response(dr_resp)
         elif isinstance(msg, GetTaskRescheduleStartDate):
             resp = 
self.client.task_instances.get_reschedule_start_date(msg.ti_id, msg.try_number)
         elif isinstance(msg, GetTICount):
-            resp = self.client.task_instances.get_count(
-                dag_id=msg.dag_id,
-                map_index=msg.map_index,
-                task_ids=msg.task_ids,
-                task_group_id=msg.task_group_id,
-                logical_dates=msg.logical_dates,
-                run_ids=msg.run_ids,
-                states=msg.states,
-            )
+            resp, dump_opts = handle_get_ti_count(self.client, msg)
         elif isinstance(msg, GetTaskStates):
-            task_states_map = self.client.task_instances.get_task_states(
-                dag_id=msg.dag_id,
-                map_index=msg.map_index,
-                task_ids=msg.task_ids,
-                task_group_id=msg.task_group_id,
-                logical_dates=msg.logical_dates,
-                run_ids=msg.run_ids,
-            )
-            if isinstance(task_states_map, TaskStatesResponse):
-                resp = TaskStatesResult.from_api_response(task_states_map)
-            else:
-                resp = task_states_map
+            resp, dump_opts = handle_get_task_states(self.client, msg)
         elif isinstance(msg, GetTaskBreadcrumbs):
             api_resp = 
self.client.task_instances.get_task_breakcrumbs(dag_id=msg.dag_id, 
run_id=msg.run_id)
             resp = TaskBreadcrumbsResult.from_api_response(api_resp)
         elif isinstance(msg, GetDRCount):
-            resp = self.client.dag_runs.get_count(
-                dag_id=msg.dag_id,
-                logical_dates=msg.logical_dates,
-                run_ids=msg.run_ids,
-                states=msg.states,
-            )
+            resp, dump_opts = handle_get_dr_count(self.client, msg)
+        elif isinstance(msg, GetDagRunState):
+            resp, dump_opts = handle_get_dag_run_state(self.client, msg)
         elif isinstance(msg, GetPreviousDagRun):
-            resp = self.client.dag_runs.get_previous(
-                dag_id=msg.dag_id,
-                logical_date=msg.logical_date,
-                state=msg.state,
-            )
+            resp, dump_opts = handle_get_previous_dag_run(self.client, msg)
         elif isinstance(msg, GetPreviousTI):
-            resp = self.client.task_instances.get_previous(
-                dag_id=msg.dag_id,
-                task_id=msg.task_id,
-                logical_date=msg.logical_date,
-                map_index=msg.map_index,
-                state=msg.state,
-            )
+            resp, dump_opts = handle_get_previous_ti(self.client, msg)
         elif isinstance(msg, DeleteVariable):
-            resp = self.client.variables.delete(msg.key)
+            resp, dump_opts = handle_delete_variable(self.client, msg)
         elif isinstance(msg, ValidateInletsAndOutlets):
             inactive_assets_resp = 
self.client.task_instances.validate_inlets_and_outlets(msg.ti_id)
             resp = 
InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp)

Reply via email to