This is an automated email from the ASF dual-hosted git repository. jhtimmins pushed a commit to branch v2-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 75c3e752667ef3bb0fe82d8edefafb444531b0e8 Author: Henry Zhang <m...@henry.dev> AuthorDate: Thu Jul 22 15:23:51 2021 -0700 Core: Enable the use of __init_subclass__ in subclasses of BaseOperator (#17027) This fixes a regression in 2.1 where subclasses of BaseOperator could no longer use `__init_subclass__` to allow class instantiation time customization. Related BPO: https://bugs.python.org/issue29581 Fixes: https://github.com/apache/airflow/issues/17014 (cherry picked from commit 901513203f287d4f8152f028e9070a2dec73ad74) --- airflow/models/baseoperator.py | 4 ++-- tests/models/test_baseoperator.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 1fec8cf..360f264 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -186,8 +186,8 @@ class BaseOperatorMeta(abc.ABCMeta): return cast(T, apply_defaults) - def __new__(cls, name, bases, namespace): - new_cls = super().__new__(cls, name, bases, namespace) + def __new__(cls, name, bases, namespace, **kwargs): + new_cls = super().__new__(cls, name, bases, namespace, **kwargs) new_cls.__init__ = cls._apply_defaults(new_cls.__init__) return new_cls diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 04d3f54..87a9247 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -516,3 +516,29 @@ class TestXComArgsRelationsAreResolved: with pytest.raises(AirflowException): op1 = DummyOperator(task_id="op1") CustomOp(task_id="op2", field=op1.output) + + +class InitSubclassOp(DummyOperator): + def __init_subclass__(cls, class_arg=None, **kwargs) -> None: + cls._class_arg = class_arg + super().__init_subclass__(**kwargs) + + def execute(self, context): + self.context_arg = context + + +class TestInitSubclassOperator: + def test_init_subclass_args(self): + class_arg = "foo" + context = {"key": "value"} + + class ConcreteSubclassOp(InitSubclassOp, class_arg=class_arg): + pass + + task = ConcreteSubclassOp(task_id="op1") + task_copy = task.prepare_for_execution() + + task_copy.execute(context) + + assert task_copy._class_arg == class_arg + assert task_copy.context_arg == context