Lee-W commented on code in PR #67150:
URL: https://github.com/apache/airflow/pull/67150#discussion_r3287734806
##########
scripts/ci/prek/known_provide_session_positional.txt:
##########
@@ -0,0 +1,89 @@
+airflow-core/src/airflow/api/common/delete_dag.py::1
Review Comment:
Are these things we want to exclude or do? Or both?
##########
.pre-commit-config.yaml:
##########
@@ -1064,6 +1064,12 @@ repos:
language: python
pass_filenames: true
files: ^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$
+ - id: check-no-new-provide-session-positional
+ name: Check that no new @provide_session functions declare `session`
positionally
+ entry: ./scripts/ci/prek/check_provide_session_kwargs.py
+ language: python
+ pass_filenames: true
+ files:
^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$|^scripts/ci/prek/known_provide_session_positional\.txt$|^scripts/ci/prek/check_provide_session_kwargs\.py$
Review Comment:
I thought we're not suppose to have provide_session in task-sdk
##########
scripts/ci/prek/check_provide_session_kwargs.py:
##########
@@ -0,0 +1,430 @@
+#!/usr/bin/env python
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "rich>=13.0.0",
+# ]
+# ///
+"""Check that no new ``@provide_session`` functions declare ``session``
positionally.
+
+The project convention is that any function decorated with ``@provide_session``
+must declare ``session`` as keyword-only (after a bare ``*`` in the signature),
+so callers cannot pass it positionally by accident. See
+``contributing-docs/05_pull_requests.rst#database-session-handling``.
+
+All *existing* offenders are recorded in
``known_provide_session_positional.txt``
+next to this script as ``relative/path::N`` entries (one per file), where ``N``
+is the maximum number of ``@provide_session`` functions with a positional
+``session`` argument allowed in that file. A file whose current count exceeds
+the recorded limit is treated as a violation – move the ``session`` argument
+behind a bare ``*`` instead.
+
+Modes
+-----
+Default (files passed by prek/pre-commit):
+ Check only the supplied files; fail if any file's count exceeds the limit.
+ When a file's count has *decreased*, the allowlist entry is tightened
+ automatically and the hook exits with a non-zero code so that pre-commit
+ reports the modified allowlist – just stage
+ ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run.
+
+``--all-files``:
+ Walk every ``.py`` file under the project source roots
+ (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``,
``shared``) —
+ the same scope the pre-commit hook applies to.
+
+``--cleanup``:
+ Remove entries for files that no longer exist. Safe to run at any time;
+ does not add new entries or raise limits.
+
+``--generate``:
+ Scan the same project source roots as ``--all-files`` and *rebuild* the
+ allowlist from scratch. Intended for the initial setup or after a
+ large-scale clean-up sprint.
+"""
+
+from __future__ import annotations
+
+import argparse
+import ast
+import subprocess
+import typing
+from pathlib import Path
+
+from rich.console import Console
+from rich.panel import Panel
+
+console = Console(color_system="standard", width=200)
+
+REPO_ROOT = Path(__file__).parents[3]
+
+_PROVIDE_SESSION_DECORATOR = "provide_session"
+
+# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in
sync with the
+# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``.
+_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk",
"providers", "shared")
+
+
+def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool:
+ """Whether one of ``nodes`` is a ``@provide_session`` decorator.
+
+ Accepts both bare names (``@provide_session``) and attribute access
+ (``@something.provide_session``).
+ """
+ for node in nodes:
+ if isinstance(node, ast.Name) and node.id ==
_PROVIDE_SESSION_DECORATOR:
+ return True
+ if isinstance(node, ast.Attribute) and node.attr ==
_PROVIDE_SESSION_DECORATOR:
+ return True
+ return False
+
+
+def _session_is_positional(args: ast.arguments) -> ast.arg | None:
+ """Return the ``session`` arg if it is positional (not keyword-only).
+
+ Covers both regular positional args and positional-only args (``def
f(session, /, ...)``).
+ """
+ for argument in (*args.posonlyargs, *args.args):
+ if argument.arg == "session":
+ return argument
+ return None
+
+
+def _iter_positional_session_in_provide_session(
+ path: Path,
+) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]:
+ """Yield ``@provide_session`` functions in *path* whose ``session`` is
positional."""
+ try:
+ source = path.read_text(encoding="utf-8", errors="replace")
+ except OSError:
+ return
+ try:
+ tree = ast.parse(source, str(path))
+ except SyntaxError:
+ return
+ for node in ast.walk(tree):
+ if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+ continue
+ if not _has_provide_session_decorator(node.decorator_list):
+ continue
+ argument = _session_is_positional(node.args)
+ if argument is None:
+ continue
+ yield node, argument
+
+
+def _count_violations(path: Path) -> int:
+ return sum(1 for _ in _iter_positional_session_in_provide_session(path))
+
+
+def _is_safe_relative(rel: str) -> bool:
+ """Whether ``rel`` is a plain relative path that stays inside
``REPO_ROOT``.
+
+ Rejects absolute paths and any entry that resolves outside the repo root so
+ callers can ``relative_to(REPO_ROOT)`` without fear of a ``ValueError``.
+ """
+ candidate = Path(rel)
+ if candidate.is_absolute():
+ return False
+ try:
+ (REPO_ROOT / candidate).resolve().relative_to(REPO_ROOT.resolve())
+ except ValueError:
+ return False
+ return True
+
+
+class AllowlistManager:
+ def __init__(self, allowlist_file: Path) -> None:
+ self.allowlist_file = allowlist_file
+
+ @staticmethod
+ def parse(text: str) -> dict[str, int]:
+ """Parse allowlist *text* into a ``{rel_path: count}`` mapping.
+
+ Same validation rules as :meth:`load` so we can reuse parsing for the
+ on-disk allowlist *and* for the previous version fetched from git when
+ guarding against entry-removal bypasses.
+ """
+ result: dict[str, int] = {}
+ for raw_line in text.splitlines():
+ if not (stripped := raw_line.strip()):
+ continue
+
+ rel_str, _, count_str = stripped.rpartition("::")
+ if not rel_str or not count_str:
+ continue
+
+ try:
+ count = int(count_str)
+ except ValueError:
+ continue
+
+ if not _is_safe_relative(rel_str):
+ console.print(
+ f"[yellow]Ignoring unsafe allowlist entry (escapes repo
root):[/yellow] {rel_str}"
+ )
+ continue
+
+ result[rel_str] = count
+
+ return result
+
+ def load(self) -> dict[str, int]:
+ if not self.allowlist_file.exists():
+ return {}
+ return self.parse(self.allowlist_file.read_text())
+
+ def save(self, counts: dict[str, int]) -> None:
+ lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())]
+ self.allowlist_file.write_text("\n".join(lines) + "\n")
+
+ def generate(self) -> int:
+ roots = ", ".join(_PROJECT_SOURCE_ROOTS)
+ console.print(
+ f"Scanning project source roots ([cyan]{roots}[/cyan]) under
[cyan]{REPO_ROOT}[/cyan] "
+ "for @provide_session functions with positional session …"
+ )
+ counts: dict[str, int] = {}
+ for path in _iter_python_files():
+ n = _count_violations(path)
+ if n > 0:
+ counts[str(path.relative_to(REPO_ROOT))] = n
+
+ self.save(counts)
+ total = sum(counts.values())
+ console.print(
+ f"[green]Generated[/green]
[cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] "
+ f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold]
offenders."
+ )
+ return 0
+
+ def cleanup(self) -> int:
+ allowlist = self.load()
Review Comment:
by loading, I thought we're to save it as a class vairable 🤔
##########
.pre-commit-config.yaml:
##########
@@ -1064,6 +1064,12 @@ repos:
language: python
pass_filenames: true
files: ^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$
+ - id: check-no-new-provide-session-positional
+ name: Check that no new @provide_session functions declare `session`
positionally
+ entry: ./scripts/ci/prek/check_provide_session_kwargs.py
+ language: python
+ pass_filenames: true
+ files:
^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$|^scripts/ci/prek/known_provide_session_positional\.txt$|^scripts/ci/prek/check_provide_session_kwargs\.py$
Review Comment:
looks like so.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]