Lee-W commented on code in PR #68702:
URL: https://github.com/apache/airflow/pull/68702#discussion_r3453613908


##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py:
##########
@@ -150,6 +155,91 @@ def perform_clear_dag_run(
     return dag_run_cleared
 
 
+_TI_CHUNK_SIZE = 500
+
+
+def clear_partition_fields(
+    *,
+    dag: SerializedDAG,
+    body: ClearPartitionsBody,
+    dag_id: str,
+    session: Session,
+) -> tuple[int, int]:
+    """
+    Reset partition_key and partition_date to None on matching runs.
+
+    Returns (dag_runs_cleared, task_instances_cleared).
+    Mirrors ``airflow partitions clear`` column-reset behavior.
+    """
+    stmt = select(DagRun).where(DagRun.dag_id == dag_id)
+    if body.run_id is not None:
+        stmt = stmt.where(DagRun.run_id == body.run_id)
+    elif body.partition_key is not None:
+        stmt = stmt.where(DagRun.partition_key == body.partition_key)
+    else:
+        stmt = stmt.where(or_(DagRun.partition_key.is_not(None), 
DagRun.partition_date.is_not(None)))
+        if body.partition_date_start is not None:
+            lower = 
dag.timetable.resolve_day_bound(body.partition_date_start.date())
+            stmt = stmt.where(DagRun.partition_date >= lower)
+        if body.partition_date_end is not None:
+            upper = 
dag.timetable.resolve_day_bound(body.partition_date_end.date() + 
timedelta(days=1))
+            stmt = stmt.where(DagRun.partition_date < upper)
+    stmt = stmt.order_by(DagRun.partition_date, DagRun.run_id)
+
+    dag_runs_cleared = 0
+    # Buffers for batched TI fetching — mirrors _flush_buffer in 
partition_command.py
+    ti_buffer_run_ids: list[str] = []
+    ti_carry: list[TaskInstance] = []
+    tis_cleared_total = 0
+
+    def _flush_ti_buffer(*, drain: bool = False) -> int:
+        flushed = 0
+        if ti_buffer_run_ids:
+            chunk_tis = list(
+                
session.scalars(select(TaskInstance).where(TaskInstance.run_id.in_(ti_buffer_run_ids)))
+            )
+            ti_buffer_run_ids.clear()
+            ti_carry.extend(chunk_tis)
+        while len(ti_carry) >= _TI_CHUNK_SIZE:
+            slice_tis = ti_carry[:_TI_CHUNK_SIZE]
+            del ti_carry[:_TI_CHUNK_SIZE]
+            clear_task_instances(slice_tis, session=session)
+            flushed += len(slice_tis)
+        if drain and ti_carry:
+            clear_task_instances(ti_carry, session=session)
+            flushed += len(ti_carry)
+        return flushed
+
+    # For dry_run TI count
+    tis_dry_total = 0
+
+    for run in session.scalars(stmt).yield_per(100):
+        fields_already_cleared = run.partition_key is None and 
run.partition_date is None
+        if fields_already_cleared and not body.clear_task_instances:
+            continue
+        if not fields_already_cleared:
+            if not body.dry_run:
+                run.partition_key = None
+                run.partition_date = None
+            dag_runs_cleared += 1
+        if body.clear_task_instances:
+            if body.dry_run:
+                run_tis = 
list(session.scalars(select(TaskInstance).where(TaskInstance.run_id == 
run.run_id)))
+                tis_dry_total += len(run_tis)

Review Comment:
   replaced it with one call. thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to