Copilot commented on code in PR #67150:
URL: https://github.com/apache/airflow/pull/67150#discussion_r3263896858


##########
scripts/tests/ci/prek/test_check_provide_session_kwargs.py:
##########
@@ -0,0 +1,395 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import textwrap
+from pathlib import Path
+
+import pytest
+from ci.prek.check_provide_session_kwargs import (
+    AllowlistManager,
+    _check_provide_session_kwargs,
+    _count_violations,
+    _expand_for_allowlist_edits,
+    _has_provide_session_decorator,
+    _iter_positional_session_in_provide_session,
+    _session_is_positional,
+)
+
+
[email protected]
+def find_violations(write_python_file):
+    """Factory fixture: write code to a temp file and return 
positional-session violations."""
+
+    def _check(code: str) -> list[tuple[ast.FunctionDef | 
ast.AsyncFunctionDef, ast.arg]]:
+        path = write_python_file(code)
+        return list(_iter_positional_session_in_provide_session(path))
+
+    return _check
+
+
[email protected]
+def fake_repo(tmp_path, monkeypatch):
+    """Create a fake repo layout and patch REPO_ROOT so paths resolve 
correctly."""
+    import ci.prek.check_provide_session_kwargs as hook
+
+    monkeypatch.setattr(hook, "REPO_ROOT", tmp_path)

Review Comment:
   The fixture performs a module import inside the function body. This violates 
the repo's general rule of keeping imports at module top-level (unless needed 
for circular-import avoidance or lazy loading). Since this module is already 
imported at the top of the file, consider moving `import 
ci.prek.check_provide_session_kwargs as hook` to the top level (or reusing an 
already-imported module object) and patching that instead.



##########
scripts/ci/prek/check_provide_session_kwargs.py:
##########
@@ -0,0 +1,352 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#   "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session`` 
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in 
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+    Check only the supplied files; fail if any file's count exceeds the limit.
+    When a file's count has *decreased*, the allowlist entry is tightened
+    automatically and the hook exits with a non-zero code so that pre-commit
+    reports the modified allowlist – just stage
+    ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+    Walk every ``.py`` file under the project source roots
+    (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, 
``shared``) —
+    the same scope the pre-commit hook applies to.
+
+``--cleanup``:
+    Remove entries for files that no longer exist. Safe to run at any time;
+    does not add new entries or raise limits.
+
+``--generate``:
+    Scan the same project source roots as ``--all-files`` and *rebuild* the
+    allowlist from scratch. Intended for the initial setup or after a
+    large-scale clean-up sprint.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import typing
+from pathlib import Path
+
+from rich.console import Console
+from rich.panel import Panel
+
+console = Console(color_system="standard", width=200)
+
+REPO_ROOT = Path(__file__).parents[3]
+
+_PROVIDE_SESSION_DECORATOR = "provide_session"
+
+# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in 
sync with the
+# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``.
+_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", 
"providers", "shared")
+
+
+def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool:
+    """Whether one of ``nodes`` is a ``@provide_session`` decorator.
+
+    Accepts both bare names (``@provide_session``) and attribute access
+    (``@something.provide_session``).
+    """
+    for node in nodes:
+        if isinstance(node, ast.Name) and node.id == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+        if isinstance(node, ast.Attribute) and node.attr == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+    return False
+
+
+def _session_is_positional(args: ast.arguments) -> ast.arg | None:
+    """Return the ``session`` arg if it is positional (not keyword-only).
+
+    Covers both regular positional args and positional-only args (``def 
f(session, /, ...)``).
+    """
+    for argument in (*args.posonlyargs, *args.args):
+        if argument.arg == "session":
+            return argument
+    return None
+
+
+def _iter_positional_session_in_provide_session(
+    path: Path,
+) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]:
+    """Yield ``@provide_session`` functions in *path* whose ``session`` is 
positional."""
+    try:
+        source = path.read_text(encoding="utf-8")
+    except OSError:
+        return
+    try:
+        tree = ast.parse(source, str(path))
+    except SyntaxError:
+        return
+    for node in ast.walk(tree):
+        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+            continue
+        if not _has_provide_session_decorator(node.decorator_list):
+            continue
+        argument = _session_is_positional(node.args)
+        if argument is None:
+            continue
+        yield node, argument
+
+
+def _count_violations(path: Path) -> int:
+    return sum(1 for _ in _iter_positional_session_in_provide_session(path))
+
+
+class AllowlistManager:
+    def __init__(self, allowlist_file: Path) -> None:
+        self.allowlist_file = allowlist_file
+
+    def load(self) -> dict[str, int]:
+        if not self.allowlist_file.exists():
+            return {}
+
+        result: dict[str, int] = {}
+        for raw_line in self.allowlist_file.read_text().splitlines():
+            if not (stripped := raw_line.strip()):
+                continue
+
+            rel_str, _, count_str = stripped.rpartition("::")
+            if not rel_str or not count_str:
+                continue
+
+            try:
+                result[rel_str] = int(count_str)
+            except ValueError:
+                continue
+
+        return result
+
+    def save(self, counts: dict[str, int]) -> None:
+        lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())]
+        self.allowlist_file.write_text("\n".join(lines) + "\n")
+
+    def generate(self) -> int:
+        roots = ", ".join(_PROJECT_SOURCE_ROOTS)
+        console.print(
+            f"Scanning project source roots ([cyan]{roots}[/cyan]) under 
[cyan]{REPO_ROOT}[/cyan] "
+            "for @provide_session functions with positional session …"
+        )
+        counts: dict[str, int] = {}
+        for path in _iter_python_files():
+            n = _count_violations(path)
+            if n > 0:
+                counts[str(path.relative_to(REPO_ROOT))] = n
+
+        self.save(counts)
+        total = sum(counts.values())
+        console.print(
+            f"[green]Generated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] "
+            f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] 
offenders."
+        )
+        return 0
+
+    def cleanup(self) -> int:
+        allowlist = self.load()
+        if not allowlist:
+            console.print("[yellow]Allowlist is empty - nothing to clean 
up.[/yellow]")
+            return 0
+
+        stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / 
rel).exists()]
+        if stale:
+            console.print(
+                f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) 
== 1 else 'ies'}:[/yellow]"
+            )
+            for s in sorted(stale):
+                console.print(f"  [dim]-[/dim] {s}")
+            for s in stale:
+                del allowlist[s]
+            self.save(allowlist)
+            console.print(
+                f"\n[green]Updated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]"
+            )
+        else:
+            console.print("[green]No stale entries found.[/green]")
+        return 0
+
+
+def _iter_python_files() -> list[Path]:
+    candidates: list[Path] = []
+    for top in _PROJECT_SOURCE_ROOTS:
+        candidates.extend(
+            p.resolve()
+            for p in (REPO_ROOT / top).rglob("*.py")
+            if ".tox" not in p.parts and "__pycache__" not in p.parts
+        )
+    return candidates
+
+
+def _check_provide_session_kwargs(
+    files: list[Path], allowlist: dict[str, int], manager: AllowlistManager
+) -> int:
+    violations: list[tuple[Path, int, int]] = []
+    tightened: list[tuple[str, int, int]] = []
+
+    for path in files:
+        if not path.exists() or path.suffix != ".py":
+            continue
+        actual = _count_violations(path)
+        rel = str(path.relative_to(REPO_ROOT))
+        allowed = allowlist.get(rel, 0)

