Copilot commented on code in PR #67150:
URL: https://github.com/apache/airflow/pull/67150#discussion_r3264141921


##########
scripts/ci/prek/check_provide_session_kwargs.py:
##########
@@ -0,0 +1,390 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#   "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session`` 
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in 
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+    Check only the supplied files; fail if any file's count exceeds the limit.
+    When a file's count has *decreased*, the allowlist entry is tightened
+    automatically and the hook exits with a non-zero code so that pre-commit
+    reports the modified allowlist – just stage
+    ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+    Walk every ``.py`` file under the project source roots
+    (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, 
``shared``) —
+    the same scope the pre-commit hook applies to.
+
+``--cleanup``:
+    Remove entries for files that no longer exist. Safe to run at any time;
+    does not add new entries or raise limits.
+
+``--generate``:
+    Scan the same project source roots as ``--all-files`` and *rebuild* the
+    allowlist from scratch. Intended for the initial setup or after a
+    large-scale clean-up sprint.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import typing
+from pathlib import Path
+
+from rich.console import Console
+from rich.panel import Panel
+
+console = Console(color_system="standard", width=200)
+
+REPO_ROOT = Path(__file__).parents[3]
+
+_PROVIDE_SESSION_DECORATOR = "provide_session"
+
+# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in 
sync with the
+# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``.
+_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", 
"providers", "shared")
+
+
+def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool:
+    """Whether one of ``nodes`` is a ``@provide_session`` decorator.
+
+    Accepts both bare names (``@provide_session``) and attribute access
+    (``@something.provide_session``).
+    """
+    for node in nodes:
+        if isinstance(node, ast.Name) and node.id == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+        if isinstance(node, ast.Attribute) and node.attr == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+    return False
+
+
+def _session_is_positional(args: ast.arguments) -> ast.arg | None:
+    """Return the ``session`` arg if it is positional (not keyword-only).
+
+    Covers both regular positional args and positional-only args (``def 
f(session, /, ...)``).
+    """
+    for argument in (*args.posonlyargs, *args.args):
+        if argument.arg == "session":
+            return argument
+    return None
+
+
+def _iter_positional_session_in_provide_session(
+    path: Path,
+) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]:
+    """Yield ``@provide_session`` functions in *path* whose ``session`` is 
positional."""
+    try:
+        source = path.read_text(encoding="utf-8", errors="replace")
+    except OSError:
+        return
+    try:
+        tree = ast.parse(source, str(path))
+    except SyntaxError:
+        return
+    for node in ast.walk(tree):
+        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+            continue
+        if not _has_provide_session_decorator(node.decorator_list):
+            continue
+        argument = _session_is_positional(node.args)
+        if argument is None:
+            continue
+        yield node, argument
+
+
+def _count_violations(path: Path) -> int:
+    return sum(1 for _ in _iter_positional_session_in_provide_session(path))
+
+
+def _is_safe_relative(rel: str) -> bool:
+    """Whether ``rel`` is a plain relative path that stays inside 
``REPO_ROOT``.
+
+    Rejects absolute paths and any entry that resolves outside the repo root so
+    callers can ``relative_to(REPO_ROOT)`` without fear of a ``ValueError``.
+    """
+    candidate = Path(rel)
+    if candidate.is_absolute():
+        return False
+    try:
+        (REPO_ROOT / candidate).resolve().relative_to(REPO_ROOT.resolve())
+    except ValueError:
+        return False
+    return True
+
+
+class AllowlistManager:
+    def __init__(self, allowlist_file: Path) -> None:
+        self.allowlist_file = allowlist_file
+
+    def load(self) -> dict[str, int]:
+        if not self.allowlist_file.exists():
+            return {}
+
+        result: dict[str, int] = {}
+        for raw_line in self.allowlist_file.read_text().splitlines():
+            if not (stripped := raw_line.strip()):
+                continue
+
+            rel_str, _, count_str = stripped.rpartition("::")
+            if not rel_str or not count_str:
+                continue
+
+            try:
+                count = int(count_str)
+            except ValueError:
+                continue
+
+            if not _is_safe_relative(rel_str):
+                console.print(
+                    f"[yellow]Ignoring unsafe allowlist entry (escapes repo 
root):[/yellow] {rel_str}"
+                )
+                continue
+
+            result[rel_str] = count
+
+        return result
+
+    def save(self, counts: dict[str, int]) -> None:
+        lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())]
+        self.allowlist_file.write_text("\n".join(lines) + "\n")
+
+    def generate(self) -> int:
+        roots = ", ".join(_PROJECT_SOURCE_ROOTS)
+        console.print(
+            f"Scanning project source roots ([cyan]{roots}[/cyan]) under 
[cyan]{REPO_ROOT}[/cyan] "
+            "for @provide_session functions with positional session …"
+        )
+        counts: dict[str, int] = {}
+        for path in _iter_python_files():
+            n = _count_violations(path)
+            if n > 0:
+                counts[str(path.relative_to(REPO_ROOT))] = n
+
+        self.save(counts)
+        total = sum(counts.values())
+        console.print(
+            f"[green]Generated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] "
+            f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] 
offenders."
+        )
+        return 0
+
+    def cleanup(self) -> int:
+        allowlist = self.load()
+        if not allowlist:
+            console.print("[yellow]Allowlist is empty - nothing to clean 
up.[/yellow]")
+            return 0
+
+        stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / 
rel).exists()]
+        if stale:
+            console.print(
+                f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) 
== 1 else 'ies'}:[/yellow]"
+            )
+            for s in sorted(stale):
+                console.print(f"  [dim]-[/dim] {s}")
+            for s in stale:
+                del allowlist[s]
+            self.save(allowlist)
+            console.print(
+                f"\n[green]Updated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]"
+            )
+        else:
+            console.print("[green]No stale entries found.[/green]")
+        return 0
+
+
+def _iter_python_files() -> list[Path]:
+    candidates: list[Path] = []
+    for top in _PROJECT_SOURCE_ROOTS:
+        candidates.extend(
+            p.resolve()
+            for p in (REPO_ROOT / top).rglob("*.py")
+            if ".tox" not in p.parts and "__pycache__" not in p.parts
+        )
+    return candidates
+
+
+def _check_provide_session_kwargs(
+    files: list[Path], allowlist: dict[str, int], manager: AllowlistManager
+) -> int:
+    allowlist_file = manager.allowlist_file.resolve()
+    if any(p.resolve() == allowlist_file for p in files) and not 
allowlist_file.exists():
+        console.print(
+            Panel.fit(
+                f"Allowlist file [cyan]{allowlist_file}[/cyan] is missing.\n"
+                "It was passed to the hook but cannot be read, so the check 
cannot proceed.\n"
+                "Restore it from git or regenerate it with:\n\n"
+                "  [cyan]uv run 
./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]",
+                title="[red]Check failed[/red]",
+                border_style="red",
+            )
+        )
+        return 1
+
+    violations: list[tuple[Path, int, int]] = []
+    tightened: list[tuple[str, int, int]] = []
+
+    for path in files:
+        if not path.exists() or path.suffix != ".py":
+            continue
+        actual = _count_violations(path)
+        rel = str(path.relative_to(REPO_ROOT))
+        allowed = allowlist.get(rel, 0)
+        if actual > allowed:
+            violations.append((path, actual, allowed))
+        elif actual < allowed:
+            if actual == 0:
+                del allowlist[rel]
+            else:
+                allowlist[rel] = actual
+            tightened.append((rel, allowed, actual))
+
+    if tightened:
+        manager.save(allowlist)
+        console.print(
+            f"[green]Tightened {len(tightened)} entr{'y' if len(tightened) == 
1 else 'ies'} "
+            f"in 
[cyan]{manager.allowlist_file.relative_to(REPO_ROOT)}[/cyan][/green] "
+            "(stage the updated file):"
+        )
+        for rel, old, new in tightened:
+            console.print(f"  [cyan]{rel}[/cyan]  {old} -> {new}")
+
+    if violations:
+        console.print(
+            Panel.fit(
+                "New [bold]@provide_session[/bold] function with positional 
``session`` detected.\n"
+                "Move ``session`` after a bare ``*`` in the signature so 
callers must pass it by keyword:\n\n"
+                "  [cyan]@provide_session\n"
+                "  def foo(arg, *, session: Session = NEW_SESSION) -> None: 
...[/cyan]\n\n"
+                "If this usage is intentional and pre-existing, run:\n\n"
+                "  [cyan]uv run 
./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]\n\n"
+                "to regenerate the allowlist, then commit the updated\n"
+                
"[cyan]scripts/ci/prek/known_provide_session_positional.txt[/cyan].",
+                title="[red]Check failed[/red]",
+                border_style="red",
+            )
+        )
+        for path, actual, allowed in violations:
+            console.print(f"  [cyan]{path.relative_to(REPO_ROOT)}[/cyan]  
count={actual} (allowed={allowed})")
+            for func, argument in 
_iter_positional_session_in_provide_session(path):
+                console.print(f"      [dim]L{argument.lineno}[/dim] def 
{func.name}(...)")
+        return 1
+
+    return 1 if tightened else 0
+
+
+def main(argv: list[str] | None = None) -> int:
+    parser = argparse.ArgumentParser(
+        description="Prevent new @provide_session functions from declaring 
`session` positionally.",
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog=__doc__,
+    )
+    parser.add_argument("files", nargs="*", metavar="FILE", help="Files to 
check (provided by prek)")
+    parser.add_argument(
+        "--all-files",
+        action="store_true",
+        help=(
+            "Check every Python file under the project source roots "
+            "(airflow-core, airflow-ctl, task-sdk, providers, shared)"
+        ),
+    )
+    parser.add_argument(
+        "--cleanup",
+        action="store_true",
+        help="Remove stale entries from the allowlist and exit",
+    )
+    parser.add_argument(
+        "--generate",
+        action="store_true",
+        help="Regenerate the allowlist from the current codebase and exit",
+    )
+    args = parser.parse_args(argv)
+
+    manager = AllowlistManager(Path(__file__).parent / 
"known_provide_session_positional.txt")
+
+    if args.generate:
+        return manager.generate()
+
+    if args.cleanup:
+        return manager.cleanup()
+
+    allowlist = manager.load()
+
+    if args.all_files:
+        return _check_provide_session_kwargs(_iter_python_files(), allowlist, 
manager)
+
+    if not args.files:
+        console.print(
+            "[yellow]No files provided. Pass filenames or use --all-files to 
scan the whole repo.[/yellow]"
+        )
+        return 0
+
+    paths = [Path(f).resolve() for f in args.files]
+    paths = _expand_for_allowlist_edits(paths, manager, allowlist)
+    return _check_provide_session_kwargs(paths, allowlist, manager)
+
+
+def _expand_for_allowlist_edits(
+    paths: list[Path], manager: AllowlistManager, allowlist: dict[str, int]
+) -> list[Path]:
+    """Add allowlisted files when the allowlist itself is being changed.
+
+    Without this, a contributor could raise counts in
+    ``known_provide_session_positional.txt`` and the hook would do no 
validation
+    (since only the ``.txt`` file is passed), letting the loosened allowlist
+    sail through.
+
+    Both sides of the allowlist-file comparison are resolved so the detection 
is
+    robust to symlinks and unresolved inputs (the hook can be invoked with 
either).
+    """
+    allowlist_file = manager.allowlist_file.resolve()
+    if not any(p.resolve() == allowlist_file for p in paths):
+        return paths
+
+    expanded = list(paths)
+    seen = {p.resolve() for p in paths if p.suffix == ".py"}
+    for rel in allowlist:
+        candidate = (REPO_ROOT / rel).resolve()
+        if candidate.exists() and candidate not in seen:
+            seen.add(candidate)
+            expanded.append(candidate)

Review Comment:
   When the allowlist file is edited, `_expand_for_allowlist_edits()` only adds 
Python files that are present in the *current* loaded allowlist (`for rel in 
allowlist`). This allows a bypass: removing an entry from 
`known_provide_session_positional.txt` (or adding a brand-new entry for a file 
with violations) can make the hook validate nothing or validate only the newly 
allowlisted file, and still exit 0 even if there are positional-`session` 
offenders that are now unallowlisted.
   
   To prevent this, consider expanding to all project Python files (or at least 
union the *previous* allowlist entries via `git show HEAD:<file>` with the 
current allowlist) whenever the allowlist file itself is among the changed 
paths, so edits to the allowlist cannot reduce coverage.
   



##########
scripts/tests/ci/prek/test_check_provide_session_kwargs.py:
##########
@@ -0,0 +1,413 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import textwrap
+from pathlib import Path
+
+import pytest
+from ci.prek import check_provide_session_kwargs as hook
+from ci.prek.check_provide_session_kwargs import (
+    AllowlistManager,
+    _check_provide_session_kwargs,
+    _count_violations,
+    _expand_for_allowlist_edits,
+    _has_provide_session_decorator,
+    _iter_positional_session_in_provide_session,
+    _session_is_positional,
+)
+
+
[email protected]
+def find_violations(write_python_file):
+    """Factory fixture: write code to a temp file and return 
positional-session violations."""
+
+    def _check(code: str) -> list[tuple[ast.FunctionDef | 
ast.AsyncFunctionDef, ast.arg]]:
+        path = write_python_file(code)
+        return list(_iter_positional_session_in_provide_session(path))
+
+    return _check
+
+
[email protected]
+def fake_repo(tmp_path, monkeypatch):
+    """Create a fake repo layout and patch REPO_ROOT so paths resolve 
correctly."""
+    monkeypatch.setattr(hook, "REPO_ROOT", tmp_path)
+
+    def _write(rel: str, code: str) -> Path:
+        path = tmp_path / rel
+        path.parent.mkdir(parents=True, exist_ok=True)
+        path.write_text(textwrap.dedent(code))
+        return path
+
+    return _write
+
+
+class TestHasProvideSessionDecorator:
+    def test_provide_session_name(self):
+        func = ast.parse("@provide_session\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+    def test_provide_session_attribute(self):
+        func = ast.parse("@utils.provide_session\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+    def test_no_decorator(self):
+        func = ast.parse("def foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is False
+
+    def test_unrelated_decorator(self):
+        func = ast.parse("@staticmethod\ndef foo(): pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is False
+
+    def test_multiple_decorators_including_provide_session(self):
+        func = ast.parse("@staticmethod\n@provide_session\ndef foo(): 
pass").body[0]
+        assert _has_provide_session_decorator(func.decorator_list) is True
+
+
+class TestSessionIsPositional:
+    def test_no_session_arg(self):
+        func = ast.parse("def foo(x, y): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional(self):
+        func = ast.parse("def foo(session=NEW_SESSION): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+    def test_session_keyword_only(self):
+        func = ast.parse("def foo(*, session=NEW_SESSION): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional_among_other_args(self):
+        func = ast.parse("def foo(x, y, session=NEW_SESSION): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+    def test_session_kwonly_after_other_positional(self):
+        func = ast.parse("def foo(x, y, *, session=NEW_SESSION): pass").body[0]
+        assert _session_is_positional(func.args) is None
+
+    def test_session_positional_only(self):
+        func = ast.parse("def foo(session, /, x): pass").body[0]
+        argument = _session_is_positional(func.args)
+        assert argument is not None
+        assert argument.arg == "session"
+
+
+class TestIterPositionalSessionInProvideSession:
+    def test_keyword_only_session_is_clean(self, find_violations):
+        code = """\
+        @provide_session
+        def foo(*, session=NEW_SESSION):
+            pass
+        """
+        assert find_violations(code) == []
+
+    def test_positional_session_is_flagged(self, find_violations):
+        code = """\
+        @provide_session
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+        func, argument = violations[0]
+        assert func.name == "foo"
+        assert argument.arg == "session"
+
+    def test_no_provide_session_decorator_is_ignored(self, find_violations):
+        code = """\
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        assert find_violations(code) == []
+
+    def test_async_function_with_positional_session_is_flagged(self, 
find_violations):
+        code = """\
+        @provide_session
+        async def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+
+    def test_method_with_positional_session_is_flagged(self, find_violations):
+        code = """\
+        class C:
+            @provide_session
+            def foo(self, session=NEW_SESSION):
+                pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+        assert violations[0][0].name == "foo"
+
+    def test_attribute_decorator_is_recognised(self, find_violations):
+        code = """\
+        @airflow.utils.session.provide_session
+        def foo(session=NEW_SESSION):
+            pass
+        """
+        violations = find_violations(code)
+        assert len(violations) == 1
+
+    def test_count_violations_multiple_in_file(self, write_python_file):
+        code = """\
+        @provide_session
+        def a(session=NEW_SESSION):
+            pass
+
+        @provide_session
+        def b(x, session=NEW_SESSION):
+            pass
+
+        @provide_session
+        def c(*, session=NEW_SESSION):
+            pass
+        """
+        path = write_python_file(code)
+        assert _count_violations(path) == 2
+
+    def test_syntax_error_returns_no_violations(self, write_python_file):
+        path = write_python_file("def foo(:\n    pass")
+        assert _count_violations(path) == 0
+
+    def test_invalid_utf8_does_not_crash(self, tmp_path):
+        path = tmp_path / "invalid_utf8.py"
+        path.write_bytes(b"# bad byte: \xff\n@provide_session\ndef 
foo(session=NEW_SESSION):\n    pass\n")
+
+        assert _count_violations(path) == 1
+
+
+class TestAllowlistManager:
+    def test_load_missing_file_returns_empty(self, tmp_path):
+        manager = AllowlistManager(tmp_path / "missing.txt")
+        assert manager.load() == {}
+
+    def test_save_and_load_round_trip(self, tmp_path):
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        manager.save({"b/file.py": 2, "a/file.py": 1})
+        # Sorted by key in the file
+        text = (tmp_path / "allowlist.txt").read_text()
+        assert text.splitlines() == ["a/file.py::1", "b/file.py::2"]
+        assert manager.load() == {"a/file.py": 1, "b/file.py": 2}
+
+    def test_load_skips_blank_and_malformed_lines(self, tmp_path):
+        path = tmp_path / "allowlist.txt"
+        path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n")
+        assert AllowlistManager(path).load() == {"valid/file.py": 3}
+
+    def test_load_skips_unsafe_entries(self, fake_repo, tmp_path):
+        """Entries that escape REPO_ROOT (absolute paths or `..` segments) are 
ignored."""
+        path = tmp_path / "allowlist.txt"
+        
path.write_text("airflow-core/src/airflow/safe.py::1\n../escape.py::1\n/etc/passwd::1\n")
+        # `fake_repo` patches REPO_ROOT to tmp_path so the safety check is 
meaningful.
+        assert AllowlistManager(path).load() == 
{"airflow-core/src/airflow/safe.py": 1}
+
+
+class TestCheckProvideSessionKwargs:
+    def test_no_violations_in_clean_file(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/clean.py",
+            """\
+            @provide_session
+            def foo(*, session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _check_provide_session_kwargs([path], {}, manager) == 0
+
+    def test_new_violation_fails(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/bad.py",
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _check_provide_session_kwargs([path], {}, manager) == 1
+
+    def test_violation_within_allowlist_passes(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/grandfathered.py",
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/grandfathered.py": 1}
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 0
+
+    def test_exceeding_allowlist_fails(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/grew.py",
+            """\
+            @provide_session
+            def a(session=NEW_SESSION):
+                pass
+
+            @provide_session
+            def b(session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/grew.py": 1}
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+
+    def test_reducing_violations_tightens_allowlist(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/improved.py",
+            """\
+            @provide_session
+            def foo(session=NEW_SESSION):
+                pass
+
+            @provide_session
+            def bar(*, session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/improved.py": 2}
+        # Exit non-zero so pre-commit reports the modified allowlist
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+        assert manager.load() == {"airflow-core/src/airflow/improved.py": 1}
+
+    def test_fixing_all_violations_removes_entry(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/fixed.py",
+            """\
+            @provide_session
+            def foo(*, session=NEW_SESSION):
+                pass
+            """,
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        allowlist = {"airflow-core/src/airflow/fixed.py": 1}
+        assert _check_provide_session_kwargs([path], allowlist, manager) == 1
+        assert manager.load() == {}
+
+    def test_non_python_file_is_skipped(self, fake_repo, tmp_path):
+        path = fake_repo(
+            "airflow-core/src/airflow/not_python.txt", "@provide_session\ndef 
foo(session=N): pass\n"
+        )
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _check_provide_session_kwargs([path], {}, manager) == 0
+
+    def test_missing_allowlist_file_fails_loudly(self, fake_repo, tmp_path):
+        """Passing the allowlist path when the file is missing must fail, not 
silently pass."""
+        allowlist_path = tmp_path / "allowlist.txt"
+        manager = AllowlistManager(allowlist_path)
+        assert not allowlist_path.exists()
+        assert _check_provide_session_kwargs([allowlist_path.resolve()], {}, 
manager) == 1
+
+
+class TestExpandForAllowlistEdits:
+    def test_unchanged_when_allowlist_not_in_paths(self, fake_repo, tmp_path):
+        py = fake_repo("airflow-core/src/airflow/x.py", "pass")
+        manager = AllowlistManager(tmp_path / "allowlist.txt")
+        assert _expand_for_allowlist_edits([py], manager, 
{"airflow-core/src/airflow/x.py": 1}) == [py]
+
+    def test_appends_allowlisted_files_when_allowlist_edited(self, fake_repo, 
tmp_path):
+        allowlist_path = tmp_path / "allowlist.txt"
+        manager = AllowlistManager(allowlist_path)
+        listed = fake_repo("airflow-core/src/airflow/listed.py", "pass")
+        # Pass a resolved path — matches production behavior (``main()`` 
resolves argv).
+        result = _expand_for_allowlist_edits(
+            [allowlist_path.resolve()],
+            manager,
+            {"airflow-core/src/airflow/listed.py": 1, 
"airflow-core/src/airflow/gone.py": 1},
+        )
+        assert allowlist_path.resolve() in result
+        assert listed in result
+        # File in allowlist that does not exist on disk should be ignored.
+        assert (tmp_path / "airflow-core/src/airflow/gone.py").resolve() not 
in result
+

Review Comment:
   The allowlist-edit expansion behavior is covered for "loosen count" 
scenarios, but there isn't a regression test for the bypass where an allowlist 
entry is *removed* (or a brand-new entry is added) and only the allowlist file 
is passed to the hook. Adding a test for those cases would help ensure edits to 
`known_provide_session_positional.txt` cannot silently skip validation of 
affected Python files.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to