Copilot commented on code in PR #67150: URL: https://github.com/apache/airflow/pull/67150#discussion_r3263896858
########## scripts/tests/ci/prek/test_check_provide_session_kwargs.py: ########## @@ -0,0 +1,395 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import textwrap +from pathlib import Path + +import pytest +from ci.prek.check_provide_session_kwargs import ( + AllowlistManager, + _check_provide_session_kwargs, + _count_violations, + _expand_for_allowlist_edits, + _has_provide_session_decorator, + _iter_positional_session_in_provide_session, + _session_is_positional, +) + + [email protected] +def find_violations(write_python_file): + """Factory fixture: write code to a temp file and return positional-session violations.""" + + def _check(code: str) -> list[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]: + path = write_python_file(code) + return list(_iter_positional_session_in_provide_session(path)) + + return _check + + [email protected] +def fake_repo(tmp_path, monkeypatch): + """Create a fake repo layout and patch REPO_ROOT so paths resolve correctly.""" + import ci.prek.check_provide_session_kwargs as hook + + monkeypatch.setattr(hook, "REPO_ROOT", tmp_path) Review Comment: The fixture performs a module import inside the function body. This violates the repo's general rule of keeping imports at module top-level (unless needed for circular-import avoidance or lazy loading). Since this module is already imported at the top of the file, consider moving `import ci.prek.check_provide_session_kwargs as hook` to the top level (or reusing an already-imported module object) and patching that instead. ########## scripts/ci/prek/check_provide_session_kwargs.py: ########## @@ -0,0 +1,352 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "rich>=13.0.0", +# ] +# /// +"""Check that no new ``@provide_session`` functions declare ``session`` positionally. + +The project convention is that any function decorated with ``@provide_session`` +must declare ``session`` as keyword-only (after a bare ``*`` in the signature), +so callers cannot pass it positionally by accident. See +``contributing-docs/05_pull_requests.rst#database-session-handling``. + +All *existing* offenders are recorded in ``known_provide_session_positional.txt`` +next to this script as ``relative/path::N`` entries (one per file), where ``N`` +is the maximum number of ``@provide_session`` functions with a positional +``session`` argument allowed in that file. A file whose current count exceeds +the recorded limit is treated as a violation – move the ``session`` argument +behind a bare ``*`` instead. + +Modes +----- +Default (files passed by prek/pre-commit): + Check only the supplied files; fail if any file's count exceeds the limit. + When a file's count has *decreased*, the allowlist entry is tightened + automatically and the hook exits with a non-zero code so that pre-commit + reports the modified allowlist – just stage + ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run. + +``--all-files``: + Walk every ``.py`` file under the project source roots + (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, ``shared``) — + the same scope the pre-commit hook applies to. + +``--cleanup``: + Remove entries for files that no longer exist. Safe to run at any time; + does not add new entries or raise limits. + +``--generate``: + Scan the same project source roots as ``--all-files`` and *rebuild* the + allowlist from scratch. Intended for the initial setup or after a + large-scale clean-up sprint. +""" + +from __future__ import annotations + +import argparse +import ast +import typing +from pathlib import Path + +from rich.console import Console +from rich.panel import Panel + +console = Console(color_system="standard", width=200) + +REPO_ROOT = Path(__file__).parents[3] + +_PROVIDE_SESSION_DECORATOR = "provide_session" + +# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in sync with the +# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``. +_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", "providers", "shared") + + +def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool: + """Whether one of ``nodes`` is a ``@provide_session`` decorator. + + Accepts both bare names (``@provide_session``) and attribute access + (``@something.provide_session``). + """ + for node in nodes: + if isinstance(node, ast.Name) and node.id == _PROVIDE_SESSION_DECORATOR: + return True + if isinstance(node, ast.Attribute) and node.attr == _PROVIDE_SESSION_DECORATOR: + return True + return False + + +def _session_is_positional(args: ast.arguments) -> ast.arg | None: + """Return the ``session`` arg if it is positional (not keyword-only). + + Covers both regular positional args and positional-only args (``def f(session, /, ...)``). + """ + for argument in (*args.posonlyargs, *args.args): + if argument.arg == "session": + return argument + return None + + +def _iter_positional_session_in_provide_session( + path: Path, +) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]: + """Yield ``@provide_session`` functions in *path* whose ``session`` is positional.""" + try: + source = path.read_text(encoding="utf-8") + except OSError: + return + try: + tree = ast.parse(source, str(path)) + except SyntaxError: + return + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not _has_provide_session_decorator(node.decorator_list): + continue + argument = _session_is_positional(node.args) + if argument is None: + continue + yield node, argument + + +def _count_violations(path: Path) -> int: + return sum(1 for _ in _iter_positional_session_in_provide_session(path)) + + +class AllowlistManager: + def __init__(self, allowlist_file: Path) -> None: + self.allowlist_file = allowlist_file + + def load(self) -> dict[str, int]: + if not self.allowlist_file.exists(): + return {} + + result: dict[str, int] = {} + for raw_line in self.allowlist_file.read_text().splitlines(): + if not (stripped := raw_line.strip()): + continue + + rel_str, _, count_str = stripped.rpartition("::") + if not rel_str or not count_str: + continue + + try: + result[rel_str] = int(count_str) + except ValueError: + continue + + return result + + def save(self, counts: dict[str, int]) -> None: + lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())] + self.allowlist_file.write_text("\n".join(lines) + "\n") + + def generate(self) -> int: + roots = ", ".join(_PROJECT_SOURCE_ROOTS) + console.print( + f"Scanning project source roots ([cyan]{roots}[/cyan]) under [cyan]{REPO_ROOT}[/cyan] " + "for @provide_session functions with positional session …" + ) + counts: dict[str, int] = {} + for path in _iter_python_files(): + n = _count_violations(path) + if n > 0: + counts[str(path.relative_to(REPO_ROOT))] = n + + self.save(counts) + total = sum(counts.values()) + console.print( + f"[green]Generated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] " + f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] offenders." + ) + return 0 + + def cleanup(self) -> int: + allowlist = self.load() + if not allowlist: + console.print("[yellow]Allowlist is empty - nothing to clean up.[/yellow]") + return 0 + + stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / rel).exists()] + if stale: + console.print( + f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) == 1 else 'ies'}:[/yellow]" + ) + for s in sorted(stale): + console.print(f" [dim]-[/dim] {s}") + for s in stale: + del allowlist[s] + self.save(allowlist) + console.print( + f"\n[green]Updated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]" + ) + else: + console.print("[green]No stale entries found.[/green]") + return 0 + + +def _iter_python_files() -> list[Path]: + candidates: list[Path] = [] + for top in _PROJECT_SOURCE_ROOTS: + candidates.extend( + p.resolve() + for p in (REPO_ROOT / top).rglob("*.py") + if ".tox" not in p.parts and "__pycache__" not in p.parts + ) + return candidates + + +def _check_provide_session_kwargs( + files: list[Path], allowlist: dict[str, int], manager: AllowlistManager +) -> int: + violations: list[tuple[Path, int, int]] = [] + tightened: list[tuple[str, int, int]] = [] + + for path in files: + if not path.exists() or path.suffix != ".py": + continue + actual = _count_violations(path) + rel = str(path.relative_to(REPO_ROOT)) + allowed = allowlist.get(rel, 0) Review Comment: `path.relative_to(REPO_ROOT)` will raise `ValueError` if a file path is outside the repo root. This can happen if `known_provide_session_positional.txt` is edited to contain a path with `..` segments or an absolute path (and `_expand_for_allowlist_edits()` resolves it). Consider guarding the `relative_to()` call (or validating allowlist entries when loading/expanding) so the hook fails gracefully with a clear error instead of crashing. ########## scripts/ci/prek/check_provide_session_kwargs.py: ########## @@ -0,0 +1,352 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "rich>=13.0.0", +# ] +# /// +"""Check that no new ``@provide_session`` functions declare ``session`` positionally. + +The project convention is that any function decorated with ``@provide_session`` +must declare ``session`` as keyword-only (after a bare ``*`` in the signature), +so callers cannot pass it positionally by accident. See +``contributing-docs/05_pull_requests.rst#database-session-handling``. + +All *existing* offenders are recorded in ``known_provide_session_positional.txt`` +next to this script as ``relative/path::N`` entries (one per file), where ``N`` +is the maximum number of ``@provide_session`` functions with a positional +``session`` argument allowed in that file. A file whose current count exceeds +the recorded limit is treated as a violation – move the ``session`` argument +behind a bare ``*`` instead. + +Modes +----- +Default (files passed by prek/pre-commit): + Check only the supplied files; fail if any file's count exceeds the limit. + When a file's count has *decreased*, the allowlist entry is tightened + automatically and the hook exits with a non-zero code so that pre-commit + reports the modified allowlist – just stage + ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run. + +``--all-files``: + Walk every ``.py`` file under the project source roots + (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, ``shared``) — + the same scope the pre-commit hook applies to. + +``--cleanup``: + Remove entries for files that no longer exist. Safe to run at any time; + does not add new entries or raise limits. + +``--generate``: + Scan the same project source roots as ``--all-files`` and *rebuild* the + allowlist from scratch. Intended for the initial setup or after a + large-scale clean-up sprint. +""" + +from __future__ import annotations + +import argparse +import ast +import typing +from pathlib import Path + +from rich.console import Console +from rich.panel import Panel + +console = Console(color_system="standard", width=200) + +REPO_ROOT = Path(__file__).parents[3] + +_PROVIDE_SESSION_DECORATOR = "provide_session" + +# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in sync with the +# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``. +_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", "providers", "shared") + + +def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool: + """Whether one of ``nodes`` is a ``@provide_session`` decorator. + + Accepts both bare names (``@provide_session``) and attribute access + (``@something.provide_session``). + """ + for node in nodes: + if isinstance(node, ast.Name) and node.id == _PROVIDE_SESSION_DECORATOR: + return True + if isinstance(node, ast.Attribute) and node.attr == _PROVIDE_SESSION_DECORATOR: + return True + return False + + +def _session_is_positional(args: ast.arguments) -> ast.arg | None: + """Return the ``session`` arg if it is positional (not keyword-only). + + Covers both regular positional args and positional-only args (``def f(session, /, ...)``). + """ + for argument in (*args.posonlyargs, *args.args): + if argument.arg == "session": + return argument + return None + + +def _iter_positional_session_in_provide_session( + path: Path, +) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]: + """Yield ``@provide_session`` functions in *path* whose ``session`` is positional.""" + try: + source = path.read_text(encoding="utf-8") + except OSError: + return + try: + tree = ast.parse(source, str(path)) + except SyntaxError: + return + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not _has_provide_session_decorator(node.decorator_list): + continue + argument = _session_is_positional(node.args) + if argument is None: + continue + yield node, argument + + +def _count_violations(path: Path) -> int: + return sum(1 for _ in _iter_positional_session_in_provide_session(path)) + + +class AllowlistManager: + def __init__(self, allowlist_file: Path) -> None: + self.allowlist_file = allowlist_file + + def load(self) -> dict[str, int]: + if not self.allowlist_file.exists(): + return {} + + result: dict[str, int] = {} + for raw_line in self.allowlist_file.read_text().splitlines(): + if not (stripped := raw_line.strip()): + continue + + rel_str, _, count_str = stripped.rpartition("::") + if not rel_str or not count_str: + continue + + try: + result[rel_str] = int(count_str) + except ValueError: + continue + + return result + + def save(self, counts: dict[str, int]) -> None: + lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())] + self.allowlist_file.write_text("\n".join(lines) + "\n") + + def generate(self) -> int: + roots = ", ".join(_PROJECT_SOURCE_ROOTS) + console.print( + f"Scanning project source roots ([cyan]{roots}[/cyan]) under [cyan]{REPO_ROOT}[/cyan] " + "for @provide_session functions with positional session …" + ) + counts: dict[str, int] = {} + for path in _iter_python_files(): + n = _count_violations(path) + if n > 0: + counts[str(path.relative_to(REPO_ROOT))] = n + + self.save(counts) + total = sum(counts.values()) + console.print( + f"[green]Generated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] " + f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] offenders." + ) + return 0 + + def cleanup(self) -> int: + allowlist = self.load() + if not allowlist: + console.print("[yellow]Allowlist is empty - nothing to clean up.[/yellow]") + return 0 + + stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / rel).exists()] + if stale: + console.print( + f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) == 1 else 'ies'}:[/yellow]" + ) + for s in sorted(stale): + console.print(f" [dim]-[/dim] {s}") + for s in stale: + del allowlist[s] + self.save(allowlist) + console.print( + f"\n[green]Updated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]" + ) + else: + console.print("[green]No stale entries found.[/green]") + return 0 + + +def _iter_python_files() -> list[Path]: + candidates: list[Path] = [] + for top in _PROJECT_SOURCE_ROOTS: + candidates.extend( + p.resolve() + for p in (REPO_ROOT / top).rglob("*.py") + if ".tox" not in p.parts and "__pycache__" not in p.parts + ) + return candidates + + +def _check_provide_session_kwargs( + files: list[Path], allowlist: dict[str, int], manager: AllowlistManager +) -> int: + violations: list[tuple[Path, int, int]] = [] + tightened: list[tuple[str, int, int]] = [] + + for path in files: + if not path.exists() or path.suffix != ".py": + continue + actual = _count_violations(path) Review Comment: `_check_provide_session_kwargs()` silently skips non-existent paths. If the allowlist file itself is deleted/renamed in a commit, prek/pre-commit will still pass its (now non-existent) path to this hook, and the hook can end up doing no checks and returning success. Consider explicitly detecting when the configured allowlist file is among the input paths but is missing, and failing with an instruction to restore or regenerate it. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
