This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 49958a5000a Add static check ensuring trigger `__init__()` and 
`serialize()` stay in sync (#66960)
49958a5000a is described below

commit 49958a5000a25fe627efb42917d94240ba5bf4df
Author: Shahar Epstein <[email protected]>
AuthorDate: Sat May 16 23:58:05 2026 +0300

    Add static check ensuring trigger `__init__()` and `serialize()` stay in 
sync (#66960)
    
    * Add static check ensuring trigger __init__ and serialize() stay in sync
    
    Trigger __init__ and serialize() are written as a pair: any __init__
    parameter that serialize() does not return is silently dropped when the
    triggerer re-instantiates the trigger, falling back to the parameter's
    default. This adds an AST-based prek static check over provider triggers
    that flags such mismatches, resolving __init__/serialize() pairs through
    in-file base classes (including **super().serialize() spreads).
    
    Five existing violations are excluded as KNOWN_VIOLATIONS pending a
    follow-up fix; three by-design cases (deprecated/aliased params folded
    into their replacement at construction time) are permanently excluded.
    
    * Update scripts/ci/prek/check_trigger_serialize_init.py
    
    Co-authored-by: Wei Lee <[email protected]>
    
    * Rename noun-style helper functions to _get_/get_ form
    
    Rename based on Lee-W's review feedback:
    - _init_param_names → _get_init_param_names
    - _base_simple_names → _get_base_simple_names
    - _method → _get_method
    - _serialize_keys → _get_serialize_keys
    - _super_serialize_keys → _get_super_serialize_keys
    - violations → get_violations
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
---
 providers/.pre-commit-config.yaml               |   7 +
 scripts/ci/prek/check_trigger_serialize_init.py | 283 ++++++++++++++++++++++++
 2 files changed, 290 insertions(+)

diff --git a/providers/.pre-commit-config.yaml 
b/providers/.pre-commit-config.yaml
index 4dd1b892de0..652477809de 100644
--- a/providers/.pre-commit-config.yaml
+++ b/providers/.pre-commit-config.yaml
@@ -30,6 +30,13 @@ repos:
         entry: ../scripts/ci/prek/check_deferrable_default.py
         pass_filenames: false
         files: ^(.*/)?airflow/.*/(sensors|operators)/.*\.py$
+      - id: check-trigger-serialize-init
+        name: Check trigger __init__ and serialize() are in sync
+        description: Every trigger __init__ parameter must appear in its 
serialize() return dict
+        language: python
+        entry: ../scripts/ci/prek/check_trigger_serialize_init.py
+        pass_filenames: true
+        files: ^.*/src/airflow/providers/.*/triggers/[^/]+\.py$
       - id: update-providers-dependencies
         name: Update dependencies for providers
         entry: ../scripts/ci/prek/update_providers_dependencies.py
diff --git a/scripts/ci/prek/check_trigger_serialize_init.py 
b/scripts/ci/prek/check_trigger_serialize_init.py
new file mode 100755
index 00000000000..3866d779269
--- /dev/null
+++ b/scripts/ci/prek/check_trigger_serialize_init.py
@@ -0,0 +1,283 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10,<3.11"
+# dependencies = [
+#   "rich>=13.6.0",
+# ]
+# ///
+"""
+Check that provider ``Trigger`` classes keep ``__init__`` and ``serialize()`` 
in sync.
+
+``__init__`` and ``serialize()`` are written as a pair: a trigger is 
instantiated once when the
+operator defers, then serialized and re-instantiated on whichever triggerer 
process runs it. Any
+``__init__`` parameter that ``serialize()`` does not return is silently 
dropped on that round-trip
+-- the reconstructed trigger falls back to the parameter's default. See
+https://github.com/apache/airflow/blob/main/airflow-core/docs/authoring-and-scheduling/deferring.rst
+
+This check parses each provider trigger module with ``ast`` and, for every 
trigger class whose
+``__init__`` signature *and* ``serialize()`` return dict can both be resolved 
statically (including
+through in-file base classes), flags any ``__init__`` parameter missing from 
the ``serialize()``
+return dict.
+
+Classes whose ``serialize()`` is built dynamically (``**spread`` of a 
non-``super()`` value,
+``.update()``, returning a variable, ...) or that inherit 
``__init__``/``serialize()`` from a base
+class defined in another file cannot be resolved statically and are skipped -- 
the check never
+guesses, so it produces no false positives.
+"""
+
+from __future__ import annotations
+
+import ast
+import sys
+from pathlib import Path
+
+from common_prek_utils import AIRFLOW_PROVIDERS_ROOT_PATH, console
+
+DEFERRING_DOC = (
+    
"https://github.com/apache/airflow/blob/main/airflow-core/docs/authoring-and-scheduling/deferring.rst";
+)
+
+# Key format for both sets below: "<path relative to the providers/ 
directory>::<ClassName>".
+
+# Trigger classes that genuinely violate the __init__/serialize() contract 
today. They predate
+# the check and are excluded so it can be enabled without a tree-wide fix; 
each is tracked for a
+# follow-up fix in a separate PR. Do NOT add new entries here -- fix the 
trigger instead.
+KNOWN_VIOLATIONS: set[str] = {
+    # `caller` is passed straight to DatabricksHook(caller=...) and never 
stored/serialized, so it
+    # falls back to the class-name default on a triggerer round-trip.
+    
"databricks/src/airflow/providers/databricks/triggers/databricks.py::DatabricksExecutionTrigger",
+    
"databricks/src/airflow/providers/databricks/triggers/databricks.py::DatabricksSQLStatementExecutionTrigger",
+    # `dataset_id`, `table_id`, `poll_interval` are forwarded to the parent 
__init__ and used, but
+    # the overridden serialize() omits them.
+    
"google/src/airflow/providers/google/cloud/triggers/bigquery.py::BigQueryIntervalCheckTrigger",
+    # `poll_interval` and `impersonation_chain` are stored and used but 
missing from serialize().
+    
"google/src/airflow/providers/google/cloud/triggers/datafusion.py::DataFusionStartPipelineTrigger",
+    # `endpoint_prefix` is stored as self._endpoint_prefix and used in run() 
but missing from serialize().
+    
"apache/livy/src/airflow/providers/apache/livy/triggers/livy.py::LivyTrigger",
+}
+
+# Trigger classes that the check flags but which are correct *by design*: an 
old/aliased parameter
+# is folded into its replacement at construction time, so it does not need to 
round-trip. These are
+# permanent exclusions, not tech debt.
+BY_DESIGN_EXCLUSIONS: set[str] = {
+    # `delta` is converted to an absolute `moment` in __init__ and the class 
deliberately serializes
+    # as a DateTimeTrigger (see its docstring) -- there is no `delta` to 
reconstruct.
+    
"standard/src/airflow/providers/standard/triggers/temporal.py::TimeDeltaTrigger",
+    # `pod_name` is a deprecated alias folded into `pod_names` in __init__; 
serializing only
+    # `pod_names` preserves the value and avoids re-triggering the deprecation 
path on restart.
+    
"google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py::GKEJobTrigger",
+    
"cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py::KubernetesJobTrigger",
+}
+
+_EXCLUDED = KNOWN_VIOLATIONS | BY_DESIGN_EXCLUSIONS
+
+
+def _get_init_param_names(func: ast.FunctionDef) -> set[str]:
+    """Return the names of reconstructable __init__ parameters 
(``*args``/``**kwargs`` excluded)."""
+    args = func.args
+    names = {a.arg for a in (*args.posonlyargs, *args.args, *args.kwonlyargs)}
+    names.discard("self")
+    names.discard("cls")
+    return names
+
+
+def _get_base_simple_names(cls: ast.ClassDef) -> list[str]:
+    """Return the simple (last-component) names of a class's bases."""
+    out: list[str] = []
+    for base in cls.bases:
+        if isinstance(base, ast.Name):
+            out.append(base.id)
+        elif isinstance(base, ast.Attribute):
+            out.append(base.attr)
+    return out
+
+
+def _get_method(cls: ast.ClassDef, name: str) -> ast.FunctionDef | None:
+    for node in cls.body:
+        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and 
node.name == name:
+            return node if isinstance(node, ast.FunctionDef) else None
+    return None
+
+
+class ModuleAnalyzer:
+    """Resolves trigger __init__/serialize() pairs within a single module, 
following in-file bases."""
+
+    def __init__(self, path: Path) -> None:
+        self.path = path
+        tree = ast.parse(path.read_text("utf-8"), str(path))
+        self.classes: dict[str, ast.ClassDef] = {
+            node.name: node for node in ast.walk(tree) if isinstance(node, 
ast.ClassDef)
+        }
+
+    def _in_file_base(self, cls: ast.ClassDef) -> ast.ClassDef | None:
+        for name in _get_base_simple_names(cls):
+            if name in self.classes:
+                return self.classes[name]
+        return None
+
+    def is_trigger(self, cls: ast.ClassDef, _seen: set[str] | None = None) -> 
bool:
+        if cls.name.endswith("Trigger"):
+            return True
+        _seen = _seen or set()
+        if cls.name in _seen:
+            return False
+        _seen.add(cls.name)
+        for name in _get_base_simple_names(cls):
+            if "Trigger" in name:
+                return True
+            base = self.classes.get(name)
+            if base is not None and self.is_trigger(base, _seen):
+                return True
+        return False
+
+    def _resolve_method(self, cls: ast.ClassDef, name: str) -> 
tuple[ast.FunctionDef, ast.ClassDef] | None:
+        """Walk in-file bases until *name* is found; return (method, defining 
class)."""
+        current: ast.ClassDef | None = cls
+        seen: set[str] = set()
+        while current is not None and current.name not in seen:
+            seen.add(current.name)
+            method = _get_method(current, name)
+            if method is not None:
+                return method, current
+            current = self._in_file_base(current)
+        return None
+
+    def _get_serialize_keys(self, cls: ast.ClassDef, _seen: set[str] | None = 
None) -> set[str] | None:
+        """
+        Statically resolve the keys of the dict returned by *cls*'s effective 
``serialize()``.
+
+        Returns ``None`` when the dict cannot be resolved statically (dynamic 
construction,
+        external base class, ``**`` spread of a non-``super()`` value, ...).
+        """
+        _seen = _seen or set()
+        if cls.name in _seen:
+            return None
+        _seen.add(cls.name)
+
+        resolved = self._resolve_method(cls, "serialize")
+        if resolved is None:
+            return None
+        serialize, defining_cls = resolved
+
+        returns = [n for n in ast.walk(serialize) if isinstance(n, ast.Return)]
+        if len(returns) != 1 or returns[0].value is None:
+            return None
+        ret = returns[0].value
+        if not isinstance(ret, (ast.Tuple, ast.List)) or len(ret.elts) != 2:
+            return None
+        payload = ret.elts[1]
+        if not isinstance(payload, ast.Dict):
+            return None
+
+        keys: set[str] = set()
+        for key, value in zip(payload.keys, payload.values):
+            if key is None:
+                # ``**spread`` entry -- only resolvable when it spreads 
super().serialize().
+                spread = self._get_super_serialize_keys(value, defining_cls, 
_seen)
+                if spread is None:
+                    return None
+                keys |= spread
+                continue
+            if not (isinstance(key, ast.Constant) and isinstance(key.value, 
str)):
+                return None
+            keys.add(key.value)
+        return keys
+
+    def _get_super_serialize_keys(
+        self, value: ast.expr, defining_cls: ast.ClassDef, _seen: set[str]
+    ) -> set[str] | None:
+        """Resolve keys for a ``**super().serialize()[1]`` / 
``**super().serialize()`` spread."""
+        node = value
+        if isinstance(node, ast.Subscript):  # super().serialize()[1]
+            node = node.value
+        if not (
+            isinstance(node, ast.Call)
+            and isinstance(node.func, ast.Attribute)
+            and node.func.attr == "serialize"
+            and isinstance(node.func.value, ast.Call)
+            and isinstance(node.func.value.func, ast.Name)
+            and node.func.value.func.id == "super"
+        ):
+            return None
+        base = self._in_file_base(defining_cls)
+        if base is None:
+            return None
+        return self._get_serialize_keys(base, _seen)
+
+    def get_violations(self) -> list[tuple[str, list[str]]]:
+        """Return ``(class_name, [missing params])`` for every resolvable 
trigger that violates."""
+        results: list[tuple[str, list[str]]] = []
+        rel = self.path.relative_to(AIRFLOW_PROVIDERS_ROOT_PATH).as_posix()
+        for name, cls in self.classes.items():
+            if f"{rel}::{name}" in _EXCLUDED:
+                continue
+            if not self.is_trigger(cls):
+                continue
+            init_resolved = self._resolve_method(cls, "__init__")
+            if init_resolved is None:
+                continue  # __init__ inherited from an out-of-file base -- 
cannot resolve.
+            params = _get_init_param_names(init_resolved[0])
+            serialize_keys = self._get_serialize_keys(cls)
+            if serialize_keys is None:
+                continue  # serialize() dynamic or inherited from an 
out-of-file base.
+            missing = sorted(params - serialize_keys)
+            if missing:
+                results.append((name, missing))
+        return results
+
+
+def _iter_files(argv: list[str]) -> list[Path]:
+    if argv:
+        return [Path(a).resolve() for a in argv]
+    return 
sorted(AIRFLOW_PROVIDERS_ROOT_PATH.glob("*/src/airflow/providers/**/triggers/*.py"))
 + sorted(
+        
AIRFLOW_PROVIDERS_ROOT_PATH.glob("*/*/src/airflow/providers/**/triggers/*.py")
+    )
+
+
+def main(argv: list[str]) -> int:
+    error_count = 0
+    for path in _iter_files(argv):
+        if path.name == "__init__.py":
+            continue
+        try:
+            analyzer = ModuleAnalyzer(path)
+        except SyntaxError as exc:
+            console.print(f"[red]Could not parse {path}: {exc}")
+            error_count += 1
+            continue
+        rel = path.relative_to(AIRFLOW_PROVIDERS_ROOT_PATH).as_posix()
+        for class_name, missing in analyzer.get_violations():
+            error_count += 1
+            console.print(
+                f"[red]{rel}::{class_name}[/] -- __init__ parameter(s) "
+                f"{', '.join(repr(m) for m in missing)} missing from 
serialize() return dict"
+            )
+    if error_count:
+        console.print(
+            f"\n[red]Found {error_count} trigger(s) whose serialize() drops 
__init__ parameters.[/]\n"
+            "Every __init__ parameter must appear in the serialize() return 
dict, otherwise it is "
+            "silently lost when the triggerer re-instantiates the trigger.\n"
+            f"See: {DEFERRING_DOC}\n"
+        )
+    return 1 if error_count else 0
+
+
+if __name__ == "__main__":
+    sys.exit(main(sys.argv[1:]))

Reply via email to