This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 075937c4a9b Refactor and extract shared request handler logic from
supervisor _handle_request methods (#65624)
075937c4a9b is described below
commit 075937c4a9bbc1f3f459b4e879438e835ce55e92
Author: Paul <[email protected]>
AuthorDate: Thu May 21 07:13:52 2026 -0700
Refactor and extract shared request handler logic from supervisor
_handle_request methods (#65624)
---
.../src/airflow/dag_processing/processor.py | 98 ++++-------
.../src/airflow/jobs/triggerer_job_runner.py | 81 +++------
.../airflow/sdk/execution_time/request_handlers.py | 186 +++++++++++++++++++++
.../src/airflow/sdk/execution_time/supervisor.py | 124 ++++----------
4 files changed, 266 insertions(+), 223 deletions(-)
diff --git a/airflow-core/src/airflow/dag_processing/processor.py
b/airflow-core/src/airflow/dag_processing/processor.py
index d7c0a9d2b59..303b62d1411 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -69,6 +69,21 @@ from airflow.sdk.execution_time.comms import (
XComSequenceIndexResult,
XComSequenceSliceResult,
)
+from airflow.sdk.execution_time.request_handlers import (
+ handle_delete_variable,
+ handle_get_prev_successful_dag_run,
+ handle_get_previous_dag_run,
+ handle_get_previous_ti,
+ handle_get_task_states,
+ handle_get_ti_count,
+ handle_get_variable_keys,
+ handle_get_xcom,
+ handle_get_xcom_count,
+ handle_get_xcom_sequence_item,
+ handle_get_xcom_sequence_slice,
+ handle_mask_secret,
+ handle_put_variable,
+)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance,
_send_error_email_notification
from airflow.sdk.log import mask_secret
@@ -601,13 +616,11 @@ class DagFileProcessorProcess(WatchedSubprocess):
def _handle_request(self, msg: ToManager, log: FilteringBoundLogger,
req_id: int) -> None:
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
- TaskStatesResponse,
VariableResponse,
- XComSequenceIndexResponse,
)
resp: BaseModel | None = None
- dump_opts = {}
+ dump_opts: dict[str, bool] = {}
if isinstance(msg, DagFileParsingResult):
self.parsing_result = msg
elif isinstance(msg, GetConnection):
@@ -633,86 +646,31 @@ class DagFileProcessorProcess(WatchedSubprocess):
else:
resp = var
elif isinstance(msg, GetVariableKeys):
- from airflow.sdk.execution_time.request_handlers import
handle_get_variable_keys
-
resp, dump_opts = handle_get_variable_keys(self.client, msg)
elif isinstance(msg, PutVariable):
- self.client.variables.set(msg.key, msg.value, msg.description)
+ resp, dump_opts = handle_put_variable(self.client, msg)
elif isinstance(msg, DeleteVariable):
- resp = self.client.variables.delete(msg.key)
+ resp, dump_opts = handle_delete_variable(self.client, msg)
elif isinstance(msg, GetPreviousDagRun):
- resp = self.client.dag_runs.get_previous(
- dag_id=msg.dag_id,
- logical_date=msg.logical_date,
- state=msg.state,
- )
+ resp, dump_opts = handle_get_previous_dag_run(self.client, msg)
elif isinstance(msg, GetPrevSuccessfulDagRun):
- dagrun_resp =
self.client.task_instances.get_previous_successful_dagrun(self.id)
- dagrun_result =
PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
- resp = dagrun_result
- dump_opts = {"exclude_unset": True}
+ resp, dump_opts = handle_get_prev_successful_dag_run(self.client,
self.id)
elif isinstance(msg, GetXCom):
- xcom = self.client.xcoms.get(
- msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index,
msg.include_prior_dates
- )
- xcom_result = XComResult.from_xcom_response(xcom)
- resp = xcom_result
+ resp, dump_opts = handle_get_xcom(self.client, msg)
elif isinstance(msg, GetXComCount):
- resp = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id,
msg.key)
+ resp, dump_opts = handle_get_xcom_count(self.client, msg)
elif isinstance(msg, GetXComSequenceItem):
- xcom = self.client.xcoms.get_sequence_item(
- msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
- )
- if isinstance(xcom, XComSequenceIndexResponse):
- resp = XComSequenceIndexResult.from_response(xcom)
- else:
- resp = xcom
+ resp, dump_opts = handle_get_xcom_sequence_item(self.client, msg)
elif isinstance(msg, GetXComSequenceSlice):
- xcoms = self.client.xcoms.get_sequence_slice(
- msg.dag_id,
- msg.run_id,
- msg.task_id,
- msg.key,
- msg.start,
- msg.stop,
- msg.step,
- msg.include_prior_dates,
- )
- resp = XComSequenceSliceResult.from_response(xcoms)
+ resp, dump_opts = handle_get_xcom_sequence_slice(self.client, msg)
elif isinstance(msg, MaskSecret):
- # Use sdk masker in dag processor and triggerer because those use
the task sdk machinery
- mask_secret(msg.value, msg.name)
+ handle_mask_secret(msg)
elif isinstance(msg, GetTICount):
- resp = self.client.task_instances.get_count(
- dag_id=msg.dag_id,
- map_index=msg.map_index,
- task_ids=msg.task_ids,
- task_group_id=msg.task_group_id,
- logical_dates=msg.logical_dates,
- run_ids=msg.run_ids,
- states=msg.states,
- )
+ resp, dump_opts = handle_get_ti_count(self.client, msg)
elif isinstance(msg, GetTaskStates):
- task_states_map = self.client.task_instances.get_task_states(
- dag_id=msg.dag_id,
- map_index=msg.map_index,
- task_ids=msg.task_ids,
- task_group_id=msg.task_group_id,
- logical_dates=msg.logical_dates,
- run_ids=msg.run_ids,
- )
- if isinstance(task_states_map, TaskStatesResponse):
- resp = TaskStatesResult.from_api_response(task_states_map)
- else:
- resp = task_states_map
+ resp, dump_opts = handle_get_task_states(self.client, msg)
elif isinstance(msg, GetPreviousTI):
- resp = self.client.task_instances.get_previous(
- dag_id=msg.dag_id,
- task_id=msg.task_id,
- logical_date=msg.logical_date,
- map_index=msg.map_index,
- state=msg.state,
- )
+ resp, dump_opts = handle_get_previous_ti(self.client, msg)
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 2b4db481c26..6f8f7baae84 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -90,10 +90,20 @@ from airflow.sdk.execution_time.comms import (
_RequestFrame,
)
from airflow.sdk.execution_time.request_handlers import (
+ handle_delete_variable,
+ handle_delete_xcom,
handle_get_connection,
+ handle_get_dag_run_state,
+ handle_get_dr_count,
+ handle_get_previous_ti,
+ handle_get_task_states,
+ handle_get_ti_count,
handle_get_variable,
handle_get_variable_keys,
+ handle_get_xcom,
handle_mask_secret,
+ handle_put_variable,
+ handle_set_xcom,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess,
make_buffered_socket_reader
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
@@ -494,10 +504,6 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
return client
def _handle_request(self, msg: ToTriggerSupervisor, log:
FilteringBoundLogger, req_id: int) -> None:
- from airflow.sdk.api.datamodels._generated import (
- TaskStatesResponse,
- XComResponse,
- )
resp: BaseModel | None = None
dump_opts: dict[str, bool] = {}
@@ -536,78 +542,31 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
elif isinstance(msg, GetConnection):
resp, dump_opts = handle_get_connection(self.client, msg)
elif isinstance(msg, DeleteVariable):
- resp = self.client.variables.delete(msg.key)
+ resp, dump_opts = handle_delete_variable(self.client, msg)
elif isinstance(msg, GetVariable):
resp, dump_opts = handle_get_variable(self.client, msg)
elif isinstance(msg, GetVariableKeys):
resp, dump_opts = handle_get_variable_keys(self.client, msg)
elif isinstance(msg, PutVariable):
- self.client.variables.set(msg.key, msg.value, msg.description)
+ resp, dump_opts = handle_put_variable(self.client, msg)
elif isinstance(msg, DeleteXCom):
- self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.map_index)
+ resp, dump_opts = handle_delete_xcom(self.client, msg)
elif isinstance(msg, GetXCom):
- xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.map_index)
- if isinstance(xcom, XComResponse):
- xcom_result = XComResult.from_xcom_response(xcom)
- resp = xcom_result
- dump_opts = {"exclude_unset": True}
- else:
- resp = xcom
+ resp, dump_opts = handle_get_xcom(self.client, msg)
elif isinstance(msg, SetXCom):
- self.client.xcoms.set(
- msg.dag_id,
- msg.run_id,
- msg.task_id,
- msg.key,
- msg.value,
- msg.map_index,
- dag_result=msg.dag_result,
- mapped_length=msg.mapped_length,
- )
+ resp, dump_opts = handle_set_xcom(self.client, msg)
elif isinstance(msg, GetDRCount):
- dr_count = self.client.dag_runs.get_count(
- dag_id=msg.dag_id,
- logical_dates=msg.logical_dates,
- run_ids=msg.run_ids,
- states=msg.states,
- )
- resp = dr_count
+ resp, dump_opts = handle_get_dr_count(self.client, msg)
elif isinstance(msg, GetDagRunState):
- dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id)
- resp = DagRunStateResult.from_api_response(dr_resp)
+ resp, dump_opts = handle_get_dag_run_state(self.client, msg)
elif isinstance(msg, GetTICount):
- resp = self.client.task_instances.get_count(
- dag_id=msg.dag_id,
- map_index=msg.map_index,
- task_ids=msg.task_ids,
- task_group_id=msg.task_group_id,
- logical_dates=msg.logical_dates,
- run_ids=msg.run_ids,
- states=msg.states,
- )
+ resp, dump_opts = handle_get_ti_count(self.client, msg)
elif isinstance(msg, GetTaskStates):
- run_id_task_state_map = self.client.task_instances.get_task_states(
- dag_id=msg.dag_id,
- map_index=msg.map_index,
- task_ids=msg.task_ids,
- task_group_id=msg.task_group_id,
- logical_dates=msg.logical_dates,
- run_ids=msg.run_ids,
- )
- if isinstance(run_id_task_state_map, TaskStatesResponse):
- resp =
TaskStatesResult.from_api_response(run_id_task_state_map)
- else:
- resp = run_id_task_state_map
+ resp, dump_opts = handle_get_task_states(self.client, msg)
elif isinstance(msg, GetPreviousTI):
- resp = self.client.task_instances.get_previous(
- dag_id=msg.dag_id,
- task_id=msg.task_id,
- logical_date=msg.logical_date,
- map_index=msg.map_index,
- state=msg.state,
- )
+ resp, dump_opts = handle_get_previous_ti(self.client, msg)
elif isinstance(msg, UpdateHITLDetail):
api_resp = self.client.hitl.update_response(
ti_id=msg.ti_id,
diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py
b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py
index fbd0e1cee58..959be43fe93 100644
--- a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py
+++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py
@@ -28,19 +28,45 @@ via ``send_msg``.
from __future__ import annotations
from typing import TYPE_CHECKING
+from uuid import UUID
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
+ DagRunStateResponse,
+ TaskStatesResponse,
VariableResponse,
+ XComResponse,
+ XComSequenceIndexResponse,
+ XComSequenceSliceResponse,
)
from airflow.sdk.execution_time.comms import (
ConnectionResult,
+ DagRunStateResult,
+ DeleteVariable,
+ DeleteXCom,
GetConnection,
+ GetDagRunState,
+ GetDRCount,
+ GetPreviousDagRun,
+ GetPreviousTI,
+ GetTaskStates,
+ GetTICount,
GetVariable,
GetVariableKeys,
+ GetXCom,
+ GetXComCount,
+ GetXComSequenceItem,
+ GetXComSequenceSlice,
MaskSecret,
+ PrevSuccessfulDagRunResult,
+ PutVariable,
+ SetXCom,
+ TaskStatesResult,
VariableKeysResult,
VariableResult,
+ XComResult,
+ XComSequenceIndexResult,
+ XComSequenceSliceResult,
)
from airflow.sdk.log import mask_secret
@@ -86,3 +112,163 @@ def handle_get_variable_keys(
def handle_mask_secret(msg: MaskSecret) -> None:
"""Register a value with the secrets masker."""
mask_secret(msg.value, msg.name)
+
+
+def handle_put_variable(client: Client, msg: PutVariable) -> tuple[BaseModel |
None, dict[str, bool]]:
+ """Store a variable value."""
+ client.variables.set(msg.key, msg.value, msg.description)
+ return None, {}
+
+
+def handle_delete_variable(client: Client, msg: DeleteVariable) ->
tuple[BaseModel | None, dict[str, bool]]:
+ """Delete a variable value."""
+ resp = client.variables.delete(msg.key)
+ return resp, {}
+
+
+def handle_get_ti_count(client: Client, msg: GetTICount) -> tuple[BaseModel |
None, dict[str, bool]]:
+ """Fetch task instance counts."""
+ resp = client.task_instances.get_count(
+ dag_id=msg.dag_id,
+ map_index=msg.map_index,
+ task_ids=msg.task_ids,
+ task_group_id=msg.task_group_id,
+ logical_dates=msg.logical_dates,
+ run_ids=msg.run_ids,
+ states=msg.states,
+ )
+ return resp, {}
+
+
+def handle_get_task_states(client: Client, msg: GetTaskStates) ->
tuple[BaseModel | None, dict[str, bool]]:
+ """Fetch task states and normalize them for supervisor response
handling."""
+ task_states_map = client.task_instances.get_task_states(
+ dag_id=msg.dag_id,
+ map_index=msg.map_index,
+ task_ids=msg.task_ids,
+ task_group_id=msg.task_group_id,
+ logical_dates=msg.logical_dates,
+ run_ids=msg.run_ids,
+ )
+ if isinstance(task_states_map, TaskStatesResponse):
+ return TaskStatesResult.from_api_response(task_states_map), {}
+ return task_states_map, {}
+
+
+def handle_get_previous_ti(client: Client, msg: GetPreviousTI) ->
tuple[BaseModel | None, dict[str, bool]]:
+ """Fetch the previous task instance."""
+ resp = client.task_instances.get_previous(
+ dag_id=msg.dag_id,
+ task_id=msg.task_id,
+ logical_date=msg.logical_date,
+ map_index=msg.map_index,
+ state=msg.state,
+ )
+ return resp, {}
+
+
+def handle_set_xcom(client: Client, msg: SetXCom) -> tuple[BaseModel | None,
dict[str, bool]]:
+ """Store an XCom value."""
+ client.xcoms.set(
+ msg.dag_id,
+ msg.run_id,
+ msg.task_id,
+ msg.key,
+ msg.value,
+ msg.map_index,
+ dag_result=msg.dag_result,
+ mapped_length=msg.mapped_length,
+ )
+ return None, {}
+
+
+def handle_delete_xcom(client: Client, msg: DeleteXCom) -> tuple[BaseModel |
None, dict[str, bool]]:
+ """Delete an XCom value."""
+ client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key,
msg.map_index)
+ return None, {}
+
+
+def handle_get_dr_count(client: Client, msg: GetDRCount) -> tuple[BaseModel |
None, dict[str, bool]]:
+ """Fetch dag run counts."""
+ resp = client.dag_runs.get_count(
+ dag_id=msg.dag_id,
+ logical_dates=msg.logical_dates,
+ run_ids=msg.run_ids,
+ states=msg.states,
+ )
+ return resp, {}
+
+
+def handle_get_dag_run_state(client: Client, msg: GetDagRunState) ->
tuple[BaseModel | None, dict[str, bool]]:
+ """Fetch dag run state."""
+ dr_resp = client.dag_runs.get_state(msg.dag_id, msg.run_id)
+ if isinstance(dr_resp, DagRunStateResponse):
+ return DagRunStateResult.from_api_response(dr_resp), {}
+ return dr_resp, {}
+
+
+def handle_get_previous_dag_run(
+ client: Client, msg: GetPreviousDagRun
+) -> tuple[BaseModel | None, dict[str, bool]]:
+ """Fetch the previous dag run."""
+ resp = client.dag_runs.get_previous(
+ dag_id=msg.dag_id,
+ logical_date=msg.logical_date,
+ state=msg.state,
+ )
+ return resp, {}
+
+
+def handle_get_prev_successful_dag_run(
+ client: Client, subprocess_id: UUID
+) -> tuple[BaseModel | None, dict[str, bool]]:
+ """Fetch the previous successful dag run using the caller's current id."""
+ dagrun_resp =
client.task_instances.get_previous_successful_dagrun(subprocess_id)
+ dagrun_result =
PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
+ return dagrun_result, {"exclude_unset": True}
+
+
+def handle_get_xcom_count(client: Client, msg: GetXComCount) ->
tuple[BaseModel | None, dict[str, bool]]:
+ """Fetch XCom count metadata."""
+ resp = client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key)
+ return resp, {}
+
+
+def handle_get_xcom_sequence_item(
+ client: Client, msg: GetXComSequenceItem
+) -> tuple[BaseModel | None, dict[str, bool]]:
+ """Fetch an XCom sequence item and normalize it for supervisor response
handling."""
+ xcom = client.xcoms.get_sequence_item(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.offset)
+ if isinstance(xcom, XComSequenceIndexResponse):
+ return XComSequenceIndexResult.from_response(xcom), {}
+ return xcom, {}
+
+
+def handle_get_xcom_sequence_slice(
+ client: Client, msg: GetXComSequenceSlice
+) -> tuple[BaseModel | None, dict[str, bool]]:
+ """Fetch an XCom sequence slice and normalize it for supervisor response
handling."""
+ xcoms = client.xcoms.get_sequence_slice(
+ msg.dag_id,
+ msg.run_id,
+ msg.task_id,
+ msg.key,
+ msg.start,
+ msg.stop,
+ msg.step,
+ msg.include_prior_dates,
+ )
+ if isinstance(xcoms, XComSequenceSliceResponse):
+ return XComSequenceSliceResult.from_response(xcoms), {}
+ return xcoms, {}
+
+
+def handle_get_xcom(client: Client, msg: GetXCom) -> tuple[BaseModel | None,
dict[str, bool]]:
+ """Fetch an XCom and normalize it for supervisor response handling."""
+ xcom = client.xcoms.get(
+ msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index,
msg.include_prior_dates
+ )
+ if isinstance(xcom, XComResponse):
+ xcom_result = XComResult.from_xcom_response(xcom)
+ return xcom_result, {"exclude_unset": True}
+ return xcom, {}
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 3e705b9d211..f7af97ffd8c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -55,8 +55,6 @@ from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
TaskInstance,
TaskInstanceState,
- TaskStatesResponse,
- XComSequenceIndexResponse,
)
from airflow.sdk.configuration import conf
from airflow.sdk.exceptions import ErrorType
@@ -72,7 +70,6 @@ from airflow.sdk.execution_time.comms import (
CreateHITLDetailPayload,
DagResult,
DagRunResult,
- DagRunStateResult,
DeferTask,
DeleteAssetStateByName,
DeleteAssetStateByUri,
@@ -110,7 +107,6 @@ from airflow.sdk.execution_time.comms import (
InactiveAssetsResult,
MaskSecret,
OKResponse,
- PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
ResendLoggingFD,
@@ -128,21 +124,32 @@ from airflow.sdk.execution_time.comms import (
TaskBreadcrumbsResult,
TaskState,
TaskStateResult,
- TaskStatesResult,
ToSupervisor,
TriggerDagRun,
ValidateInletsAndOutlets,
- XComResult,
- XComSequenceIndexResult,
- XComSequenceSliceResult,
_RequestFrame,
_ResponseFrame,
)
from airflow.sdk.execution_time.request_handlers import (
+ handle_delete_variable,
+ handle_delete_xcom,
handle_get_connection,
+ handle_get_dag_run_state,
+ handle_get_dr_count,
+ handle_get_prev_successful_dag_run,
+ handle_get_previous_dag_run,
+ handle_get_previous_ti,
+ handle_get_task_states,
+ handle_get_ti_count,
handle_get_variable,
handle_get_variable_keys,
+ handle_get_xcom,
+ handle_get_xcom_count,
+ handle_get_xcom_sequence_item,
+ handle_get_xcom_sequence_slice,
handle_mask_secret,
+ handle_put_variable,
+ handle_set_xcom,
)
try:
@@ -1550,33 +1557,11 @@ class ActivitySubprocess(WatchedSubprocess):
elif isinstance(msg, GetVariableKeys):
resp, dump_opts = handle_get_variable_keys(self.client, msg)
elif isinstance(msg, GetXCom):
- xcom = self.client.xcoms.get(
- msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index,
msg.include_prior_dates
- )
- xcom_result = XComResult.from_xcom_response(xcom)
- resp = xcom_result
- elif isinstance(msg, GetXComCount):
- resp = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id,
msg.key)
+ resp, dump_opts = handle_get_xcom(self.client, msg)
elif isinstance(msg, GetXComSequenceItem):
- xcom = self.client.xcoms.get_sequence_item(
- msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
- )
- if isinstance(xcom, XComSequenceIndexResponse):
- resp = XComSequenceIndexResult.from_response(xcom)
- else:
- resp = xcom
+ resp, dump_opts = handle_get_xcom_sequence_item(self.client, msg)
elif isinstance(msg, GetXComSequenceSlice):
- xcoms = self.client.xcoms.get_sequence_slice(
- msg.dag_id,
- msg.run_id,
- msg.task_id,
- msg.key,
- msg.start,
- msg.stop,
- msg.step,
- msg.include_prior_dates,
- )
- resp = XComSequenceSliceResult.from_response(xcoms)
+ resp, dump_opts = handle_get_xcom_sequence_slice(self.client, msg)
elif isinstance(msg, DeferTask):
self._rendered_map_index = msg.rendered_map_index
self._send_terminal_state_msg(msg)
@@ -1585,20 +1570,11 @@ class ActivitySubprocess(WatchedSubprocess):
elif isinstance(msg, SkipDownstreamTasks):
self.client.task_instances.skip_downstream_tasks(self.id, msg)
elif isinstance(msg, SetXCom):
- self.client.xcoms.set(
- msg.dag_id,
- msg.run_id,
- msg.task_id,
- msg.key,
- msg.value,
- msg.map_index,
- dag_result=msg.dag_result,
- mapped_length=msg.mapped_length,
- )
+ resp, dump_opts = handle_set_xcom(self.client, msg)
elif isinstance(msg, DeleteXCom):
- self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.map_index)
+ resp, dump_opts = handle_delete_xcom(self.client, msg)
elif isinstance(msg, PutVariable):
- self.client.variables.set(msg.key, msg.value, msg.description)
+ resp, dump_opts = handle_put_variable(self.client, msg)
elif isinstance(msg, SetRenderedFields):
self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
elif isinstance(msg, SetRenderedMapIndex):
@@ -1645,10 +1621,9 @@ class ActivitySubprocess(WatchedSubprocess):
resp = asset_event_result
dump_opts = {"exclude_unset": True}
elif isinstance(msg, GetPrevSuccessfulDagRun):
- dagrun_resp =
self.client.task_instances.get_previous_successful_dagrun(self.id)
- dagrun_result =
PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
- resp = dagrun_result
- dump_opts = {"exclude_unset": True}
+ resp, dump_opts = handle_get_prev_successful_dag_run(self.client,
self.id)
+ elif isinstance(msg, GetXComCount):
+ resp, dump_opts = handle_get_xcom_count(self.client, msg)
elif isinstance(msg, TriggerDagRun):
resp = self.client.dag_runs.trigger(
msg.dag_id, msg.run_id, msg.conf, msg.logical_date,
msg.run_after, msg.reset_dag_run, msg.note
@@ -1656,60 +1631,25 @@ class ActivitySubprocess(WatchedSubprocess):
elif isinstance(msg, GetDagRun):
dr_resp = self.client.dag_runs.get_detail(msg.dag_id, msg.run_id)
resp = DagRunResult.from_api_response(dr_resp)
- elif isinstance(msg, GetDagRunState):
- dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id)
- resp = DagRunStateResult.from_api_response(dr_resp)
elif isinstance(msg, GetTaskRescheduleStartDate):
resp =
self.client.task_instances.get_reschedule_start_date(msg.ti_id, msg.try_number)
elif isinstance(msg, GetTICount):
- resp = self.client.task_instances.get_count(
- dag_id=msg.dag_id,
- map_index=msg.map_index,
- task_ids=msg.task_ids,
- task_group_id=msg.task_group_id,
- logical_dates=msg.logical_dates,
- run_ids=msg.run_ids,
- states=msg.states,
- )
+ resp, dump_opts = handle_get_ti_count(self.client, msg)
elif isinstance(msg, GetTaskStates):
- task_states_map = self.client.task_instances.get_task_states(
- dag_id=msg.dag_id,
- map_index=msg.map_index,
- task_ids=msg.task_ids,
- task_group_id=msg.task_group_id,
- logical_dates=msg.logical_dates,
- run_ids=msg.run_ids,
- )
- if isinstance(task_states_map, TaskStatesResponse):
- resp = TaskStatesResult.from_api_response(task_states_map)
- else:
- resp = task_states_map
+ resp, dump_opts = handle_get_task_states(self.client, msg)
elif isinstance(msg, GetTaskBreadcrumbs):
api_resp =
self.client.task_instances.get_task_breakcrumbs(dag_id=msg.dag_id,
run_id=msg.run_id)
resp = TaskBreadcrumbsResult.from_api_response(api_resp)
elif isinstance(msg, GetDRCount):
- resp = self.client.dag_runs.get_count(
- dag_id=msg.dag_id,
- logical_dates=msg.logical_dates,
- run_ids=msg.run_ids,
- states=msg.states,
- )
+ resp, dump_opts = handle_get_dr_count(self.client, msg)
+ elif isinstance(msg, GetDagRunState):
+ resp, dump_opts = handle_get_dag_run_state(self.client, msg)
elif isinstance(msg, GetPreviousDagRun):
- resp = self.client.dag_runs.get_previous(
- dag_id=msg.dag_id,
- logical_date=msg.logical_date,
- state=msg.state,
- )
+ resp, dump_opts = handle_get_previous_dag_run(self.client, msg)
elif isinstance(msg, GetPreviousTI):
- resp = self.client.task_instances.get_previous(
- dag_id=msg.dag_id,
- task_id=msg.task_id,
- logical_date=msg.logical_date,
- map_index=msg.map_index,
- state=msg.state,
- )
+ resp, dump_opts = handle_get_previous_ti(self.client, msg)
elif isinstance(msg, DeleteVariable):
- resp = self.client.variables.delete(msg.key)
+ resp, dump_opts = handle_delete_variable(self.client, msg)
elif isinstance(msg, ValidateInletsAndOutlets):
inactive_assets_resp =
self.client.task_instances.validate_inlets_and_outlets(msg.ti_id)
resp =
InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp)