yuseok89 commented on code in PR #66554:
URL: https://github.com/apache/airflow/pull/66554#discussion_r3234885331


##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py:
##########
@@ -86,3 +123,435 @@ async def wait(self) -> AsyncGenerator[str, None]:
             await asyncio.sleep(self.interval)
             yield await self._serialize_response(dag_run := await 
self._get_dag_run())
             yield "\n"
+
+
+def _format_dag_run_key(dag_id: str, dag_run_id: str) -> str:
+    return f"{dag_id}.{dag_run_id}"
+
+
+def _authorize_dag_run(
+    *,
+    session: Session,
+    user,
+    dag_id: str,
+    method: AuthMethod,
+    cache: dict[str, bool],
+) -> bool:
+    """
+    Return whether ``user`` may perform ``method`` on Dag runs of ``dag_id``.
+
+    The result is memoised in ``cache`` so a bulk request that touches many
+    runs of the same Dag only pays for one ``is_authorized_dag`` call per Dag.
+    """
+    if dag_id not in cache:
+        team_name = DagModel.get_team_name(dag_id, session=session)
+        cache[dag_id] = get_auth_manager().is_authorized_dag(
+            method=method,
+            access_entity=DagAccessEntity.RUN,
+            details=DagDetails(id=dag_id, team_name=team_name),
+            user=user,
+        )
+    return cache[dag_id]
+
+
+def _apply_state_change(
+    dag_run: DagRun,
+    new_state: DAGRunPatchStates,
+    dag: SerializedDAG,
+    session: Session,
+) -> None:
+    """Apply ``new_state`` to ``dag_run`` and fire the matching listener 
hook."""
+    if new_state == DAGRunPatchStates.SUCCESS:
+        set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
+        try:
+            get_listener_manager().hook.on_dag_run_success(dag_run=dag_run, 
msg="")
+        except Exception:
+            log.exception("error calling listener")
+    elif new_state == DAGRunPatchStates.QUEUED:
+        # Notification on queued is intentionally skipped; the scheduler emits
+        # the RUNNING notification instead.
+        set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
+    elif new_state == DAGRunPatchStates.FAILED:
+        set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
+        try:
+            get_listener_manager().hook.on_dag_run_failed(dag_run=dag_run, 
msg="")
+        except Exception:
+            log.exception("error calling listener")
+
+
+def _apply_note(dag_run: DagRun, note: str | None, user_id: str) -> None:
+    if dag_run.dag_run_note is None:
+        dag_run.note = (note, user_id)
+    else:
+        dag_run.dag_run_note.content = note
+        dag_run.dag_run_note.user_id = user_id
+
+
+def _validate_no_wildcard_in_resolved(
+    *,
+    dag_id: str,
+    dag_run_id: str,
+    results: BulkActionResponse,
+) -> bool:
+    if dag_id == "~" or dag_run_id == "~":
+        results.errors.append(
+            {
+                "error": (
+                    "When the path uses the ``~`` wildcard, ``dag_id`` and 
``dag_run_id`` must be "
+                    "specified in the body for each entity."
+                ),
+                "status_code": status.HTTP_400_BAD_REQUEST,
+            }
+        )
+        return False
+    return True
+
+
+def _validate_path_dag_id_match(
+    *,
+    path_dag_id: str,
+    entity_dag_id: str | None,
+    dag_run_id: str,
+    results: BulkActionResponse,
+) -> bool:
+    if path_dag_id != "~" and entity_dag_id is not None and entity_dag_id != 
path_dag_id:
+        results.errors.append(
+            {
+                "error": (
+                    f"Entity dag_id '{entity_dag_id}' does not match path 
dag_id '{path_dag_id}'. "
+                    "Use ``~`` in the path for cross-DAG bulk operations."
+                ),
+                "status_code": status.HTTP_400_BAD_REQUEST,
+                "dag_id": entity_dag_id,
+                "dag_run_id": dag_run_id,
+            }
+        )
+        return False
+    return True
+
+
+class BulkDagRunService(BulkService[BulkDagRunBody]):
+    """Service for handling bulk operations on Dag runs."""
+
+    def __init__(
+        self,
+        session: Session,
+        request: BulkBody[BulkDagRunBody],
+        dag_id: str,
+        dag_bag: DagBagDep,
+        user: GetUserDep,
+    ):
+        super().__init__(session, request)
+        self.dag_id = dag_id
+        self.dag_bag = dag_bag
+        self.user = user
+
+    def _resolve_identifiers(self, entity: str | BulkDagRunBody) -> tuple[str, 
str]:
+        """Return ``(dag_id, dag_run_id)`` for an entity, falling back to the 
path's ``dag_id``."""
+        if isinstance(entity, str):
+            return self.dag_id, entity
+        dag_id = entity.dag_id or self.dag_id
+        return dag_id, entity.dag_run_id
+
+    def _check_dag_authorization(
+        self,
+        dag_id: str,
+        method: AuthMethod,
+        action_name: str,
+        results: BulkActionResponse,
+        cache: dict[str, bool],
+    ) -> bool:
+        if not _authorize_dag_run(
+            session=self.session,
+            user=self.user,
+            dag_id=dag_id,
+            method=method,
+            cache=cache,
+        ):
+            results.errors.append(
+                {
+                    "error": f"User is not authorized to {action_name} Dag 
runs for DAG '{dag_id}'",
+                    "status_code": status.HTTP_403_FORBIDDEN,
+                }
+            )
+            return False
+        return True
+
+    def _fetch_dag_runs(
+        self,
+        keys: set[tuple[str, str]],
+    ) -> tuple[dict[tuple[str, str], DagRun], set[tuple[str, str]]]:
+        if not keys:
+            return {}, set()
+        keys_list = list(keys)
+        dag_runs = self.session.scalars(
+            select(DagRun)
+            .options(joinedload(DagRun.dag_model))
+            .where(
+                DagRun.dag_id.in_({k[0] for k in keys_list}),
+                DagRun.run_id.in_({k[1] for k in keys_list}),
+            )
+        ).all()
+        found = {(dr.dag_id, dr.run_id): dr for dr in dag_runs if (dr.dag_id, 
dr.run_id) in keys}
+        not_found = keys - set(found.keys())
+        return found, not_found
+
+    def handle_bulk_create(
+        self, action: BulkCreateAction[BulkDagRunBody], results: 
BulkActionResponse
+    ) -> None:
+        results.errors.append(
+            {
+                "error": "Dag runs bulk create is not supported via this 
endpoint; use the trigger Dag run endpoint instead.",
+                "status_code": status.HTTP_405_METHOD_NOT_ALLOWED,
+            }
+        )
+
+    def handle_bulk_update(
+        self, action: BulkUpdateAction[BulkDagRunBody], results: 
BulkActionResponse
+    ) -> None:
+        """Bulk update Dag runs (state and/or note)."""
+        update_mask = action.update_mask
+        auth_cache: dict[str, bool] = {}
+        keys: set[tuple[str, str]] = set()
+        entity_map: dict[tuple[str, str], BulkDagRunBody] = {}
+
+        for entity in action.entities:
+            if isinstance(entity, str):
+                results.errors.append(
+                    {
+                        "error": "Bulk update requires entities as objects, 
not strings.",
+                        "status_code": status.HTTP_400_BAD_REQUEST,
+                    }
+                )
+                continue
+            dag_id, dag_run_id = self._resolve_identifiers(entity)
+            if not _validate_no_wildcard_in_resolved(dag_id=dag_id, 
dag_run_id=dag_run_id, results=results):
+                continue
+            if not _validate_path_dag_id_match(
+                path_dag_id=self.dag_id,
+                entity_dag_id=entity.dag_id,
+                dag_run_id=dag_run_id,
+                results=results,
+            ):
+                continue
+            if not self._check_dag_authorization(dag_id, "PUT", 
action.action.value, results, auth_cache):
+                continue
+            keys.add((dag_id, dag_run_id))
+            entity_map[(dag_id, dag_run_id)] = entity
+
+        try:
+            found, not_found = self._fetch_dag_runs(keys)
+
+            if action.action_on_non_existence == BulkActionNotOnExistence.FAIL 
and not_found:
+                missing = [{"dag_id": d, "dag_run_id": r} for d, r in 
not_found]
+                raise HTTPException(
+                    status_code=status.HTTP_404_NOT_FOUND,
+                    detail=f"The Dag runs with these identifiers were not 
found: {missing}",
+                )
+
+            for key, dag_run in found.items():
+                entity = entity_map[key]
+                fields_to_update = entity.model_fields_set
+                if update_mask:
+                    fields_to_update = 
fields_to_update.intersection(update_mask)
+                fields_to_update = fields_to_update - {"dag_id", "dag_run_id"}
+                if not fields_to_update:
+                    continue
+
+                try:
+                    with self.session.begin_nested():
+                        dag = get_dag_for_run(self.dag_bag, dag_run, 
session=self.session)

Review Comment:
   Done.
   Added `_cached_dag_for_run` helper that memoises by `dag_id`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to