Review Comment:
   `path.relative_to(REPO_ROOT)` will raise `ValueError` if a file path is 
outside the repo root. This can happen if 
`known_provide_session_positional.txt` is edited to contain a path with `..` 
segments or an absolute path (and `_expand_for_allowlist_edits()` resolves it). 
Consider guarding the `relative_to()` call (or validating allowlist entries 
when loading/expanding) so the hook fails gracefully with a clear error instead 
of crashing.



##########
scripts/ci/prek/check_provide_session_kwargs.py:
##########
@@ -0,0 +1,352 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#   "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session`` 
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in 
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+    Check only the supplied files; fail if any file's count exceeds the limit.
+    When a file's count has *decreased*, the allowlist entry is tightened
+    automatically and the hook exits with a non-zero code so that pre-commit
+    reports the modified allowlist – just stage
+    ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+    Walk every ``.py`` file under the project source roots
+    (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, 
``shared``) —
+    the same scope the pre-commit hook applies to.
+
+``--cleanup``:
+    Remove entries for files that no longer exist. Safe to run at any time;
+    does not add new entries or raise limits.
+
+``--generate``:
+    Scan the same project source roots as ``--all-files`` and *rebuild* the
+    allowlist from scratch. Intended for the initial setup or after a
+    large-scale clean-up sprint.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import typing
+from pathlib import Path
+
+from rich.console import Console
+from rich.panel import Panel
+
+console = Console(color_system="standard", width=200)
+
+REPO_ROOT = Path(__file__).parents[3]
+
+_PROVIDE_SESSION_DECORATOR = "provide_session"
+
+# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in 
sync with the
+# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``.
+_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", 
"providers", "shared")
+
+
+def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool:
+    """Whether one of ``nodes`` is a ``@provide_session`` decorator.
+
+    Accepts both bare names (``@provide_session``) and attribute access
+    (``@something.provide_session``).
+    """
+    for node in nodes:
+        if isinstance(node, ast.Name) and node.id == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+        if isinstance(node, ast.Attribute) and node.attr == 
_PROVIDE_SESSION_DECORATOR:
+            return True
+    return False
+
+
+def _session_is_positional(args: ast.arguments) -> ast.arg | None:
+    """Return the ``session`` arg if it is positional (not keyword-only).
+
+    Covers both regular positional args and positional-only args (``def 
f(session, /, ...)``).
+    """
+    for argument in (*args.posonlyargs, *args.args):
+        if argument.arg == "session":
+            return argument
+    return None
+
+
+def _iter_positional_session_in_provide_session(
+    path: Path,
+) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]:
+    """Yield ``@provide_session`` functions in *path* whose ``session`` is 
positional."""
+    try:
+        source = path.read_text(encoding="utf-8")
+    except OSError:
+        return
+    try:
+        tree = ast.parse(source, str(path))
+    except SyntaxError:
+        return
+    for node in ast.walk(tree):
+        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+            continue
+        if not _has_provide_session_decorator(node.decorator_list):
+            continue
+        argument = _session_is_positional(node.args)
+        if argument is None:
+            continue
+        yield node, argument
+
+
+def _count_violations(path: Path) -> int:
+    return sum(1 for _ in _iter_positional_session_in_provide_session(path))
+
+
+class AllowlistManager:
+    def __init__(self, allowlist_file: Path) -> None:
+        self.allowlist_file = allowlist_file
+
+    def load(self) -> dict[str, int]:
+        if not self.allowlist_file.exists():
+            return {}
+
+        result: dict[str, int] = {}
+        for raw_line in self.allowlist_file.read_text().splitlines():
+            if not (stripped := raw_line.strip()):
+                continue
+
+            rel_str, _, count_str = stripped.rpartition("::")
+            if not rel_str or not count_str:
+                continue
+
+            try:
+                result[rel_str] = int(count_str)
+            except ValueError:
+                continue
+
+        return result
+
+    def save(self, counts: dict[str, int]) -> None:
+        lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())]
+        self.allowlist_file.write_text("\n".join(lines) + "\n")
+
+    def generate(self) -> int:
+        roots = ", ".join(_PROJECT_SOURCE_ROOTS)
+        console.print(
+            f"Scanning project source roots ([cyan]{roots}[/cyan]) under 
[cyan]{REPO_ROOT}[/cyan] "
+            "for @provide_session functions with positional session …"
+        )
+        counts: dict[str, int] = {}
+        for path in _iter_python_files():
+            n = _count_violations(path)
+            if n > 0:
+                counts[str(path.relative_to(REPO_ROOT))] = n
+
+        self.save(counts)
+        total = sum(counts.values())
+        console.print(
+            f"[green]Generated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] "
+            f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] 
offenders."
+        )
+        return 0
+
+    def cleanup(self) -> int:
+        allowlist = self.load()
+        if not allowlist:
+            console.print("[yellow]Allowlist is empty - nothing to clean 
up.[/yellow]")
+            return 0
+
+        stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / 
rel).exists()]
+        if stale:
+            console.print(
+                f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) 
== 1 else 'ies'}:[/yellow]"
+            )
+            for s in sorted(stale):
+                console.print(f"  [dim]-[/dim] {s}")
+            for s in stale:
+                del allowlist[s]
+            self.save(allowlist)
+            console.print(
+                f"\n[green]Updated[/green] 
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]"
+            )
+        else:
+            console.print("[green]No stale entries found.[/green]")
+        return 0
+
+
+def _iter_python_files() -> list[Path]:
+    candidates: list[Path] = []
+    for top in _PROJECT_SOURCE_ROOTS:
+        candidates.extend(
+            p.resolve()
+            for p in (REPO_ROOT / top).rglob("*.py")
+            if ".tox" not in p.parts and "__pycache__" not in p.parts
+        )
+    return candidates
+
+
+def _check_provide_session_kwargs(
+    files: list[Path], allowlist: dict[str, int], manager: AllowlistManager
+) -> int:
+    violations: list[tuple[Path, int, int]] = []
+    tightened: list[tuple[str, int, int]] = []
+
+    for path in files:
+        if not path.exists() or path.suffix != ".py":
+            continue
+        actual = _count_violations(path)

Review Comment:
   `_check_provide_session_kwargs()` silently skips non-existent paths. If the 
allowlist file itself is deleted/renamed in a commit, prek/pre-commit will 
still pass its (now non-existent) path to this hook, and the hook can end up 
doing no checks and returning success. Consider explicitly detecting when the 
configured allowlist file is among the input paths but is missing, and failing 
with an instruction to restore or regenerate it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to