pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2906986909
##########
airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py:
##########
@@ -215,6 +215,29 @@ def validate_new_state(cls, ns: str | None) -> str:
return ns
+class PatchTaskGroupBody(StrictBaseModel):
+ """Request body for patching the state of all task instances in a task
group."""
+
+ new_state: TaskInstanceState
+ include_future: bool = False
+ include_past: bool = False
+
+ @field_validator("new_state", mode="before")
+ @classmethod
+ def validate_new_state(cls, ns: str | None) -> str:
+ """Validate new_state."""
+ valid_states = [
+ vs.name.lower()
+ for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED,
TaskInstanceState.SKIPPED)
+ ]
+ if ns is None:
+ raise ValueError("'new_state' should not be empty")
+ ns = ns.lower()
+ if ns not in valid_states:
+ raise ValueError(f"'{ns}' is not one of {valid_states}")
+ return ns
+
Review Comment:
This is a complete duplicate of the existing `validate_new_state`. Make a
common base body class for `PatchTaskGroup` and `PatchTaskInstanceBody`. Same
for other attributes.
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -153,6 +158,99 @@ def _patch_task_instance_note(
ti.task_instance_note.user_id = user.get_id()
+def _get_task_group_task_ids(
+ dag: SerializedDAG,
+ group_id: str,
+) -> list[str]:
+ """
+ Get task ids that belong to a task group.
+
+ :param dag: The serialized DAG containing the task group.
+ :param group_id: The ID of the task group.
+ :return: List of task IDs in the group.
+ :raises HTTPException: If the task group is not found or has no tasks.
+ """
+ if not hasattr(dag, "task_group"):
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ f"DAG '{dag.dag_id}' does not have task groups",
+ )
+
+ task_groups = dag.task_group.get_task_group_dict()
+ task_group = task_groups.get(group_id)
+ if not task_group:
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ f"Task group '{group_id}' not found in DAG '{dag.dag_id}'",
+ )
+
+ task_ids = [task.task_id for task in task_group.iter_tasks()]
+ if not task_ids:
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ f"Task group '{group_id}' in DAG '{dag.dag_id}' has no tasks",
+ )
Review Comment:
Nit: Not sure we should 404 here. If there is nothing to patch, just return
`[]` and proceed.
##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -847,6 +850,103 @@ def _collect_relatives(run_id: str, direction:
Literal["upstream", "downstream"]
)
+@task_instances_router.patch(
+ "/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+ responses=create_openapi_http_exception_doc(
+ [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
+ ),
+ dependencies=[
+ Depends(action_logging()),
+ Depends(requires_access_dag(method="PUT",
access_entity=DagAccessEntity.TASK_INSTANCE)),
+ ],
+ operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+ dag_id: str,
+ dag_run_id: str,
+ group_id: str,
+ dag_bag: DagBagDep,
+ body: PatchTaskGroupBody,
+ session: SessionDep,
+ user: GetUserDep,
+) -> TaskInstanceCollectionResponse:
+ """Update the state of all task instances in a task group."""
+ _patch_task_group_state(
+ dag_id=dag_id,
+ dag_run_id=dag_run_id,
+ group_id=group_id,
+ body=body,
+ dag_bag=dag_bag,
+ user=user,
+ session=session,
+ )
+
+ # Collect all TIs for the task group to build the response
+ dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+ task_ids = _get_task_group_task_ids(dag, group_id)
Review Comment:
`_get_task_group_task_ids` is done multiple times. Inside the
`_patch_task_group_state` too.
Same for `get_latest_version_of_dag`
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -268,34 +366,39 @@ def _perform_update(
results: BulkActionResponse,
update_mask: list[str] | None = Query(None),
) -> None:
- dag, tis, data = _patch_ti_validate_request(
- dag_id=dag_id,
- dag_run_id=dag_run_id,
- task_id=task_id,
- dag_bag=self.dag_bag,
- body=entity,
- session=self.session,
- update_mask=update_mask,
- )
-
- for key, _ in data.items():
- if key == "new_state":
- _patch_task_instance_state(
- task_id=task_id,
- dag_run_id=dag_run_id,
- dag=dag,
- task_instance_body=entity,
- session=self.session,
- data=data,
- )
- elif key == "note":
- _patch_task_instance_note(
- task_instance_body=entity,
- tis=tis,
- user=self.user,
- )
+ try:
+ dag, tis, data = _patch_ti_validate_request(
+ dag_id=dag_id,
+ dag_run_id=dag_run_id,
+ task_id=task_id,
+ dag_bag=self.dag_bag,
+ body=entity,
+ session=self.session,
+ update_mask=update_mask,
+ )
+
+ for key, _ in data.items():
+ if key == "new_state":
+ _patch_task_instance_state(
+ task_id=task_id,
+ dag_run_id=dag_run_id,
+ dag=dag,
+ task_instance_body=entity,
+ data=data,
+ session=self.session,
+ )
+ elif key == "note":
+ _patch_task_instance_note(
+ task_instance_body=entity,
+ tis=tis,
+ user=self.user,
+ )
- results.success.append(f"{dag_id}.{dag_run_id}.{task_id}[{map_index}]")
+
results.success.append(f"{dag_id}.{dag_run_id}.{task_id}[{map_index}]")
+ except HTTPException as e:
+ results.errors.append({"error": f"{e.detail}", "status_code":
e.status_code})
+ except ValidationError as e:
+ results.errors.append({"error": f"{e.errors()}"})
Review Comment:
This will avoid such update that we actually do not use in our endpoint.
##########
airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py:
##########
@@ -215,6 +215,29 @@ def validate_new_state(cls, ns: str | None) -> str:
return ns
+class PatchTaskGroupBody(StrictBaseModel):
+ """Request body for patching the state of all task instances in a task
group."""
+
+ new_state: TaskInstanceState
+ include_future: bool = False
+ include_past: bool = False
Review Comment:
Why are `include_upstream` and `include_downstream` removed?
##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -847,6 +850,103 @@ def _collect_relatives(run_id: str, direction:
Literal["upstream", "downstream"]
)
+@task_instances_router.patch(
+ "/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+ responses=create_openapi_http_exception_doc(
+ [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
+ ),
+ dependencies=[
+ Depends(action_logging()),
+ Depends(requires_access_dag(method="PUT",
access_entity=DagAccessEntity.TASK_INSTANCE)),
+ ],
+ operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+ dag_id: str,
+ dag_run_id: str,
+ group_id: str,
+ dag_bag: DagBagDep,
+ body: PatchTaskGroupBody,
+ session: SessionDep,
+ user: GetUserDep,
+) -> TaskInstanceCollectionResponse:
+ """Update the state of all task instances in a task group."""
+ _patch_task_group_state(
+ dag_id=dag_id,
+ dag_run_id=dag_run_id,
+ group_id=group_id,
+ body=body,
+ dag_bag=dag_bag,
+ user=user,
+ session=session,
+ )
+
+ # Collect all TIs for the task group to build the response
+ dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+ task_ids = _get_task_group_task_ids(dag, group_id)
+ tis = (
+ session.scalars(
+ select(TI)
+ .where(TI.dag_id == dag_id, TI.run_id == dag_run_id,
TI.task_id.in_(task_ids))
+ .join(TI.dag_run)
+ .options(joinedload(TI.rendered_task_instance_fields))
+ .options(joinedload(TI.dag_version))
+
.options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
+ )
+ .unique()
+ .all()
Review Comment:
This query could probably be avoided. we are fetching from the DB at
multiple times.
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -153,6 +158,99 @@ def _patch_task_instance_note(
ti.task_instance_note.user_id = user.get_id()
+def _get_task_group_task_ids(
+ dag: SerializedDAG,
+ group_id: str,
+) -> list[str]:
+ """
+ Get task ids that belong to a task group.
+
+ :param dag: The serialized DAG containing the task group.
+ :param group_id: The ID of the task group.
+ :return: List of task IDs in the group.
+ :raises HTTPException: If the task group is not found or has no tasks.
+ """
+ if not hasattr(dag, "task_group"):
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ f"DAG '{dag.dag_id}' does not have task groups",
+ )
+
+ task_groups = dag.task_group.get_task_group_dict()
+ task_group = task_groups.get(group_id)
+ if not task_group:
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ f"Task group '{group_id}' not found in DAG '{dag.dag_id}'",
+ )
+
+ task_ids = [task.task_id for task in task_group.iter_tasks()]
+ if not task_ids:
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ f"Task group '{group_id}' in DAG '{dag.dag_id}' has no tasks",
+ )
+
+ return task_ids
+
+
+def _patch_task_group_state(
+ *,
+ dag_id: str,
+ dag_run_id: str,
+ group_id: str,
+ body: PatchTaskGroupBody,
+ dag_bag: DagBagDep,
+ user: GetUserDep,
+ session: Session,
+) -> None:
+ """
+ Set the state of all task instances in a task group.
+
+ Uses BulkTaskInstanceService to update each task in the group.
+
+ :param dag_id: The DAG ID.
+ :param dag_run_id: The run_id of the DAG run.
+ :param group_id: The ID of the task group.
+ :param body: The request body with the new state and options.
+ :param dag_bag: The DAG bag for DAG resolution.
+ :param user: The authenticated user.
+ :param session: The database session.
+ """
+ dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+ task_ids = _get_task_group_task_ids(dag, group_id)
+
+ entities = [
+ BulkTaskInstanceBody(
+ task_id=task_id,
+ dag_id=dag_id,
+ dag_run_id=dag_run_id,
+ new_state=body.new_state,
+ include_future=body.include_future,
+ include_past=body.include_past,
+ )
+ for task_id in task_ids
+ ]
+
+ action = BulkUpdateAction(
+ action=BulkAction.UPDATE,
+ entities=entities,
+ update_mask=["new_state"],
+ action_on_non_existence=BulkActionNotOnExistence.SKIP,
+ )
+ results = BulkActionResponse()
+
+ service = BulkTaskInstanceService(
+ session=session,
+ request=BulkBody(actions=[]), # unused, but required by base class
+ dag_id=dag_id,
+ dag_run_id=dag_run_id,
+ dag_bag=dag_bag,
+ user=user,
+ )
+ service.handle_bulk_update(action, results)
Review Comment:
Maybe not going through the bulk update service is actually better.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]