pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2907088878


##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -153,6 +158,99 @@ def _patch_task_instance_note(
                 ti.task_instance_note.user_id = user.get_id()
 
 
+def _get_task_group_task_ids(
+    dag: SerializedDAG,
+    group_id: str,
+) -> list[str]:
+    """
+    Get task ids that belong to a task group.
+
+    :param dag: The serialized DAG containing the task group.
+    :param group_id: The ID of the task group.
+    :return: List of task IDs in the group.
+    :raises HTTPException: If the task group is not found or has no tasks.
+    """
+    if not hasattr(dag, "task_group"):
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND,
+            f"DAG '{dag.dag_id}' does not have task groups",
+        )
+
+    task_groups = dag.task_group.get_task_group_dict()
+    task_group = task_groups.get(group_id)
+    if not task_group:
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND,
+            f"Task group '{group_id}' not found in DAG '{dag.dag_id}'",
+        )
+
+    task_ids = [task.task_id for task in task_group.iter_tasks()]
+    if not task_ids:
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND,
+            f"Task group '{group_id}' in DAG '{dag.dag_id}' has no tasks",
+        )
+
+    return task_ids
+
+
+def _patch_task_group_state(
+    *,
+    dag_id: str,
+    dag_run_id: str,
+    group_id: str,
+    body: PatchTaskGroupBody,
+    dag_bag: DagBagDep,
+    user: GetUserDep,
+    session: Session,
+) -> None:
+    """
+    Set the state of all task instances in a task group.
+
+    Uses BulkTaskInstanceService to update each task in the group.
+
+    :param dag_id: The DAG ID.
+    :param dag_run_id: The run_id of the DAG run.
+    :param group_id: The ID of the task group.
+    :param body: The request body with the new state and options.
+    :param dag_bag: The DAG bag for DAG resolution.
+    :param user: The authenticated user.
+    :param session: The database session.
+    """
+    dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+    task_ids = _get_task_group_task_ids(dag, group_id)
+
+    entities = [
+        BulkTaskInstanceBody(
+            task_id=task_id,
+            dag_id=dag_id,
+            dag_run_id=dag_run_id,
+            new_state=body.new_state,
+            include_future=body.include_future,
+            include_past=body.include_past,
+        )
+        for task_id in task_ids
+    ]
+
+    action = BulkUpdateAction(
+        action=BulkAction.UPDATE,
+        entities=entities,
+        update_mask=["new_state"],
+        action_on_non_existence=BulkActionNotOnExistence.SKIP,
+    )
+    results = BulkActionResponse()
+
+    service = BulkTaskInstanceService(
+        session=session,
+        request=BulkBody(actions=[]),  # unused, but required by base class
+        dag_id=dag_id,
+        dag_run_id=dag_run_id,
+        dag_bag=dag_bag,
+        user=user,
+    )
+    service.handle_bulk_update(action, results)

Review Comment:
   Maybe not going through the bulk update service is actually better, the 
interface wasn't ment for this. And it makes you do weird stuff to actually 
plug into it. 
   
   The code for Updating a single TI is probably more re-usable and fitted to 
this use case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to