This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch worktree-modular-stargazing-lampson in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 2e24bbc0112ee356e50d43f72fc24a1c80a52982 Author: Jarek Potiuk <[email protected]> AuthorDate: Tue Mar 17 08:59:41 2026 +0100 TUI: add full-screen TUI for PR triage with inline diff panel --- .../src/airflow/jobs/scheduler_job_runner.py | 45 +- airflow-core/tests/unit/jobs/test_scheduler_job.py | 89 +- dev/breeze/doc/images/output_pr_auto-triage.svg | 2 +- dev/breeze/doc/images/output_pr_auto-triage.txt | 2 +- .../src/airflow_breeze/commands/pr_commands.py | 1145 +++++++++++--------- dev/breeze/src/airflow_breeze/utils/tui_display.py | 826 ++++++++++++++ .../providers/common/ai/utils/sql_validation.py | 2 +- .../src/airflow/sdk/execution_time/task_runner.py | 6 +- .../task_sdk/execution_time/test_task_runner.py | 4 +- 9 files changed, 1477 insertions(+), 644 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index f8c043f0c9f..b078cb183bc 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1838,6 +1838,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): if asset_triggered_dags: self._create_dag_runs_asset_triggered( dag_models=[d for d in asset_triggered_dags if d.dag_id not in partition_dag_ids], + triggered_date_by_dag=triggered_date_by_dag, session=session, ) @@ -2004,44 +2005,30 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): def _create_dag_runs_asset_triggered( self, - *, dag_models: Collection[DagModel], + triggered_date_by_dag: dict[str, datetime], session: Session, ) -> None: - """For Dags that are triggered by assets, create Dag runs.""" + """For DAGs that are triggered by assets, create dag runs.""" + triggered_dates: dict[str, DateTime] = { + dag_id: timezone.coerce_datetime(last_asset_event_time) + for dag_id, last_asset_event_time in triggered_date_by_dag.items() + } + for dag_model in dag_models: dag = self._get_current_dag(dag_id=dag_model.dag_id, session=session) if not dag: - self.log.error("Dag '%s' not found in serialized_dag table", dag_model.dag_id) + self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id) continue if not isinstance(dag.timetable, AssetTriggeredTimetable): self.log.error( - "Dag '%s' was asset-scheduled, but didn't have an AssetTriggeredTimetable!", + "DAG '%s' was asset-scheduled, but didn't have an AssetTriggeredTimetable!", dag_model.dag_id, ) continue - queued_adrqs = session.scalars( - with_row_locks( - select(AssetDagRunQueue) - .where(AssetDagRunQueue.target_dag_id == dag.dag_id) - .order_by(AssetDagRunQueue.created_at.desc()), - of=AssetDagRunQueue, - skip_locked=True, - key_share=False, - session=session, - ) - ).all() - # If another scheduler already locked these ADRQ rows, SKIP LOCKED makes this scheduler skip them. - if not queued_adrqs: - self.log.debug( - "Skipping asset-triggered DagRun creation for Dag '%s'; no queued assets remain.", - dag.dag_id, - ) - continue - - triggered_date: DateTime = timezone.coerce_datetime(queued_adrqs[0].created_at) + triggered_date = triggered_dates[dag.dag_id] cte = ( select(func.max(DagRun.run_after).label("previous_dag_run_run_after")) .where( @@ -2090,15 +2077,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): ) Stats.incr("asset.triggered_dagruns") dag_run.consumed_asset_events.extend(asset_events) - - # Delete only consumed ADRQ rows to avoid dropping newly queued events - # (e.g. DagRun triggered by asset A while a new event for asset B arrives). - adrq_pks = [(record.asset_id, record.target_dag_id) for record in queued_adrqs] - session.execute( - delete(AssetDagRunQueue).where( - tuple_(AssetDagRunQueue.asset_id, AssetDagRunQueue.target_dag_id).in_(adrq_pks) - ) - ) + session.execute(delete(AssetDagRunQueue).where(AssetDagRunQueue.target_dag_id == dag_run.dag_id)) def _lock_backfills(self, dag_runs: Collection[DagRun], session: Session) -> dict[int, Backfill]: """ diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 4511c333689..7c8d2ff4ffb 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -285,9 +285,9 @@ class TestSchedulerJob: def set_instance_attrs(self) -> Generator: # Speed up some tests by not running the tasks, just look at what we # enqueue! - self.null_exec: BaseExecutor = MockExecutor() + self.null_exec: MockExecutor | None = MockExecutor() yield - self.null_exec = None # type: ignore[assignment] + self.null_exec = None @pytest.fixture def mock_executors(self): @@ -4858,91 +4858,6 @@ class TestSchedulerJob: assert created_run.creating_job_id == scheduler_job.id - @pytest.mark.need_serialized_dag - def test_create_dag_runs_asset_triggered_skips_stale_triggered_date(self, session, dag_maker): - asset = Asset(uri="test://asset-for-stale-trigger-date", name="asset-for-stale-trigger-date") - with dag_maker(dag_id="asset-consumer-stale-trigger-date", schedule=[asset], session=session): - pass - dag_model = dag_maker.dag_model - asset_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset.uri)) - - queued_at = timezone.utcnow() - session.add(AssetDagRunQueue(target_dag_id=dag_model.dag_id, asset_id=asset_id, created_at=queued_at)) - session.flush() - - # Simulate another scheduler consuming ADRQ rows after we computed triggered_date_by_dag. - session.execute(delete(AssetDagRunQueue).where(AssetDagRunQueue.target_dag_id == dag_model.dag_id)) - session.flush() - - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, executors=[self.null_exec]) - self.job_runner._create_dag_runs_asset_triggered( - dag_models=[dag_model], - session=session, - ) - - # We do not create a new DagRun seems the ADRQ has already been consumed - assert session.scalars(select(DagRun).where(DagRun.dag_id == dag_model.dag_id)).one_or_none() is None - - @pytest.mark.need_serialized_dag - def test_create_dag_runs_asset_triggered_deletes_only_selected_adrq_rows( - self, session: Session, dag_maker - ): - asset_1 = Asset("ready-to-trigger-a-Dag-run") - asset_2 = Asset("should-still-exist-after-a-Dag-run-created") - with dag_maker(dag_id="asset-consumer-delete-selected", schedule=asset_1 | asset_2, session=session): - pass - dag_model = dag_maker.dag_model - asset_1_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset_1.name)) - asset_2_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset_2.name)) - - session.add_all( - [ - # The ADRQ that should triggers the Dag run creation - AssetDagRunQueue( - asset_id=asset_1_id, target_dag_id=dag_model.dag_id, created_at=timezone.utcnow() - ), - # The ADRQ that arrives after the Dag run creation but before ADRQ clean up - # This situation is simluarted by _lock_only_selected_asset below - AssetDagRunQueue( - asset_id=asset_2_id, target_dag_id=dag_model.dag_id, created_at=timezone.utcnow() - ), - ] - ) - session.flush() - - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, executors=[self.null_exec]) - - def _lock_only_selected_asset(query, **_): - # Simulate SKIP LOCKED behavior where this scheduler can only consume one ADRQ row. - return query.where(AssetDagRunQueue.asset_id == asset_1_id) - - with patch("airflow.jobs.scheduler_job_runner.with_row_locks", side_effect=_lock_only_selected_asset): - self.job_runner._create_dag_runs_asset_triggered( - dag_models=[dag_model], - session=session, - ) - - dr = session.scalars(select(DagRun).where(DagRun.dag_id == dag_model.dag_id)).one_or_none() - - assert dr is not None - - adrq_1 = session.scalars( - select(AssetDagRunQueue).where( - AssetDagRunQueue.target_dag_id == dag_model.dag_id, - AssetDagRunQueue.asset_id == asset_1_id, - ) - ).one_or_none() - assert adrq_1 is None - adrq_2 = session.scalars( - select(AssetDagRunQueue).where( - AssetDagRunQueue.target_dag_id == dag_model.dag_id, - AssetDagRunQueue.asset_id == asset_2_id, - ) - ).one_or_none() - assert adrq_2 is not None - @pytest.mark.need_serialized_dag def test_create_dag_runs_asset_alias_with_asset_event_attached(self, session, dag_maker): """ diff --git a/dev/breeze/doc/images/output_pr_auto-triage.svg b/dev/breeze/doc/images/output_pr_auto-triage.svg index c3a56f415f0..b94fd60b346 100644 --- a/dev/breeze/doc/images/output_pr_auto-triage.svg +++ b/dev/breeze/doc/images/output_pr_auto-triage.svg @@ -426,7 +426,7 @@ </text><text class="breeze-pr-auto-triage-r5" x="0" y="1996.4" textLength="12.2" clip-path="url(#breeze-pr-auto-triage-line-81)">│</text><text class="breeze-pr-auto-triage-r6" x="256.2" y="1996.4" textLength="1159" clip-path="url(#breeze-pr-auto-triage-line-81)">>claude/claude-sonnet-4-6< | claude/claude-opus-4-20250514 | claude/claude-sonnet-4-20250514 | </text><text class="breeze-pr-auto-triage-r5" x="1451.8" y="1996.4" textLength="12.2" clip-path="u [...] </text><text class="breeze-pr-auto-triage-r5" x="0" y="2020.8" textLength="12.2" clip-path="url(#breeze-pr-auto-triage-line-82)">│</text><text class="breeze-pr-auto-triage-r6" x="256.2" y="2020.8" textLength="1110.2" clip-path="url(#breeze-pr-auto-triage-line-82)">claude/claude-haiku-4-5-20251001 | claude/sonnet | claude/opus | claude/haiku | codex/o3 | </text><text class="breeze-pr-auto-triage-r5" x="1451.8" y="2020.8" textLength="12.2" [...] </text><text class="breeze-pr-auto-triage-r5" x="0" y="2045.2" textLength="12.2" clip-path="url(#breeze-pr-auto-triage-line-83)">│</text><text class="breeze-pr-auto-triage-r6" x="256.2" y="2045.2" textLength="366" clip-path="url(#breeze-pr-auto-triage-line-83)">codex/o4-mini | codex/gpt-4.1)</text><text class="breeze-pr-auto-triage-r5" x="1451.8" y="2045.2" textLength="12.2" clip-path="url(#breeze-pr-auto-triage-line-83)">│</text><text class="breeze-pr-auto-triage-r1" x="1464" [...] -</text><text class="breeze-pr-auto-triage-r5" x="0" y="2069.6" textLength="12.2" clip-path="url(#breeze-pr-auto-triage-line-84)">│</text><text class="breeze-pr-auto-triage-r4" x="24.4" y="2069.6" textLength="207.4" clip-path="url(#breeze-pr-auto-triage-line-84)">--llm-concurrency</text><text class="breeze-pr-auto-triage-r1" x="256.2" y="2069.6" textLength="524.6" clip-path="url(#breeze-pr-auto-triage-line-84)">Number of concurrent LLM assessment calls. </tex [...] +</text><text class="breeze-pr-auto-triage-r5" x="0" y="2069.6" textLength="12.2" clip-path="url(#breeze-pr-auto-triage-line-84)">│</text><text class="breeze-pr-auto-triage-r4" x="24.4" y="2069.6" textLength="207.4" clip-path="url(#breeze-pr-auto-triage-line-84)">--llm-concurrency</text><text class="breeze-pr-auto-triage-r1" x="256.2" y="2069.6" textLength="524.6" clip-path="url(#breeze-pr-auto-triage-line-84)">Number of concurrent LLM assessment calls. </tex [...] </text><text class="breeze-pr-auto-triage-r5" x="0" y="2094" textLength="12.2" clip-path="url(#breeze-pr-auto-triage-line-85)">│</text><text class="breeze-pr-auto-triage-r4" x="24.4" y="2094" textLength="207.4" clip-path="url(#breeze-pr-auto-triage-line-85)">--clear-llm-cache</text><text class="breeze-pr-auto-triage-r1" x="256.2" y="2094" textLength="658.8" clip-path="url(#breeze-pr-auto-triage-line-85)">Clear the LLM review and triage caches before [...] </text><text class="breeze-pr-auto-triage-r5" x="0" y="2118.4" textLength="1464" clip-path="url(#breeze-pr-auto-triage-line-86)">╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯</text><text class="breeze-pr-auto-triage-r1" x="1464" y="2118.4" textLength="12.2" clip-path="url(#breeze-pr-auto-triage-line-86)"> </text><text class="breeze-pr-auto-triage-r5" x="0" y="2142.8" textLength="24.4" clip-path="url(#breeze-pr-auto-triage-line-87)">╭─</text><text class="breeze-pr-auto-triage-r5" x="24.4" y="2142.8" textLength="195.2" clip-path="url(#breeze-pr-auto-triage-line-87)"> Action options </text><text class="breeze-pr-auto-triage-r5" x="219.6" y="2142.8" textLength="1220" clip-path="url(#breeze-pr-auto-triage-line-87)">──────────────────────────────────────────────────────────────── [...] diff --git a/dev/breeze/doc/images/output_pr_auto-triage.txt b/dev/breeze/doc/images/output_pr_auto-triage.txt index 66cedf08b8d..dc56649fa28 100644 --- a/dev/breeze/doc/images/output_pr_auto-triage.txt +++ b/dev/breeze/doc/images/output_pr_auto-triage.txt @@ -1 +1 @@ -8107cecf84483d0900f542688f0d0247 +81f90099329d48933ff6824ffbf3e2b9 diff --git a/dev/breeze/src/airflow_breeze/commands/pr_commands.py b/dev/breeze/src/airflow_breeze/commands/pr_commands.py index e49706c5865..95fd841b8bc 100644 --- a/dev/breeze/src/airflow_breeze/commands/pr_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/pr_commands.py @@ -43,6 +43,7 @@ from airflow_breeze.utils.confirm import ( Answer, ContinueAction, TriageAction, + _has_tty, prompt_space_continue, prompt_triage_action, user_confirm, @@ -142,45 +143,6 @@ def _save_assessment_cache(github_repository: str, pr_number: int, head_sha: str cache_file.write_text(json.dumps({"head_sha": head_sha, "assessment": assessment}, indent=2)) -def _get_log_cache_dir(github_repository: str) -> Path: - """Return the directory for storing cached CI log snippets.""" - from airflow_breeze.utils.path_utils import BUILD_CACHE_PATH - - safe_name = github_repository.replace("/", "_") - cache_dir = Path(BUILD_CACHE_PATH) / "log_cache" / safe_name - cache_dir.mkdir(parents=True, exist_ok=True) - return cache_dir - - -def _get_cached_log_snippets( - github_repository: str, head_sha: str, failed_check_names: list[str] -) -> dict[str, dict] | None: - """Load cached CI log snippets if they exist and match the commit hash and check names.""" - cache_file = _get_log_cache_dir(github_repository) / f"sha_{head_sha}.json" - if not cache_file.exists(): - return None - try: - data = json.loads(cache_file.read_text()) - if data.get("head_sha") != head_sha: - return None - cached_snippets = data.get("snippets", {}) - # Only return cache if it covers all requested check names - if all(name in cached_snippets for name in failed_check_names): - return cached_snippets - return None - except (json.JSONDecodeError, KeyError, OSError): - return None - - -def _save_log_cache(github_repository: str, head_sha: str, snippets: dict[str, Any]) -> None: - """Save CI log snippets to cache keyed by commit SHA.""" - cache_file = _get_log_cache_dir(github_repository) / f"sha_{head_sha}.json" - serializable = { - name: {"snippet": info.snippet, "job_url": info.job_url} for name, info in snippets.items() - } - cache_file.write_text(json.dumps({"head_sha": head_sha, "snippets": serializable}, indent=2)) - - _STATUS_CACHE_TTL_SECONDS = 4 * 3600 # 4 hours @@ -247,62 +209,15 @@ def _cached_fetch_recent_pr_failures( def _cached_fetch_main_canary_builds( token: str, github_repository: str, *, branch: str = "main", count: int = 4 ) -> list[dict]: - """Return cached canary build data with failed jobs pre-fetched. - - The cache stores builds together with their failed jobs so that - ``_display_canary_builds_status`` can render instantly without - additional API calls. - """ + """Return cached canary build data, fetching fresh data when the cache expires.""" cache_key = f"canary_builds_{branch}" cached = _get_cached_status(github_repository, cache_key) if cached is not None: - # Verify cached builds have failed-jobs data; if any failed build is - # missing the key, invalidate the cache and re-fetch. - needs_refetch = any(b.get("conclusion") == "failure" and "_failed_jobs" not in b for b in cached) - if not needs_refetch: - get_console().print("[dim]Using cached canary build data (expires after 4 h).[/]") - return cached - get_console().print("[dim]Cached canary builds missing failed-job details, re-fetching...[/]") - - get_console().print("[info]Fetching canary builds from GitHub...[/]") - builds = _fetch_main_canary_builds(token, github_repository, branch=branch, count=count) - - # Pre-fetch failed jobs for all failed builds in parallel so they are - # available in the cache and don't block display later. - failed_builds = [b for b in builds if b.get("conclusion") == "failure" and b.get("id")] - if failed_builds: - from concurrent.futures import ThreadPoolExecutor, as_completed - - get_console().print( - f"[info]Fetching failed jobs for {len(failed_builds)} " - f"failed {'builds' if len(failed_builds) != 1 else 'build'}...[/]" - ) - with ThreadPoolExecutor(max_workers=min(len(failed_builds), 4)) as executor: - futures = { - executor.submit(_fetch_failed_jobs_for_run, token, github_repository, b["id"]): b["id"] - for b in failed_builds - } - jobs_by_run: dict[int, list[dict]] = {} - done_count = 0 - for future in as_completed(futures): - run_id = futures[future] - done_count += 1 - try: - jobs_by_run[run_id] = future.result() - except Exception: - jobs_by_run[run_id] = [] - get_console().print( - f" [dim]({done_count}/{len(failed_builds)}) fetched jobs for run {run_id}[/]" - ) - - # Embed the failed-jobs list directly in each build dict so the - # display function can use it without making extra API calls. - for build in builds: - if build.get("id") in jobs_by_run: - build["_failed_jobs"] = jobs_by_run[build["id"]] - - _save_status_cache(github_repository, cache_key, builds) - return builds + get_console().print("[dim]Using cached canary build data (expires after 4 h).[/]") + return cached + result = _fetch_main_canary_builds(token, github_repository, branch=branch, count=count) + _save_status_cache(github_repository, cache_key, result) + return result def _cached_assess_pr( @@ -715,19 +630,6 @@ class StaleReviewInfo: author_pinged_reviewer: bool # whether the author mentioned the reviewer after the review -@dataclass -class BatchPrefetchResult: - """Result from a background prefetch of the next page of PRs.""" - - all_prs: list[PRData] - has_next_page: bool - next_cursor: str | None - candidate_prs: list[PRData] - accepted_prs: list[PRData] - triaged_classification: dict[str, set[int]] - reclassified_count: int - - @dataclass class ReviewComment: """A single line-level review comment proposed by the LLM.""" @@ -2707,8 +2609,8 @@ def _launch_background_log_fetching( pr_word = "PRs" if len(prs_with_failures) != 1 else "PR" get_console().print( - f"[info]Downloading CI failure logs for {len(prs_with_failures)} {pr_word} " - f"in background (concurrency: {llm_concurrency}).[/]" + f"[info]Launched CI log fetching for {len(prs_with_failures)} {pr_word} " + f"with failures in background (concurrency: {llm_concurrency}).[/]" ) return log_futures @@ -3154,18 +3056,34 @@ def _prompt_and_execute_flagged_pr( if log_snippets: _display_log_snippets_panel(log_snippets, pr=pr) else: - # Logs are still being fetched — wait automatically - get_console().print(" [dim]Downloading CI failure logs...[/]") + # Logs are still being fetched — offer user to wait or cancel + from airflow_breeze.utils.confirm import _read_char + + get_console().print(" [dim]CI failure logs are still being fetched in the background...[/]") + get_console().print( + " Press any key to [bold]wait[/] or [bold]\\[c]ancel[/] log retrieval for this PR: ", + end="", + ) try: - log_snippets = log_future.result(timeout=120) - if log_snippets: - _display_log_snippets_panel(log_snippets, pr=pr) - except TimeoutError: - get_console().print(" [warning]CI log retrieval timed out.[/]") - except Exception: - get_console().print(" [warning]CI log retrieval failed.[/]") + ch = _read_char() + except (KeyboardInterrupt, EOFError): + ch = "c" + get_console().print(ch if len(ch) == 1 else "") + + if ch.lower() != "c": + get_console().print(" [dim]Waiting for CI logs...[/]") + try: + log_snippets = log_future.result(timeout=120) + if log_snippets: + _display_log_snippets_panel(log_snippets, pr=pr) + except TimeoutError: + get_console().print(" [warning]CI log retrieval timed out.[/]") + except Exception: + get_console().print(" [warning]CI log retrieval failed.[/]") + else: + get_console().print(" [dim]Skipping CI log display for this PR.[/]") else: - # No background future — fetch inline as fallback (progress shown by the function) + # No background future — fetch inline as fallback log_snippets = _fetch_failed_job_log_snippets( ctx.token, ctx.github_repository, pr.head_sha, pr.failed_checks ) @@ -3374,6 +3292,398 @@ def _display_pr_overview_table( console_print() +def _run_tui_triage( + ctx: TriageContext, + all_prs: list[PRData], + *, + pending_approval: list[PRData], + det_flagged_prs: list[tuple[PRData, PRAssessment]], + llm_candidates: list[PRData], + passing_prs: list[PRData], + accepted_prs: list[PRData], + already_triaged_nums: set[int], + mode_desc: str = "", + selection_criteria: str = "", +) -> None: + """Run the full-screen TUI triage interface. + + Displays all PRs in a full-screen view. When the user selects a PR, + drops into the existing review flow for that PR, then returns to the TUI. + """ + import webbrowser + + from airflow_breeze.utils.confirm import _read_char + from airflow_breeze.utils.tui_display import ( + PRCategory, + PRListEntry, + TriageTUI, + TUIAction, + ) + + # Build categorization sets + pending_nums = {pr.number for pr in pending_approval} + flagged_nums = {pr.number for pr, _ in det_flagged_prs} + # LLM flagged will be populated as they arrive + llm_flagged_nums: set[int] = set() + passing_nums = {pr.number for pr in passing_prs} + + # Map PR number to assessment for flagged PRs + assessment_map: dict[int, PRAssessment] = {pr.number: asmt for pr, asmt in det_flagged_prs} + + # Build entries + entries: list[PRListEntry] = [] + pr_map: dict[int, PRData] = {} + for pr in all_prs: + if pr.author_association in _COLLABORATOR_ASSOCIATIONS: + continue + pr_map[pr.number] = pr + + if pr.number in flagged_nums: + cat = PRCategory.FLAGGED + elif pr.number in pending_nums: + cat = PRCategory.WORKFLOW_APPROVAL + elif pr.number in passing_nums: + cat = PRCategory.PASSING + elif pr.number in already_triaged_nums: + cat = PRCategory.ALREADY_TRIAGED + else: + cat = PRCategory.SKIPPED + entries.append(PRListEntry(pr, cat)) + + # Sort: actionable first, workflow approval near the end + _ORDER = { + PRCategory.FLAGGED: 0, + PRCategory.LLM_FLAGGED: 1, + PRCategory.PASSING: 2, + PRCategory.STALE_REVIEW: 3, + PRCategory.WORKFLOW_APPROVAL: 4, + PRCategory.ALREADY_TRIAGED: 5, + PRCategory.SKIPPED: 6, + } + entries.sort( + key=lambda e: ( + _ORDER.get(e.category, 99), + e.pr.author_login.lower(), + e.pr.number, + ) + ) + + tui = TriageTUI( + title="Auto-Triage", + mode_desc=mode_desc, + github_repository=ctx.github_repository, + selection_criteria=selection_criteria, + ) + tui.set_entries(entries) + + # Show diff panel by default for the first PR + if entries: + first_pr = entries[0].pr + diff_text = _fetch_pr_diff(ctx.token, ctx.github_repository, first_pr.number) + if diff_text: + tui.set_diff(first_pr.number, diff_text) + else: + tui.set_diff(first_pr.number, f"Could not fetch diff. Review at: {first_pr.url}/files") + + while not ctx.stats.quit_early: + # Collect LLM results if available + ctx.collect_llm_progress() + + # Update LLM-flagged entries + for entry in entries: + pr = entry.pr + if pr.number in ctx.llm_assessments and entry.category == PRCategory.SKIPPED: + entry.category = PRCategory.LLM_FLAGGED + llm_flagged_nums.add(pr.number) + assessment_map[pr.number] = ctx.llm_assessments[pr.number] + elif pr.number in {p.number for p in ctx.llm_passing} and entry.category == PRCategory.SKIPPED: + entry.category = PRCategory.PASSING + passing_nums.add(pr.number) + + entry, action = tui.run_interactive() + + # Auto-fetch diff when cursor moves to a different PR + if tui.cursor_changed() and tui.needs_diff_fetch(): + current_entry = tui.get_selected_entry() + if current_entry: + diff_text = _fetch_pr_diff(ctx.token, ctx.github_repository, current_entry.pr.number) + if diff_text: + tui.set_diff(current_entry.pr.number, diff_text) + else: + tui.set_diff( + current_entry.pr.number, + f"Could not fetch diff. Review at: {current_entry.pr.url}/files", + ) + + if action == TUIAction.QUIT: + ctx.stats.quit_early = True + break + + if action in ( + TUIAction.UP, + TUIAction.DOWN, + TUIAction.PAGE_UP, + TUIAction.PAGE_DOWN, + TUIAction.NEXT_PAGE, + TUIAction.PREV_PAGE, + TUIAction.NEXT_SECTION, + TUIAction.TOGGLE_SELECT, + TUIAction.SHOW_DIFF, + ): + continue + + if action == TUIAction.APPROVE_SELECTED: + selected_entries = tui.get_selected_entries() + if not selected_entries: + continue + # Batch approve all selected workflow PRs + get_console().clear() + get_console().print() + get_console().rule("[bold green]Batch workflow approval[/]", style="green") + get_console().print( + f"\n[info]Approving workflows for {len(selected_entries)} " + f"{'PRs' if len(selected_entries) != 1 else 'PR'}:[/]\n" + ) + for sel_entry in selected_entries: + sel_pr = sel_entry.pr + get_console().print(f" {_pr_link(sel_pr)} {sel_pr.title} [dim]by {sel_pr.author_login}[/]") + get_console().print() + + if not ctx.dry_run: + for sel_entry in selected_entries: + sel_pr = sel_entry.pr + pending_runs = _find_pending_workflow_runs( + ctx.token, ctx.github_repository, sel_pr.head_sha + ) + if pending_runs: + approved = _approve_workflow_runs(ctx.token, ctx.github_repository, pending_runs) + if approved: + get_console().print( + f" [success]Approved {approved} workflow " + f"{'runs' if approved != 1 else 'run'} for " + f"PR {_pr_link(sel_pr)}.[/]" + ) + ctx.stats.total_workflows_approved += 1 + sel_entry.action_taken = "approved" + else: + get_console().print( + f" [error]Failed to approve workflow runs for PR {_pr_link(sel_pr)}.[/]" + ) + else: + # Try rerunning completed runs + if sel_pr.head_sha: + completed_runs = _find_workflow_runs_by_status( + ctx.token, ctx.github_repository, sel_pr.head_sha, "completed" + ) + rerun_count = 0 + if completed_runs: + for run in completed_runs: + if _rerun_workflow_run(ctx.token, ctx.github_repository, run): + rerun_count += 1 + if rerun_count: + get_console().print( + f" [success]Rerun {rerun_count} workflow " + f"{'runs' if rerun_count != 1 else 'run'} for " + f"PR {_pr_link(sel_pr)}.[/]" + ) + ctx.stats.total_rerun += 1 + sel_entry.action_taken = "rerun" + else: + get_console().print( + f" [warning]No workflow runs found for PR {_pr_link(sel_pr)}.[/]" + ) + sel_entry.selected = False + else: + get_console().print("[warning]Dry run — skipping batch approval.[/]") + for sel_entry in selected_entries: + sel_entry.selected = False + + get_console().print("\n[dim]Press any key to return to TUI...[/]") + _read_char() + continue + + if entry is None: + continue + + pr = entry.pr + + if action == TUIAction.OPEN: + webbrowser.open(pr.url) + continue + + if action == TUIAction.SKIP: + # Already marked as skipped by the TUI + ctx.stats.total_skipped_action += 1 + continue + + if action == TUIAction.SELECT: + # Drop into the detailed review for this PR + get_console().clear() + get_console().print() + + if entry.category == PRCategory.FLAGGED or entry.category == PRCategory.LLM_FLAGGED: + assessment = assessment_map.get(pr.number) + if assessment: + _prompt_and_execute_flagged_pr(ctx, pr, assessment) + # Determine what action was taken based on stats changes + entry.action_taken = _infer_last_action(ctx.stats) + else: + get_console().print(f"[warning]No assessment available for PR #{pr.number}[/]") + get_console().print("[dim]Press any key to return...[/]") + _read_char() + elif entry.category == PRCategory.WORKFLOW_APPROVAL: + _review_single_workflow_pr(ctx, pr) + entry.action_taken = _infer_last_action(ctx.stats) + elif entry.category == PRCategory.PASSING: + author_profile = _fetch_author_profile(ctx.token, pr.author_login, ctx.github_repository) + _display_pr_info_panels(pr, author_profile) + console_print("[success]This looks like a PR that is ready for review.[/]") + + if not ctx.dry_run: + act = prompt_triage_action( + f"Action for PR {_pr_link(pr)}?", + default=TriageAction.READY, + forced_answer=ctx.answer_triage, + exclude={TriageAction.DRAFT} if pr.is_draft else None, + pr_url=pr.url, + token=ctx.token, + github_repository=ctx.github_repository, + pr_number=pr.number, + ) + if act == TriageAction.QUIT: + ctx.stats.quit_early = True + elif act == TriageAction.READY: + _execute_triage_action(ctx, pr, act, draft_comment="", close_comment="") + entry.action_taken = "ready" + elif act == TriageAction.SKIP: + entry.action_taken = "skipped" + else: + _execute_triage_action(ctx, pr, act, draft_comment="", close_comment="") + entry.action_taken = act.value + elif entry.category == PRCategory.ALREADY_TRIAGED: + get_console().print( + f"[dim]PR #{pr.number} was already triaged. Press any key to return...[/]" + ) + _read_char() + else: + get_console().print( + f"[dim]PR #{pr.number} — no action available. Press any key to return...[/]" + ) + _read_char() + + if ctx.stats.quit_early: + break + + # Move cursor to next entry after taking action + if entry.action_taken: + tui.move_cursor(1) + + # Final clear + get_console().clear() + + +def _infer_last_action(stats: TriageStats) -> str: + """Try to infer the last action from stats changes. + + This is a heuristic — we check which stat counter was last incremented. + """ + # We can't perfectly determine this, but we can use the stats object fields + # Just return a generic marker + if stats.total_converted > getattr(stats, "_prev_converted", 0): + stats._prev_converted = stats.total_converted # type: ignore[attr-defined] + return "drafted" + if stats.total_commented > getattr(stats, "_prev_commented", 0): + stats._prev_commented = stats.total_commented # type: ignore[attr-defined] + return "commented" + if stats.total_closed > getattr(stats, "_prev_closed", 0): + stats._prev_closed = stats.total_closed # type: ignore[attr-defined] + return "closed" + if stats.total_rebased > getattr(stats, "_prev_rebased", 0): + stats._prev_rebased = stats.total_rebased # type: ignore[attr-defined] + return "rebased" + if stats.total_rerun > getattr(stats, "_prev_rerun", 0): + stats._prev_rerun = stats.total_rerun # type: ignore[attr-defined] + return "rerun" + if stats.total_ready > getattr(stats, "_prev_ready", 0): + stats._prev_ready = stats.total_ready # type: ignore[attr-defined] + return "ready" + if stats.total_skipped_action > getattr(stats, "_prev_skipped", 0): + stats._prev_skipped = stats.total_skipped_action # type: ignore[attr-defined] + return "skipped" + return "" + + +def _review_single_workflow_pr(ctx: TriageContext, pr: PRData) -> None: + """Review a single PR that needs workflow approval (used by TUI mode).""" + author_profile = _fetch_author_profile(ctx.token, pr.author_login, ctx.github_repository) + pending_runs = _find_pending_workflow_runs(ctx.token, ctx.github_repository, pr.head_sha) + + check_counts: dict[str, int] = {} + if pr.head_sha: + check_counts = _fetch_check_status_counts(ctx.token, ctx.github_repository, pr.head_sha) + + _display_workflow_approval_panel(pr, author_profile, pending_runs, check_counts) + + if ctx.dry_run: + console_print("[warning]Dry run — skipping workflow approval.[/]") + return + + if not pending_runs: + console_print( + f" [info]No pending workflow runs for PR {_pr_link(pr)}. " + f"Attempting to rerun completed workflows...[/]" + ) + default_action = TriageAction.RERUN + else: + default_action = TriageAction.RERUN + + action = prompt_triage_action( + f"Action for PR {_pr_link(pr)}?", + default=default_action, + forced_answer=ctx.answer_triage, + exclude={TriageAction.DRAFT} if pr.is_draft else None, + pr_url=pr.url, + token=ctx.token, + github_repository=ctx.github_repository, + pr_number=pr.number, + ) + + if action == TriageAction.QUIT: + ctx.stats.quit_early = True + return + if action == TriageAction.SKIP: + console_print(f" [info]Skipping PR {_pr_link(pr)} — no action taken.[/]") + return + + if action == TriageAction.RERUN: + if pending_runs: + approved = _approve_workflow_runs(ctx.token, ctx.github_repository, pending_runs) + if approved: + console_print( + f" [success]Approved {approved} workflow " + f"{'runs' if approved != 1 else 'run'} for PR {_pr_link(pr)}.[/]" + ) + ctx.stats.total_workflows_approved += 1 + else: + # Try rerunning completed runs + if pr.head_sha: + completed_runs = _find_workflow_runs_by_status( + ctx.token, ctx.github_repository, pr.head_sha, "completed" + ) + rerun_count = 0 + if completed_runs: + for run in completed_runs: + if _rerun_workflow_run(ctx.token, ctx.github_repository, run): + console_print(f" [success]Rerun triggered for: {run.get('name', run['id'])}[/]") + rerun_count += 1 + if rerun_count: + ctx.stats.total_rerun += 1 + else: + console_print(f" [warning]No workflow runs found to rerun for PR {_pr_link(pr)}.[/]") + else: + _execute_triage_action(ctx, pr, action, draft_comment="", close_comment="") + + def _filter_candidate_prs( all_prs: list[PRData], *, @@ -3479,178 +3789,42 @@ def _enrich_candidate_details( if not candidate_prs: return - n = len(candidate_prs) - pr_word = "PRs" if n != 1 else "PR" - total_steps = 2 + (1 if run_api else 0) - step = 0 - - step += 1 - t_step = time.monotonic() - console_print(f" [info][{step}/{total_steps}] Fetching check details for {n} candidate {pr_word}...[/]") + console_print( + f"[info]Fetching check details for {len(candidate_prs)} " + f"{'PRs' if len(candidate_prs) != 1 else 'PR'}...[/]" + ) _fetch_check_details_batch(token, github_repository, candidate_prs) for pr in candidate_prs: if pr.checks_state == "FAILURE" and not pr.failed_checks and pr.head_sha: console_print( - f" [dim]Fetching full check details for PR {_pr_link(pr)} " + f" [dim]Fetching full check details for PR {_pr_link(pr)} " f"(failures beyond first 100 checks)...[/]" ) pr.failed_checks = _fetch_failed_checks(token, github_repository, pr.head_sha) - console_print(f" [dim]done ({_fmt_duration(time.monotonic() - t_step)})[/]") - step += 1 - t_step = time.monotonic() unknown_count = sum(1 for pr in candidate_prs if pr.mergeable == "UNKNOWN") if unknown_count: console_print( - f" [info][{step}/{total_steps}] Resolving merge conflict status " - f"for {unknown_count} {pr_word}...[/]" + f"[info]Resolving merge conflict status for {unknown_count} " + f"{'PRs' if unknown_count != 1 else 'PR'} with unknown status...[/]" ) resolved = _resolve_unknown_mergeable(token, github_repository, candidate_prs) remaining = unknown_count - resolved if remaining: console_print( - f" [dim]{resolved} resolved, {remaining} still unknown " - f"({_fmt_duration(time.monotonic() - t_step)})[/]" + f" [dim]{resolved} resolved, {remaining} still unknown " + f"(GitHub hasn't computed mergeability yet).[/]" ) else: - console_print(f" [dim]All {resolved} resolved ({_fmt_duration(time.monotonic() - t_step)})[/]") - else: - console_print(f" [info][{step}/{total_steps}] Merge conflict status: all known (skip)[/]") + console_print(f" [dim]All {resolved} resolved.[/]") if run_api: - step += 1 - t_step = time.monotonic() console_print( - f" [info][{step}/{total_steps}] Fetching review thread details for {n} candidate {pr_word}...[/]" + f"[info]Fetching review thread details for {len(candidate_prs)} " + f"{'PRs' if len(candidate_prs) != 1 else 'PR'}...[/]" ) _fetch_unresolved_comments_batch(token, github_repository, candidate_prs) - console_print(f" [dim]done ({_fmt_duration(time.monotonic() - t_step)})[/]") - - -def _prefetch_next_batch( - *, - token: str, - github_repository: str, - exact_labels: tuple[str, ...], - exact_exclude_labels: tuple[str, ...], - filter_user: str | None, - sort: str, - batch_size: int, - created_after: str | None, - created_before: str | None, - updated_after: str | None, - updated_before: str | None, - review_requested_user: str | None, - next_cursor: str | None, - wildcard_labels: list[str], - wildcard_exclude_labels: list[str], - include_collaborators: bool, - include_drafts: bool, - checks_state: str, - min_commits_behind: int, - max_num: int, - viewer_login: str, -) -> BatchPrefetchResult | None: - """Prefetch and prepare the next page of PRs in a background thread. - - Performs GraphQL fetch, wildcard filtering, commits-behind resolution, - mergeable status resolution, NOT_RUN reclassification, candidate filtering, - and triage classification — everything up to the point where interactive - review begins. - - Returns None if no PRs are found. - """ - from fnmatch import fnmatch - - all_prs, has_next_page, new_cursor = _fetch_prs_graphql( - token, - github_repository, - labels=exact_labels, - exclude_labels=exact_exclude_labels, - filter_user=filter_user, - sort=sort, - batch_size=batch_size, - created_after=created_after, - created_before=created_before, - updated_after=updated_after, - updated_before=updated_before, - review_requested=review_requested_user, - after_cursor=next_cursor, - ) - if not all_prs: - return None - - # Apply wildcard label filters client-side - if wildcard_labels: - all_prs = [ - pr for pr in all_prs if any(fnmatch(lbl, pat) for pat in wildcard_labels for lbl in pr.labels) - ] - if wildcard_exclude_labels: - all_prs = [ - pr - for pr in all_prs - if not any(fnmatch(lbl, pat) for pat in wildcard_exclude_labels for lbl in pr.labels) - ] - - if not all_prs: - return None - - # Enrich: commits behind - behind_map = _fetch_commits_behind_batch(token, github_repository, all_prs) - for pr in all_prs: - pr.commits_behind = behind_map.get(pr.number, 0) - - # Resolve unknown mergeable status - unknown_count = sum(1 for pr in all_prs if pr.mergeable == "UNKNOWN") - if unknown_count: - _resolve_unknown_mergeable(token, github_repository, all_prs) - - # Detect NOT_RUN reclassification - non_collab_success = [ - pr - for pr in all_prs - if pr.checks_state == "SUCCESS" - and pr.author_association not in _COLLABORATOR_ASSOCIATIONS - and not _is_bot_account(pr.author_login) - ] - reclassified_count = 0 - if non_collab_success: - _fetch_check_details_batch(token, github_repository, non_collab_success) - reclassified_count = sum(1 for pr in non_collab_success if pr.checks_state == "NOT_RUN") - - # Filter candidates - candidate_prs, accepted_prs, _, _, _ = _filter_candidate_prs( - all_prs, - include_collaborators=include_collaborators, - include_drafts=include_drafts, - checks_state=checks_state, - min_commits_behind=min_commits_behind, - max_num=max_num, - ) - - # Classify already triaged - triaged_classification = _classify_already_triaged_prs( - token, github_repository, candidate_prs, viewer_login - ) - - return BatchPrefetchResult( - all_prs=all_prs, - has_next_page=has_next_page, - next_cursor=new_cursor, - candidate_prs=candidate_prs, - accepted_prs=accepted_prs, - triaged_classification=triaged_classification, - reclassified_count=reclassified_count, - ) - - -def _start_next_batch_prefetch( - executor: ThreadPoolExecutor, - **kwargs, -) -> Future[BatchPrefetchResult | None]: - """Submit a background prefetch of the next batch to the given executor.""" - return executor.submit(_prefetch_next_batch, **kwargs) def _review_workflow_approval_prs(ctx: TriageContext, pending_approval: list[PRData]) -> None: @@ -5518,22 +5692,7 @@ def _fetch_failed_job_log_snippets( Returns a dict mapping failed check name -> LogSnippetInfo (snippet + job URL). Only fetches logs for checks in ``failed_check_names`` to limit API calls. - Results are cached by commit SHA so repeated runs skip the download. """ - # Check cache first - cached = _get_cached_log_snippets(github_repository, head_sha, failed_check_names) - if cached is not None: - get_console().print(f"[dim]Using cached CI logs for {head_sha[:8]} ({len(cached)} checks).[/]") - return { - name: LogSnippetInfo(snippet=info["snippet"], job_url=info["job_url"]) - for name, info in cached.items() - } - - check_word = "check" if len(failed_check_names) == 1 else "checks" - get_console().print( - f"[info]Downloading CI failure logs for {head_sha[:8]} " - f"({len(failed_check_names)} failed {check_word})...[/]" - ) import io import zipfile @@ -5636,10 +5795,6 @@ def _fetch_failed_job_log_snippets( if all(name in snippets for name in failed_check_names): break - # Cache results for future runs with the same commit - if snippets: - _save_log_cache(github_repository, head_sha, snippets) - return snippets @@ -5867,21 +6022,6 @@ def _fetch_main_canary_builds( return [r for r in runs if r.get("name") == "Tests"][:count] -def _platform_from_name(name: str) -> str: - """Determine platform (ARM/AMD) from the job name. - - Airflow CI job names typically contain 'ARM' or 'AMD' as a segment, - e.g. ``Tests / AMD Python 3.9 / ...`` or ``Tests / ARM Python 3.9 / ...``. - Falls back to ``AMD`` when the name does not contain a clear indicator. - """ - upper = name.upper() - if "ARM" in upper or "AARCH64" in upper: - return "ARM" - if "AMD" in upper or "X86" in upper or "X64" in upper: - return "AMD" - return "" - - def _platform_from_labels(labels: list[str]) -> str: """Determine platform (ARM/AMD) from GitHub Actions job runner labels.""" for label in labels: @@ -5920,25 +6060,18 @@ def _fetch_failed_jobs_for_run(token: str, github_repository: str, run_id: int) failed = [] for job in all_jobs: if job.get("conclusion") == "failure": - job_name = job.get("name", "unknown") - # Prefer platform detection from job name; fall back to runner labels - platform = _platform_from_name(job_name) or _platform_from_labels(job.get("labels") or []) failed.append( { - "name": job_name, - "platform": platform, + "name": job.get("name", "unknown"), + "platform": _platform_from_labels(job.get("labels") or []), "html_url": job.get("html_url", ""), } ) return failed -def _display_canary_builds_status(builds: list[dict]) -> None: - """Display a Rich table showing the status of recent scheduled Tests builds. - - Failed jobs are expected to be pre-fetched and embedded in each build dict - under the ``_failed_jobs`` key by ``_cached_fetch_main_canary_builds``. - """ +def _display_canary_builds_status(builds: list[dict], token: str, github_repository: str) -> None: + """Display a Rich table showing the status of recent scheduled Tests builds.""" from rich.table import Table console = get_console() @@ -5950,7 +6083,6 @@ def _display_canary_builds_status(builds: list[dict]) -> None: table = Table(title="Main Branch Tests Builds (scheduled)", expand=False) table.add_column("Status", justify="center") table.add_column("Started", justify="right") - table.add_column("Arch", justify="center") table.add_column("Failed Jobs", style="red") table.add_column("Link", style="dim") @@ -5984,22 +6116,17 @@ def _display_canary_builds_status(builds: list[dict]) -> None: # Clickable link to the workflow run page link = f"[link={html_url}]checks[/link]" if html_url else str(run_id) - # Use pre-fetched failed jobs (embedded by _cached_fetch_main_canary_builds) + # Fetch failed jobs for failed builds failed_jobs_display = "" - failed_jobs = build.get("_failed_jobs", []) - # Determine unique architectures from failed jobs - archs: set[str] = set() - if failed_jobs: - parts = [] - for fj in failed_jobs: - platform = fj.get("platform", "") - if platform: - archs.add(platform) - parts.append(fj["name"]) - failed_jobs_display = "\n".join(parts) - arch_display = ", ".join(sorted(archs)) if archs else "" - - table.add_row(status_display, age, arch_display, failed_jobs_display, link) + if conclusion == "failure" and run_id: + failed_jobs = _fetch_failed_jobs_for_run(token, github_repository, run_id) + if failed_jobs: + parts = [] + for fj in failed_jobs: + parts.append(f"{fj['name']} ({fj['platform']})") + failed_jobs_display = "\n".join(parts) + + table.add_row(status_display, age, failed_jobs_display, link) console.print(table) console.print() @@ -6317,7 +6444,7 @@ def _display_recent_pr_failure_panel( @click.option( "--llm-concurrency", type=int, - default=8, + default=4, show_default=True, help="Number of concurrent LLM assessment calls.", ) @@ -6397,7 +6524,6 @@ def auto_triage( ("review", _get_review_cache_dir), ("triage", _get_triage_cache_dir), ("status", _get_status_cache_dir), - ("log", _get_log_cache_dir), ]: cache_dir = get_dir(github_repository) if cache_dir.exists(): @@ -6441,17 +6567,12 @@ def auto_triage( # Refresh collaborators cache in the background on every run _refresh_collaborators_cache_in_background(token, github_repository) - # Preload main branch CI failure information and canary builds in parallel (both cached for 4 hours) - with ThreadPoolExecutor(max_workers=2) as startup_executor: - main_failures_future = startup_executor.submit( - _cached_fetch_recent_pr_failures, token, github_repository - ) - canary_builds_future = startup_executor.submit( - _cached_fetch_main_canary_builds, token, github_repository - ) - main_failures = main_failures_future.result() - canary_builds = canary_builds_future.result() - _display_canary_builds_status(canary_builds) + # Preload main branch CI failure information (cached for 4 hours) + main_failures = _cached_fetch_recent_pr_failures(token, github_repository) + + # Show status of recent scheduled (canary) builds on main branch (cached for 4 hours) + canary_builds = _cached_fetch_main_canary_builds(token, github_repository) + _display_canary_builds_status(canary_builds, token, github_repository) # Resolve review-requested filter: --reviews-for-me uses authenticated user, --reviews-for uses specified users review_requested_user: str | None = None @@ -6481,25 +6602,16 @@ def auto_triage( t_total_start = time.monotonic() - # Phase 1: Fetch and prepare PRs - console.print("\n[bold]Phase 1: Fetching and preparing PRs[/bold]") + # Phase 1: Lightweight fetch of PRs via GraphQL (no check contexts — fast) t_phase1_start = time.monotonic() has_next_page = False next_cursor: str | None = None - step_num = 0 - - # Step 1: Fetch PRs via GraphQL - step_num += 1 - t_step = time.monotonic() if pr_number: - console_print(f" [info][{step_num}/7] Fetching PR #{pr_number} via GraphQL...[/]") + console_print(f"[info]Fetching PR #{pr_number} via GraphQL...[/]") all_prs = [_fetch_single_pr_graphql(token, github_repository, pr_number)] elif len(review_requested_users) > 1: # Multiple reviewers: fetch PRs for each reviewer and merge (deduplicate) - console_print( - f" [info][{step_num}/7] Fetching PRs via GraphQL " - f"for {len(review_requested_users)} reviewers...[/]" - ) + console_print("[info]Fetching PRs via GraphQL for multiple reviewers...[/]") seen_numbers: set[int] = set() all_prs = [] for reviewer in review_requested_users: @@ -6524,7 +6636,7 @@ def auto_triage( # Disable pagination for multi-reviewer queries has_next_page = False else: - console_print(f" [info][{step_num}/7] Fetching PRs via GraphQL...[/]") + console_print("[info]Fetching PRs via GraphQL...[/]") all_prs, has_next_page, next_cursor = _fetch_prs_graphql( token, github_repository, @@ -6539,10 +6651,6 @@ def auto_triage( updated_before=updated_before, review_requested=review_requested_user, ) - console_print( - f" [dim]{len(all_prs)} PRs fetched" - f"{' (more pages available)' if has_next_page else ''} ({_fmt_duration(time.monotonic() - t_step)})[/]" - ) # Apply wildcard label filters client-side if wildcard_labels: @@ -6586,48 +6694,36 @@ def auto_triage( reviewed_by_prs.add(pr.number) if reviewed_by_prs: console.print( - f" [dim]Also found {len(reviewed_by_prs)} " + f"[info]Also found {len(reviewed_by_prs)} " f"{'PRs' if len(reviewed_by_prs) != 1 else 'PR'} " f"previously reviewed by {', '.join(review_requested_users)}.[/]" ) - # Step 2: Resolve how far behind base branch each PR is - step_num += 1 - t_step = time.monotonic() - console_print(f" [info][{step_num}/7] Checking how far behind base branch each PR is...[/]") + # Resolve how far behind base branch each PR is + console_print("[info]Checking how far behind base branch each PR is...[/]") behind_map = _fetch_commits_behind_batch(token, github_repository, all_prs) for pr in all_prs: pr.commits_behind = behind_map.get(pr.number, 0) - max_behind = max(behind_map.values()) if behind_map else 0 - console_print( - f" [dim]done (max {max_behind} commits behind) ({_fmt_duration(time.monotonic() - t_step)})[/]" - ) - # Step 3: Resolve UNKNOWN mergeable status before displaying the overview table - step_num += 1 - t_step = time.monotonic() + # Resolve UNKNOWN mergeable status before displaying the overview table unknown_count = sum(1 for pr in all_prs if pr.mergeable == "UNKNOWN") if unknown_count: console_print( - f" [info][{step_num}/7] Resolving merge conflict status " - f"for {unknown_count} {'PRs' if unknown_count != 1 else 'PR'}...[/]" + f"[info]Resolving merge conflict status for {unknown_count} " + f"{'PRs' if unknown_count != 1 else 'PR'} with unknown status...[/]" ) resolved = _resolve_unknown_mergeable(token, github_repository, all_prs) remaining = unknown_count - resolved if remaining: console_print( - f" [dim]{resolved} resolved, {remaining} still unknown " - f"({_fmt_duration(time.monotonic() - t_step)})[/]" + f" [dim]{resolved} resolved, {remaining} still unknown " + f"(GitHub hasn't computed mergeability yet).[/]" ) else: - console_print(f" [dim]All {resolved} resolved ({_fmt_duration(time.monotonic() - t_step)})[/]") - else: - console_print(f" [info][{step_num}/7] Merge conflict status: all known (skip)[/]") + console_print(f" [dim]All {resolved} resolved.[/]") - # Step 4: Detect PRs whose rollup state is SUCCESS but only have bot/labeler checks (no real CI). + # Detect PRs whose rollup state is SUCCESS but only have bot/labeler checks (no real CI). # These need to be reclassified as NOT_RUN so they get routed to workflow approval. - step_num += 1 - t_step = time.monotonic() non_collab_success = [ pr for pr in all_prs @@ -6637,23 +6733,19 @@ def auto_triage( ] if non_collab_success: console_print( - f" [info][{step_num}/7] Verifying CI status for {len(non_collab_success)} " - f"{'PRs' if len(non_collab_success) != 1 else 'PR'} showing SUCCESS...[/]" + f"[info]Verifying CI status for {len(non_collab_success)} " + f"{'PRs' if len(non_collab_success) != 1 else 'PR'} " + f"showing SUCCESS (checking for real test checks)...[/]" ) _fetch_check_details_batch(token, github_repository, non_collab_success) reclassified = sum(1 for pr in non_collab_success if pr.checks_state == "NOT_RUN") if reclassified: console_print( - f" [warning]{reclassified} reclassified to NOT_RUN " - f"(only bot/labeler checks) ({_fmt_duration(time.monotonic() - t_step)})[/]" + f" [warning]{reclassified} {'PRs' if reclassified != 1 else 'PR'} " + f"reclassified to NOT_RUN (only bot/labeler checks, no real CI).[/]" ) - else: - console_print(f" [dim]All verified ({_fmt_duration(time.monotonic() - t_step)})[/]") - else: - console_print(f" [info][{step_num}/7] CI status verification: none needed (skip)[/]") - # Step 5: Filter candidates - step_num += 1 + # Filter candidates first candidate_prs, accepted_prs, total_skipped_collaborator, total_skipped_bot, total_skipped_accepted = ( _filter_candidate_prs( all_prs, @@ -6665,20 +6757,9 @@ def auto_triage( also_accepted=reviewed_by_prs if review_mode else None, ) ) - console_print( - f" [info][{step_num}/7] Filtering candidates: " - f"{len(candidate_prs)} candidates, {len(accepted_prs)} accepted " - f"(skipped: {total_skipped_collaborator} collaborators, " - f"{total_skipped_bot} bots, {total_skipped_accepted} already accepted)[/]" - ) - # Step 6: Exclude PRs that already have a triage comment posted after the last commit - step_num += 1 - t_step = time.monotonic() - console_print( - f" [info][{step_num}/7] Checking for already-triaged PRs " - f"(no new commits since last triage comment)...[/]" - ) + # Exclude PRs that already have a triage comment posted after the last commit + console_print("[info]Checking for PRs already triaged (no new commits since last triage comment)...[/]") triaged_classification = _classify_already_triaged_prs( token, github_repository, candidate_prs, viewer_login ) @@ -6690,17 +6771,15 @@ def auto_triage( already_triaged = [pr for pr in candidate_prs if pr.number in already_triaged_nums] candidate_prs = [pr for pr in candidate_prs if pr.number not in already_triaged_nums] console_print( - f" [dim]Skipped {len(already_triaged)} already-triaged " + f"[info]Skipped {len(already_triaged)} already-triaged " + f"{'PRs' if len(already_triaged) != 1 else 'PR'} " f"({triaged_waiting_count} commented, " - f"{triaged_responded_count} author responded) " - f"({_fmt_duration(time.monotonic() - t_step)})[/]" + f"{triaged_responded_count} author responded).[/]" ) else: - console_print(f" [dim]None found ({_fmt_duration(time.monotonic() - t_step)})[/]") + console_print(" [dim]None found.[/]") - # Step 7: Display overview table - step_num += 1 - console_print(f" [info][{step_num}/7] Displaying overview table[/]") + # Display overview table (after triaged detection so we can mark actionable PRs) _display_pr_overview_table( all_prs, triaged_waiting_nums=triaged_classification["waiting"], @@ -6708,10 +6787,6 @@ def auto_triage( ) t_phase1_end = time.monotonic() - console.print( - f"[bold]Phase 1 complete:[/bold] {len(candidate_prs)} PRs to triage, " - f"{len(accepted_prs)} accepted ({_fmt_duration(t_phase1_end - t_phase1_start)})\n" - ) # --- Review mode: early exit into review flow for accepted PRs --- if review_mode: @@ -7072,56 +7147,81 @@ def auto_triage( log_futures=log_futures, ) - # Start prefetching next page in background while user reviews current batch - prefetch_executor = ThreadPoolExecutor(max_workers=1) if has_next_page and not pr_number else None - next_batch_future: Future[BatchPrefetchResult | None] | None = None - prefetch_kwargs = dict( - token=token, - github_repository=github_repository, - exact_labels=exact_labels, - exact_exclude_labels=exact_exclude_labels, - filter_user=filter_user, - sort=sort, - batch_size=batch_size, - created_after=created_after, - created_before=created_before, - updated_after=updated_after, - updated_before=updated_before, - review_requested_user=review_requested_user, - wildcard_labels=wildcard_labels, - wildcard_exclude_labels=wildcard_exclude_labels, - include_collaborators=include_collaborators, - include_drafts=include_drafts, - checks_state=checks_state, - min_commits_behind=min_commits_behind, - max_num=max_num, - viewer_login=viewer_login, - ) - if prefetch_executor and has_next_page: - next_batch_future = _start_next_batch_prefetch( - prefetch_executor, next_cursor=next_cursor, **prefetch_kwargs - ) + det_flagged_prs = [(pr, assessments[pr.number]) for pr in candidate_prs if pr.number in assessments] + det_flagged_prs.sort(key=lambda pair: (pair[0].author_login.lower(), pair[0].number)) + + # Use full-screen TUI when a TTY is available, otherwise fall back to sequential mode + use_tui = _has_tty() and not dry_run and not answer_triage try: - # Phase 4b: Present NOT_RUN PRs for workflow approval (LLM runs in background) - _review_workflow_approval_prs(ctx, pending_approval) + if use_tui: + # Full-screen TUI mode: show all PRs in an interactive full-screen view + # Build selection criteria description for TUI header + criteria_parts: list[str] = [] + if pr_number: + criteria_parts.append(f"PR #{pr_number}") + if labels: + criteria_parts.append(f"labels={','.join(labels)}") + if exclude_labels: + criteria_parts.append(f"exclude={','.join(exclude_labels)}") + if filter_user: + criteria_parts.append(f"user={filter_user}") + if review_requested_users: + criteria_parts.append(f"reviewer={','.join(review_requested_users)}") + if created_after: + criteria_parts.append(f"created>={created_after}") + if created_before: + criteria_parts.append(f"created<={created_before}") + if updated_after: + criteria_parts.append(f"updated>={updated_after}") + if updated_before: + criteria_parts.append(f"updated<={updated_before}") + if checks_state != "all": + criteria_parts.append(f"checks={checks_state}") + if min_commits_behind > 0: + criteria_parts.append(f"behind>={min_commits_behind}") + if include_drafts: + criteria_parts.append("include_drafts") + if include_collaborators: + criteria_parts.append("include_collaborators") + if sort != "created": + criteria_parts.append(f"sort={sort}") + criteria_parts.append(f"batch={batch_size}") + if triage_mode != "triage": + criteria_parts.append(f"mode={triage_mode}") + selection_criteria = " | ".join(criteria_parts) if criteria_parts else "defaults" + + _run_tui_triage( + ctx, + all_prs, + pending_approval=pending_approval, + det_flagged_prs=det_flagged_prs, + llm_candidates=llm_candidates, + passing_prs=passing_prs, + accepted_prs=accepted_prs, + already_triaged_nums=already_triaged_nums, + mode_desc=mode_desc.get(check_mode, check_mode), + selection_criteria=selection_criteria, + ) + else: + # Sequential mode (CI / forced answer / no TTY) + # Phase 4b: Present NOT_RUN PRs for workflow approval (LLM runs in background) + _review_workflow_approval_prs(ctx, pending_approval) - # Phase 5a: Present deterministically flagged PRs - det_flagged_prs = [(pr, assessments[pr.number]) for pr in candidate_prs if pr.number in assessments] - det_flagged_prs.sort(key=lambda pair: (pair[0].author_login.lower(), pair[0].number)) - _review_deterministic_flagged_prs(ctx, det_flagged_prs) + # Phase 5a: Present deterministically flagged PRs + _review_deterministic_flagged_prs(ctx, det_flagged_prs) - # Phase 5b: Present LLM-flagged PRs as they become ready (streaming) - _review_llm_flagged_prs(ctx, llm_candidates) + # Phase 5b: Present LLM-flagged PRs as they become ready (streaming) + _review_llm_flagged_prs(ctx, llm_candidates) - # Add LLM passing PRs to the passing list - passing_prs.extend(llm_passing) + # Add LLM passing PRs to the passing list + passing_prs.extend(llm_passing) - # Phase 5c: Present passing PRs for optional ready-for-review marking - _review_passing_prs(ctx, passing_prs) + # Phase 5c: Present passing PRs for optional ready-for-review marking + _review_passing_prs(ctx, passing_prs) - # Phase 5d: Check accepted PRs for stale CHANGES_REQUESTED reviews - _review_stale_review_requests(ctx, accepted_prs) + # Phase 5d: Check accepted PRs for stale CHANGES_REQUESTED reviews + _review_stale_review_requests(ctx, accepted_prs) except KeyboardInterrupt: console_print("\n[warning]Interrupted — shutting down.[/]") stats.quit_early = True @@ -7130,62 +7230,97 @@ def auto_triage( if llm_executor is not None: llm_executor.shutdown(wait=False, cancel_futures=True) - # Process subsequent batches using prefetched data - while not stats.quit_early and not pr_number and next_batch_future is not None: + # Fetch and process next batch if available and user hasn't quit + while has_next_page and not stats.quit_early and not pr_number: batch_num = getattr(stats, "_batch_count", 1) + 1 stats._batch_count = batch_num # type: ignore[attr-defined] - - # Wait for the prefetched result (should already be done or nearly done) - t_wait_start = time.monotonic() - was_ready = next_batch_future.done() - prefetch_result = next_batch_future.result() - t_wait = time.monotonic() - t_wait_start - if was_ready: - console_print(f"\n[info]Batch {batch_num}: next page already prefetched.[/]") - else: - console_print( - f"\n[info]Batch {batch_num}: waited {_fmt_duration(t_wait)} " - f"for background prefetch to complete.[/]" - ) - next_batch_future = None - - if prefetch_result is None: + console_print(f"\n[info]Batch complete. Fetching next batch (page {batch_num})...[/]\n") + all_prs, has_next_page, next_cursor = _fetch_prs_graphql( + token, + github_repository, + labels=exact_labels, + exclude_labels=exact_exclude_labels, + filter_user=filter_user, + sort=sort, + batch_size=batch_size, + created_after=created_after, + created_before=created_before, + updated_after=updated_after, + updated_before=updated_before, + review_requested=review_requested_user, + after_cursor=next_cursor, + ) + if not all_prs: console_print("[info]No more PRs to process.[/]") break - all_prs = prefetch_result.all_prs - has_next_page = prefetch_result.has_next_page - next_cursor = prefetch_result.next_cursor - candidate_prs = prefetch_result.candidate_prs - batch_accepted = prefetch_result.accepted_prs - accepted_prs.extend(batch_accepted) + # Apply wildcard label filters client-side + if wildcard_labels: + all_prs = [ + pr for pr in all_prs if any(fnmatch(lbl, pat) for pat in wildcard_labels for lbl in pr.labels) + ] + if wildcard_exclude_labels: + all_prs = [ + pr + for pr in all_prs + if not any(fnmatch(lbl, pat) for pat in wildcard_exclude_labels for lbl in pr.labels) + ] - console_print( - f"[info]Batch {batch_num}: {len(all_prs)} PRs fetched, " - f"{len(candidate_prs)} candidates" - f"{' (more pages available)' if has_next_page else ''}" - f" (wait: {_fmt_duration(t_wait)})[/]" - ) + # Enrich: commits behind, mergeable status + behind_map = _fetch_commits_behind_batch(token, github_repository, all_prs) + for pr in all_prs: + pr.commits_behind = behind_map.get(pr.number, 0) + unknown_count = sum(1 for pr in all_prs if pr.mergeable == "UNKNOWN") + if unknown_count: + _resolve_unknown_mergeable(token, github_repository, all_prs) - if prefetch_result.reclassified_count: + # Detect PRs whose rollup state is SUCCESS but only have bot/labeler checks + batch_non_collab_success = [ + pr + for pr in all_prs + if pr.checks_state == "SUCCESS" + and pr.author_association not in _COLLABORATOR_ASSOCIATIONS + and not _is_bot_account(pr.author_login) + ] + if batch_non_collab_success: console_print( - f" [warning]{prefetch_result.reclassified_count} " - f"{'PRs' if prefetch_result.reclassified_count != 1 else 'PR'} " - f"reclassified to NOT_RUN (only bot/labeler checks).[/]" + f"[info]Verifying CI status for {len(batch_non_collab_success)} " + f"{'PRs' if len(batch_non_collab_success) != 1 else 'PR'} " + f"showing SUCCESS...[/]" ) + _fetch_check_details_batch(token, github_repository, batch_non_collab_success) + reclassified = sum(1 for pr in batch_non_collab_success if pr.checks_state == "NOT_RUN") + if reclassified: + console_print( + f" [warning]{reclassified} {'PRs' if reclassified != 1 else 'PR'} " + f"reclassified to NOT_RUN (only bot/labeler checks).[/]" + ) + + ( + candidate_prs, + batch_accepted, + _, + _, + _, + ) = _filter_candidate_prs( + all_prs, + include_collaborators=include_collaborators, + include_drafts=include_drafts, + checks_state=checks_state, + min_commits_behind=min_commits_behind, + max_num=max_num, + ) + accepted_prs.extend(batch_accepted) if not candidate_prs: console_print("[info]No PRs to assess in this batch.[/]") _display_pr_overview_table(all_prs) - # Start prefetching the next page if available - if has_next_page and prefetch_executor: - next_batch_future = _start_next_batch_prefetch( - prefetch_executor, next_cursor=next_cursor, **prefetch_kwargs - ) continue - # Apply triage classification from prefetch - batch_triaged_cls = prefetch_result.triaged_classification + # Check already-triaged + batch_triaged_cls = _classify_already_triaged_prs( + token, github_repository, candidate_prs, viewer_login + ) batch_triaged_nums = batch_triaged_cls["waiting"] | batch_triaged_cls["responded"] if batch_triaged_nums: candidate_prs = [pr for pr in candidate_prs if pr.number not in batch_triaged_nums] @@ -7198,11 +7333,6 @@ def auto_triage( if not candidate_prs: console_print("[info]All PRs in this batch already triaged.[/]") - # Start prefetching the next page if available - if has_next_page and prefetch_executor: - next_batch_future = _start_next_batch_prefetch( - prefetch_executor, next_cursor=next_cursor, **prefetch_kwargs - ) continue # Enrich and assess @@ -7330,12 +7460,6 @@ def auto_triage( log_futures=batch_log_futures, ) - # Start prefetching the NEXT page before entering interactive review - if has_next_page and prefetch_executor: - next_batch_future = _start_next_batch_prefetch( - prefetch_executor, next_cursor=next_cursor, **prefetch_kwargs - ) - try: _review_workflow_approval_prs(batch_ctx, batch_pending) @@ -7357,13 +7481,6 @@ def auto_triage( if batch_executor is not None: batch_executor.shutdown(wait=False, cancel_futures=True) - # Clean up prefetch executor - if prefetch_executor is not None: - # Cancel any pending prefetch if user quit early - if next_batch_future is not None and not next_batch_future.done(): - next_batch_future.cancel() - prefetch_executor.shutdown(wait=False, cancel_futures=True) - # Display summary _display_triage_summary( all_prs, diff --git a/dev/breeze/src/airflow_breeze/utils/tui_display.py b/dev/breeze/src/airflow_breeze/utils/tui_display.py new file mode 100644 index 00000000000..543a5df6b5c --- /dev/null +++ b/dev/breeze/src/airflow_breeze/utils/tui_display.py @@ -0,0 +1,826 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Full-screen TUI display for PR auto-triage using Rich.""" + +from __future__ import annotations + +import io +import os +from enum import Enum +from typing import TYPE_CHECKING + +from rich.align import Align +from rich.columns import Columns +from rich.console import Console +from rich.panel import Panel +from rich.syntax import Syntax +from rich.table import Table +from rich.text import Text + +from airflow_breeze.utils.confirm import _has_tty +from airflow_breeze.utils.console import get_theme + +if TYPE_CHECKING: + from airflow_breeze.commands.pr_commands import PRData + + +class PRCategory(Enum): + """Category of a PR in the triage view.""" + + WORKFLOW_APPROVAL = "Needs Workflow" + FLAGGED = "Flagged" + LLM_FLAGGED = "LLM Flagged" + PASSING = "Passing" + STALE_REVIEW = "Stale Review" + ALREADY_TRIAGED = "Triaged" + SKIPPED = "Skipped" + + +# Category display styles +_CATEGORY_STYLES: dict[PRCategory, str] = { + PRCategory.WORKFLOW_APPROVAL: "bright_cyan", + PRCategory.FLAGGED: "red", + PRCategory.LLM_FLAGGED: "yellow", + PRCategory.PASSING: "green", + PRCategory.STALE_REVIEW: "yellow", + PRCategory.ALREADY_TRIAGED: "dim", + PRCategory.SKIPPED: "dim", +} + + +class TUIAction(Enum): + """Actions the user can take in the TUI.""" + + SELECT = "enter" + UP = "up" + DOWN = "down" + PAGE_UP = "page_up" + PAGE_DOWN = "page_down" + NEXT_PAGE = "next_page" + PREV_PAGE = "prev_page" + QUIT = "q" + OPEN = "o" + SHOW_DIFF = "w" + SKIP = "s" + NEXT_SECTION = "tab" + TOGGLE_SELECT = "space" + APPROVE_SELECTED = "approve" + + +class _FocusPanel(Enum): + """Which panel currently has focus for keyboard input.""" + + PR_LIST = "pr_list" + DIFF = "diff" + + +def _get_terminal_size() -> tuple[int, int]: + """Get terminal size (columns, rows).""" + try: + size = os.get_terminal_size() + return size.columns, size.lines + except OSError: + return 120, 40 + + +def _make_tui_console() -> Console: + """Create a Console instance for TUI rendering.""" + width, _ = _get_terminal_size() + return Console( + force_terminal=True, + color_system="standard", + width=width, + theme=get_theme(), + ) + + +def _read_tui_key() -> TUIAction | str: + """Read a keypress and map it to a TUIAction or return the raw character.""" + if not _has_tty(): + # No TTY — fall back to line input + try: + line = input() + return line.strip().lower() if line.strip() else TUIAction.SELECT + except (EOFError, KeyboardInterrupt): + return TUIAction.QUIT + + import click + + ch = click.getchar() + + # Arrow key escape sequences + if ch == "\x1b[A" or ch == "\x1bOA": + return TUIAction.UP + if ch == "\x1b[B" or ch == "\x1bOB": + return TUIAction.DOWN + if ch == "\x1b[5~": + return TUIAction.PAGE_UP + if ch == "\x1b[6~": + return TUIAction.PAGE_DOWN + if ch == " ": + return TUIAction.TOGGLE_SELECT + if ch in ("\r", "\n"): + return TUIAction.SELECT + if ch == "\t": + return TUIAction.NEXT_SECTION + if ch == "q" or ch == "Q": + return TUIAction.QUIT + if ch == "o" or ch == "O": + return TUIAction.OPEN + if ch == "w" or ch == "W": + return TUIAction.SHOW_DIFF + if ch == "s" or ch == "S": + return TUIAction.SKIP + if ch == "a" or ch == "A": + return TUIAction.APPROVE_SELECTED + if ch == "n" or ch == "N": + return TUIAction.NEXT_PAGE + if ch == "p" or ch == "P": + return TUIAction.PREV_PAGE + # k/j for vim-style navigation + if ch == "k": + return TUIAction.UP + if ch == "j": + return TUIAction.DOWN + + # Return raw character for other keys + return ch if len(ch) == 1 else "" + + +class PRListEntry: + """A PR entry in the TUI list with its category and optional metadata.""" + + def __init__(self, pr: PRData, category: PRCategory, *, action_taken: str = ""): + self.pr = pr + self.category = category + self.action_taken = action_taken + self.selected = False + + +class TriageTUI: + """Full-screen TUI for PR auto-triage overview.""" + + def __init__( + self, + title: str = "Auto-Triage", + *, + mode_desc: str = "", + github_repository: str = "", + selection_criteria: str = "", + ): + self.title = title + self.mode_desc = mode_desc + self.github_repository = github_repository + self.selection_criteria = selection_criteria + self.entries: list[PRListEntry] = [] + self.cursor: int = 0 + self.scroll_offset: int = 0 + self._visible_rows: int = 0 # set during render + self._console = _make_tui_console() + self._sections: dict[PRCategory, list[PRListEntry]] = {} + # Diff panel state + self._diff_text: str = "" + self._diff_lines: list[str] = [] + self._diff_scroll: int = 0 + self._diff_visible_lines: int = 20 + self._diff_pr_number: int | None = None # track which PR's diff is loaded + # Focus state — which panel receives keyboard navigation + self._focus: _FocusPanel = _FocusPanel.PR_LIST + # Track previous cursor to detect PR changes for diff auto-fetch + self._prev_cursor: int = -1 + + def set_entries( + self, + entries: list[PRListEntry], + ) -> None: + """Set the PR entries to display.""" + self.entries = entries + self.cursor = 0 + self.scroll_offset = 0 + self._sections.clear() + for entry in entries: + self._sections.setdefault(entry.category, []).append(entry) + + def _build_header(self, width: int) -> Panel: + """Build the header panel with title and stats.""" + # Count by category + counts: dict[PRCategory, int] = {} + for entry in self.entries: + counts[entry.category] = counts.get(entry.category, 0) + 1 + + parts = [f"[bold]{self.title}[/]"] + if self.github_repository: + parts[0] += f": [cyan]{self.github_repository}[/]" + + stats_parts = [] + total = len(self.entries) + stats_parts.append(f"Total: [bold]{total}[/]") + for cat in PRCategory: + if cat in counts: + style = _CATEGORY_STYLES.get(cat, "white") + stats_parts.append(f"[{style}]{cat.value}: {counts[cat]}[/]") + selected_count = sum(1 for e in self.entries if e.selected) + if selected_count: + stats_parts.append(f"[bold green]Selected: {selected_count}[/]") + if self.mode_desc: + stats_parts.append(f"Mode: [bold]{self.mode_desc}[/]") + + header_lines = [" | ".join(parts), " | ".join(stats_parts)] + if self.selection_criteria: + header_lines.append(f"[dim]Selection: {self.selection_criteria}[/]") + # Show sorting order: category order (only categories with PRs) then secondary keys + sorted_cats = sorted( + (cat for cat in counts), + key=lambda c: list(PRCategory).index(c), + ) + cat_order = " → ".join(f"[{_CATEGORY_STYLES.get(c, 'white')}]{c.value}[/]" for c in sorted_cats) + header_lines.append(f"[dim]Sort: {cat_order}, then by author, PR#[/]") + header_text = "\n".join(header_lines) + return Panel(header_text, border_style="bright_blue", padding=(0, 1)) + + def _page_info(self, visible_rows: int) -> tuple[int, int]: + """Return (current_page, total_pages) based on visible rows.""" + if not self.entries or visible_rows <= 0: + return 1, 1 + total_pages = max(1, -(-len(self.entries) // visible_rows)) # ceil division + current_page = (self.scroll_offset // visible_rows) + 1 + return min(current_page, total_pages), total_pages + + def _build_pr_table(self, visible_rows: int) -> Panel: + """Build the scrollable PR list table.""" + self._visible_rows = visible_rows + width, _ = _get_terminal_size() + + # Adjust scroll to keep cursor visible + if self.cursor < self.scroll_offset: + self.scroll_offset = self.cursor + elif self.cursor >= self.scroll_offset + visible_rows: + self.scroll_offset = self.cursor - visible_rows + 1 + + table = Table( + show_header=True, + header_style="bold", + expand=True, + show_edge=False, + pad_edge=False, + box=None, + ) + table.add_column("", width=2, no_wrap=True) # cursor indicator + table.add_column("#", style="cyan", width=6, no_wrap=True) + table.add_column("Category", width=15, no_wrap=True) + table.add_column("Status", width=8, no_wrap=True) + table.add_column("Title", ratio=1) + table.add_column("Author", width=16, no_wrap=True) + table.add_column("CI", width=8, no_wrap=True) + table.add_column("Behind", width=6, justify="right", no_wrap=True) + table.add_column("Action", width=10, no_wrap=True) + + end_idx = min(self.scroll_offset + visible_rows, len(self.entries)) + for i in range(self.scroll_offset, end_idx): + entry = self.entries[i] + pr = entry.pr + is_selected = i == self.cursor + + # Cursor indicator (show selection checkmark for workflow approval PRs) + if is_selected and entry.selected: + cursor_mark = "[bold green]>[/]" + elif is_selected: + cursor_mark = "[bold bright_white]>[/]" + elif entry.selected: + cursor_mark = "[green]*[/]" + else: + cursor_mark = " " + + # PR number + pr_num = f"[bold cyan]#{pr.number}[/]" if is_selected else f"#{pr.number}" + + # Category + cat_style = _CATEGORY_STYLES.get(entry.category, "white") + cat_text = f"[{cat_style}]{entry.category.value}[/]" + + # Overall status + if pr.is_draft: + status = "[yellow]Draft[/]" + elif pr.checks_state == "FAILURE" or pr.mergeable == "CONFLICTING": + status = "[red]Issues[/]" + elif pr.checks_state in ("UNKNOWN", "NOT_RUN"): + status = "[yellow]No CI[/]" + elif pr.checks_state == "PENDING": + status = "[yellow]Pending[/]" + else: + status = "[green]OK[/]" + + # Title - truncate based on available width + max_title = max(20, width - 85) + title = pr.title[:max_title] + if len(pr.title) > max_title: + title += "..." + if is_selected: + title = f"[bold]{title}[/]" + + # Author + author = pr.author_login[:16] + + # CI status + if pr.checks_state == "FAILURE": + ci = "[red]Fail[/]" + elif pr.checks_state == "PENDING": + ci = "[yellow]Pend[/]" + elif pr.checks_state in ("UNKNOWN", "NOT_RUN"): + ci = "[yellow]NotRun[/]" + elif pr.checks_state == "SUCCESS": + ci = "[green]Pass[/]" + else: + ci = f"[dim]{pr.checks_state[:6]}[/]" + + # Commits behind + behind = f"[yellow]{pr.commits_behind}[/]" if pr.commits_behind > 0 else "[green]0[/]" + + # Action taken + action_text = "" + if entry.action_taken: + action_styles = { + "drafted": "[yellow]drafted[/]", + "commented": "[yellow]commented[/]", + "closed": "[red]closed[/]", + "rebased": "[green]rebased[/]", + "rerun": "[green]rerun[/]", + "approved": "[green]approved[/]", + "ready": "[green]ready[/]", + "skipped": "[dim]skipped[/]", + "pinged": "[cyan]pinged[/]", + } + action_text = action_styles.get(entry.action_taken, f"[dim]{entry.action_taken}[/]") + + # Row style for selected + row_style = "on grey23" if is_selected else "" + + table.add_row( + cursor_mark, + pr_num, + cat_text, + status, + title, + author, + ci, + behind, + action_text, + style=row_style, + ) + + # Add scroll indicators + scroll_info = "" + if self.scroll_offset > 0: + scroll_info += f" [dim]... {self.scroll_offset} more above[/]" + remaining_below = len(self.entries) - end_idx + if remaining_below > 0: + if scroll_info: + scroll_info += " | " + scroll_info += f"[dim]{remaining_below} more below ...[/]" + + pos_text = f"[dim]{self.cursor + 1}/{len(self.entries)}[/]" + current_page, total_pages = self._page_info(visible_rows) + page_text = f"[dim]Page {current_page}/{total_pages}[/]" + + # Show focus indicator on PR list + focus_indicator = "[bold bright_white] FOCUS [/]" if self._focus == _FocusPanel.PR_LIST else "" + title_text = f"PR List {pos_text} {page_text} {focus_indicator}" + + border_style = "bold bright_blue" if self._focus == _FocusPanel.PR_LIST else "bright_blue" + return Panel(table, title=title_text, subtitle=scroll_info, border_style=border_style) + + def _build_detail_panel(self, panel_height: int) -> Panel: + """Build the detail panel for the currently selected PR.""" + if not self.entries: + return Panel("[dim]No PRs to display[/]", title="Details", border_style="dim") + + entry = self.entries[self.cursor] + pr = entry.pr + + lines = [] + # PR title and link + lines.append(f"[bold cyan]#{pr.number}[/] [bold]{pr.title}[/]") + lines.append(f"[link={pr.url}]{pr.url}[/link]") + lines.append("") + + # Author + lines.append(f"Author: [bold]{pr.author_login}[/] ([dim]{pr.author_association}[/])") + + # Timestamps + from airflow_breeze.commands.pr_commands import _human_readable_age + + lines.append( + f"Created: {_human_readable_age(pr.created_at)} | Updated: {_human_readable_age(pr.updated_at)}" + ) + + # Status info + lines.append("") + status_parts = [] + if pr.is_draft: + status_parts.append("[yellow]Draft[/]") + if pr.mergeable == "CONFLICTING": + status_parts.append("[red]Merge conflicts[/]") + elif pr.mergeable == "MERGEABLE": + status_parts.append("[green]Mergeable[/]") + if pr.commits_behind > 0: + status_parts.append( + f"[yellow]{pr.commits_behind} commit{'s' if pr.commits_behind != 1 else ''} behind[/]" + ) + if pr.checks_state == "FAILURE": + status_parts.append(f"[red]CI: Failing ({len(pr.failed_checks)} checks)[/]") + elif pr.checks_state == "SUCCESS": + status_parts.append("[green]CI: Passing[/]") + elif pr.checks_state == "PENDING": + status_parts.append("[yellow]CI: Pending[/]") + elif pr.checks_state in ("NOT_RUN", "UNKNOWN"): + status_parts.append("[yellow]CI: Not run[/]") + + if status_parts: + lines.append(" | ".join(status_parts)) + + # Failed checks + if pr.failed_checks: + lines.append("") + lines.append("[red]Failed checks:[/]") + for check in pr.failed_checks[:5]: + lines.append(f" [red]- {check}[/]") + if len(pr.failed_checks) > 5: + lines.append(f" [dim]... and {len(pr.failed_checks) - 5} more[/]") + + # Labels + if pr.labels: + lines.append("") + label_text = ", ".join(f"[dim]{lbl}[/]" for lbl in pr.labels[:5]) + if len(pr.labels) > 5: + label_text += f" (+{len(pr.labels) - 5} more)" + lines.append(f"Labels: {label_text}") + + # Unresolved threads + if pr.unresolved_threads: + lines.append("") + lines.append(f"[yellow]Unresolved review threads: {len(pr.unresolved_threads)}[/]") + for t in pr.unresolved_threads[:3]: + body_preview = t.comment_body[:80].replace("\n", " ") + if len(t.comment_body) > 80: + body_preview += "..." + lines.append(f" [dim]@{t.reviewer_login}:[/] {body_preview}") + + # Category + lines.append("") + cat_style = _CATEGORY_STYLES.get(entry.category, "white") + lines.append(f"Category: [{cat_style}]{entry.category.value}[/]") + + # Action taken + if entry.action_taken: + lines.append(f"Action: [bold]{entry.action_taken}[/]") + + # Truncate lines to fit panel height (subtract borders) + max_lines = max(1, panel_height - 2) + if len(lines) > max_lines: + lines = lines[:max_lines] + + content = "\n".join(lines) + return Panel(content, title="Details", border_style="cyan", padding=(0, 1)) + + def set_diff(self, pr_number: int, diff_text: str) -> None: + """Set the diff content for display.""" + self._diff_text = diff_text + self._diff_lines = diff_text.splitlines() + self._diff_scroll = 0 + self._diff_pr_number = pr_number + + def needs_diff_fetch(self) -> bool: + """Check if the current PR needs a diff fetch (cursor moved to new PR).""" + entry = self.get_selected_entry() + if not entry: + return False + return self._diff_pr_number != entry.pr.number + + def scroll_diff(self, delta: int, visible_lines: int = 20) -> None: + """Scroll the diff panel by delta lines.""" + if not self._diff_lines: + return + max_scroll = max(0, len(self._diff_lines) - visible_lines) + self._diff_scroll = max(0, min(max_scroll, self._diff_scroll + delta)) + + def _build_diff_panel(self, panel_height: int, panel_width: int) -> Panel: + """Build the scrollable diff panel.""" + is_focused = self._focus == _FocusPanel.DIFF + focus_indicator = "[bold bright_white] FOCUS [/]" if is_focused else "" + + if not self._diff_text: + return Panel( + "[dim]Loading diff...[/]", + title=f"Diff {focus_indicator}", + border_style="bold bright_cyan" if is_focused else "dim", + width=panel_width, + ) + + # Visible lines within the panel (subtract borders) + visible_lines = max(1, panel_height - 2) + self._diff_visible_lines = visible_lines + + # Clamp scroll + max_scroll = max(0, len(self._diff_lines) - visible_lines) + self._diff_scroll = min(self._diff_scroll, max_scroll) + + # Slice the diff text to the visible window + end = min(self._diff_scroll + visible_lines, len(self._diff_lines)) + visible_text = "\n".join(self._diff_lines[self._diff_scroll : end]) + + diff_content = Syntax(visible_text, "diff", theme="monokai", word_wrap=True) + + # Scroll info + scroll_info_parts = [] + if self._diff_scroll > 0: + scroll_info_parts.append(f"{self._diff_scroll} lines above") + remaining = len(self._diff_lines) - end + if remaining > 0: + scroll_info_parts.append(f"{remaining} lines below") + scroll_info = f"[dim]{' | '.join(scroll_info_parts)}[/]" if scroll_info_parts else "" + + pr_num = self._diff_pr_number or "?" + pos = f"{self._diff_scroll + 1}-{end}/{len(self._diff_lines)}" + title = f"Diff #{pr_num} [dim]{pos}[/] {focus_indicator}" + + border_style = "bold bright_cyan" if is_focused else "bright_cyan" + return Panel( + diff_content, + title=title, + subtitle=scroll_info, + border_style=border_style, + padding=(0, 1), + width=panel_width, + ) + + def _build_footer(self) -> Panel: + """Build the footer panel with context-sensitive available actions.""" + # Navigation keys depend on focus + if self._focus == _FocusPanel.DIFF: + nav = ( + "[bold]j/↓[/] Scroll down [bold]k/↑[/] Scroll up " + "[bold]PgDn/Space[/] Page down [bold]PgUp[/] Page up " + "[bold]Tab[/] Switch to PR list" + ) + else: + nav = ( + "[bold]j/↓[/] Down [bold]k/↑[/] Up " + "[bold]n[/] Next pg [bold]p[/] Prev pg " + "[bold]Tab[/] Switch to diff" + ) + + # Context-sensitive action keys based on selected PR category + entry = self.get_selected_entry() + selected_count = sum(1 for e in self.entries if e.selected) + if entry is not None: + cat = entry.category + action_parts: list[str] = [] + if cat in (PRCategory.FLAGGED, PRCategory.LLM_FLAGGED): + action_parts.append("[bold]Enter[/] Review flagged PR") + elif cat == PRCategory.WORKFLOW_APPROVAL: + action_parts.append("[bold]Space[/] Toggle select") + action_parts.append("[bold]Enter[/] Review workflow") + if selected_count: + action_parts.append(f"[bold green]a[/] [green]Approve selected ({selected_count})[/]") + else: + action_parts.append("[bold]a[/] Approve selected") + elif cat == PRCategory.PASSING: + action_parts.append("[bold]Enter[/] Triage PR") + elif cat == PRCategory.ALREADY_TRIAGED: + action_parts.append("[bold]Enter[/] View (triaged)") + else: + action_parts.append("[bold]Enter[/] View") + action_parts.append("[bold]o[/] Open in browser") + action_parts.append("[bold]s[/] Skip") + action_parts.append("[bold]q[/] Quit") + actions = " ".join(action_parts) + else: + actions = "[bold]q[/] Quit" + + footer_text = f"{nav}\n{actions}" + return Panel( + Align.center(Text.from_markup(footer_text)), + border_style="bright_blue", + padding=(0, 1), + ) + + def _build_bottom_panels(self, bottom_height: int, total_width: int) -> Columns: + """Build the side-by-side detail + diff panels.""" + # Split width: detail gets ~40%, diff gets ~60% + detail_width = max(30, int(total_width * 0.4)) + diff_width = max(30, total_width - detail_width - 1) # -1 for column gap + + detail_panel = self._build_detail_panel(bottom_height) + diff_panel = self._build_diff_panel(bottom_height, diff_width) + + return Columns( + [detail_panel, diff_panel], + expand=True, + equal=False, + padding=(0, 0), + ) + + def render(self) -> None: + """Render the full-screen TUI using a single buffered write to avoid flicker.""" + width, height = _get_terminal_size() + + # Build everything into a buffer console first, then output at once + buf = io.StringIO() + buf_console = Console( + file=buf, + force_terminal=True, + color_system="standard", + width=width, + theme=get_theme(), + ) + + # Calculate layout sizes + header_height = 6 if self.selection_criteria else 5 + footer_height = 4 # two lines of keys + border + available = height - header_height - footer_height - 2 + + # PR list gets ~50%, bottom panels get ~50% + list_height = max(5, int(available * 0.5)) + bottom_height = max(5, available - list_height) + visible_rows = list_height - 3 + + header = self._build_header(width) + pr_table = self._build_pr_table(visible_rows) + bottom = self._build_bottom_panels(bottom_height, width) + footer = self._build_footer() + + buf_console.print(header) + buf_console.print(pr_table, height=list_height) + buf_console.print(bottom, height=bottom_height) + buf_console.print(footer) + + # Single atomic write: move cursor to top-left and overwrite + output = buf.getvalue() + # Use real console for the actual write + self._console = _make_tui_console() + # Move to top of screen and overwrite (avoids clear+redraw flicker) + self._console.file.write("\033[H") # cursor to home position + self._console.file.write("\033[J") # clear from cursor to end + self._console.file.write(output) + self._console.file.flush() + + def get_selected_entry(self) -> PRListEntry | None: + """Return the currently selected PR entry.""" + if self.entries and 0 <= self.cursor < len(self.entries): + return self.entries[self.cursor] + return None + + def get_selected_entries(self) -> list[PRListEntry]: + """Return all PRs that have been selected (toggled) for batch actions.""" + return [e for e in self.entries if e.selected] + + def move_cursor(self, delta: int) -> None: + """Move the cursor by delta positions, clamping to bounds.""" + if not self.entries: + return + self.cursor = max(0, min(len(self.entries) - 1, self.cursor + delta)) + + def next_page(self) -> None: + """Move cursor to the start of the next page.""" + if not self.entries or self._visible_rows <= 0: + return + new_offset = self.scroll_offset + self._visible_rows + if new_offset < len(self.entries): + self.scroll_offset = new_offset + self.cursor = new_offset + + def prev_page(self) -> None: + """Move cursor to the start of the previous page.""" + if not self.entries or self._visible_rows <= 0: + return + new_offset = max(0, self.scroll_offset - self._visible_rows) + self.scroll_offset = new_offset + self.cursor = new_offset + + def mark_action(self, index: int, action: str) -> None: + """Mark a PR entry with the action taken.""" + if 0 <= index < len(self.entries): + self.entries[index].action_taken = action + + def cursor_changed(self) -> bool: + """Check if cursor moved to a different PR since last check. Resets tracking.""" + changed = self.cursor != self._prev_cursor + self._prev_cursor = self.cursor + return changed + + def run_interactive(self) -> tuple[PRListEntry | None, TUIAction | str]: + """Render and wait for user input. Returns (selected_entry, action).""" + self.render() + key = _read_tui_key() + + # Tab switches focus between PR list and diff panel + if key == TUIAction.NEXT_SECTION: + if self._focus == _FocusPanel.PR_LIST: + self._focus = _FocusPanel.DIFF + else: + self._focus = _FocusPanel.PR_LIST + return None, key + + # When diff panel has focus, navigation keys scroll the diff + if self._focus == _FocusPanel.DIFF: + visible = self._diff_visible_lines + if key == TUIAction.UP: + self.scroll_diff(-1, visible) + return None, key + if key == TUIAction.DOWN: + self.scroll_diff(1, visible) + return None, key + if key == TUIAction.PAGE_UP: + self.scroll_diff(-visible, visible) + return None, key + if key == TUIAction.PAGE_DOWN: + self.scroll_diff(visible, visible) + return None, key + if key == TUIAction.TOGGLE_SELECT: + # Space scrolls page-down in diff focus + self.scroll_diff(visible, visible) + return None, key + # Pass through action keys even in diff focus + if key == TUIAction.QUIT: + return None, key + if key == TUIAction.OPEN: + return self.get_selected_entry(), key + if key == TUIAction.SELECT: + return self.get_selected_entry(), key + if key == TUIAction.SKIP: + entry = self.get_selected_entry() + if entry: + entry.action_taken = "skipped" + self.move_cursor(1) + return entry, key + if key == TUIAction.APPROVE_SELECTED: + selected = [e for e in self.entries if e.selected] + if selected: + return selected[0], key + return None, key + # Ignore other keys in diff focus + return None, key + + # PR list has focus — standard navigation + if key == TUIAction.UP: + self.move_cursor(-1) + return None, key + if key == TUIAction.DOWN: + self.move_cursor(1) + return None, key + if key == TUIAction.PAGE_UP: + self.prev_page() + return None, key + if key == TUIAction.PAGE_DOWN: + self.next_page() + return None, key + if key == TUIAction.NEXT_PAGE: + self.next_page() + return None, key + if key == TUIAction.PREV_PAGE: + self.prev_page() + return None, key + if key == TUIAction.SELECT: + return self.get_selected_entry(), key + if key == TUIAction.QUIT: + return None, key + if key == TUIAction.OPEN: + return self.get_selected_entry(), key + if key == TUIAction.SHOW_DIFF: + # w key switches focus to diff panel + self._focus = _FocusPanel.DIFF + return None, key + if key == TUIAction.TOGGLE_SELECT: + entry = self.get_selected_entry() + if entry and entry.category == PRCategory.WORKFLOW_APPROVAL and not entry.action_taken: + entry.selected = not entry.selected + self.move_cursor(1) + return None, key + if key == TUIAction.APPROVE_SELECTED: + # Return selected entries for batch approval + selected = [e for e in self.entries if e.selected] + if selected: + return selected[0], key + return None, key + if key == TUIAction.SKIP: + entry = self.get_selected_entry() + if entry: + entry.action_taken = "skipped" + self.move_cursor(1) + return entry, key + # Unknown key — return it for caller to handle + return self.get_selected_entry(), key diff --git a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py index e20b8804f80..22361949f3b 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py +++ b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py @@ -81,7 +81,7 @@ def validate_sql( raise SQLSafetyError(f"SQL parse error: {e}") from e # sqlglot.parse can return [None] for empty input - parsed: list[exp.Expression] = [s for s in statements if s is not None] # type: ignore[misc] + parsed: list[exp.Expression] = [s for s in statements if s is not None] if not parsed: raise SQLSafetyError("Empty SQL input.") diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 611c4fc28ec..e26344a81e5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1765,11 +1765,7 @@ def finalize( try: SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) except Exception: - log.exception( - "Failed to set rendered fields during finalization", - task_id=ti.task_id, - dag_id=ti.dag_id, - ) + log.exception("Failed to set rendered fields during finalization", ti=ti, task=ti.task) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 47de56c384f..37d3963146b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2337,8 +2337,8 @@ class TestRuntimeTaskInstance: mock_log.exception.assert_called_once_with( "Failed to set rendered fields during finalization", - task_id=runtime_ti.task_id, - dag_id=runtime_ti.dag_id, + ti=runtime_ti, + task=task, ) @pytest.mark.parametrize(
