This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch openlineage-interface
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a3a8d92a5b246245a3ed1e1540fb11a55cd00157
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Mon Jul 31 14:38:10 2023 +0200

    add OpenLineage methods as mixin interface to BaseOperator
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/models/abstractoperator.py                 |  3 +-
 airflow/providers/openlineage/extractors/base.py   | 29 ++++++++++------
 .../providers/openlineage/extractors/manager.py    | 20 +++++++----
 airflow/providers/openlineage/plugins/listener.py  | 10 ++++--
 .../providers/openlineage/plugins/openlineage.py   |  3 ++
 airflow/utils/openlineage_mixin.py                 | 39 ++++++++++++++++++++++
 .../extractors/test_default_extractor.py           | 18 +++-------
 7 files changed, 88 insertions(+), 34 deletions(-)

diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 11e9184735..daf1d45a25 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -33,6 +33,7 @@ from airflow.template.templater import Templater
 from airflow.utils.context import Context
 from airflow.utils.db import exists_query
 from airflow.utils.log.secrets_masker import redact
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.setup_teardown import SetupTeardownContext
 from airflow.utils.sqlalchemy import skip_locked, with_row_locks
@@ -82,7 +83,7 @@ class NotMapped(Exception):
     """Raise if a task is neither mapped nor has any parent mapped groups."""
 
 
-class AbstractOperator(Templater, DAGNode):
+class AbstractOperator(Templater, DAGNode, OpenLineageMixin):
     """Common implementation for operators, including unmapped and mapped.
 
     This base class is more about sharing implementations, not defining a 
common
diff --git a/airflow/providers/openlineage/extractors/base.py 
b/airflow/providers/openlineage/extractors/base.py
index 95d8fa6f28..43a0b584f6 100644
--- a/airflow/providers/openlineage/extractors/base.py
+++ b/airflow/providers/openlineage/extractors/base.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from abc import ABC, abstractmethod
+from contextlib import suppress
 
 from attrs import Factory, define
 from openlineage.client.facet import BaseFacet
@@ -84,20 +85,28 @@ class DefaultExtractor(BaseExtractor):
 
     def extract(self) -> OperatorLineage | None:
         # OpenLineage methods are optional - if there's no method, return None
-        try:
+        with suppress(AttributeError):
             return 
self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start)  # 
type: ignore
-        except AttributeError:
-            return None
+        return None
 
     def extract_on_complete(self, task_instance) -> OperatorLineage | None:
+        """
+        For complete method, we want to handle on_failure and on_complete 
methods as priority.
+        If they are not implemented - which happens in older, 
pre-OpenLineageMixin
+        classes, we're falling back to on_start method.
+        """
         if task_instance.state == TaskInstanceState.FAILED:
-            on_failed = getattr(self.operator, 
"get_openlineage_facets_on_failure", None)
-            if on_failed and callable(on_failed):
-                return self._get_openlineage_facets(on_failed, task_instance)
-        on_complete = getattr(self.operator, 
"get_openlineage_facets_on_complete", None)
-        if on_complete and callable(on_complete):
-            return self._get_openlineage_facets(on_complete, task_instance)
-        return self.extract()
+            with suppress(AttributeError):
+                return self._get_openlineage_facets(
+                    self.operator.get_openlineage_facets_on_failure, 
task_instance
+                )
+        with suppress(AttributeError):
+            return self._get_openlineage_facets(
+                self.operator.get_openlineage_facets_on_complete, task_instance
+            )
+        with suppress(AttributeError):
+            return 
self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start, 
task_instance)
+        return None
 
     def _get_openlineage_facets(self, get_facets_method, *args) -> 
OperatorLineage | None:
         try:
diff --git a/airflow/providers/openlineage/extractors/manager.py 
b/airflow/providers/openlineage/extractors/manager.py
index 02a4124840..8f293ae545 100644
--- a/airflow/providers/openlineage/extractors/manager.py
+++ b/airflow/providers/openlineage/extractors/manager.py
@@ -32,6 +32,7 @@ from airflow.providers.openlineage.plugins.facets import (
 from airflow.providers.openlineage.utils.utils import 
get_filtered_unknown_operator_keys
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.module_loading import import_string
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 if TYPE_CHECKING:
     from airflow.models import Operator
@@ -135,13 +136,9 @@ class ExtractorManager(LoggingMixin):
         if task.task_type in self.extractors:
             return self.extractors[task.task_type]
 
-        def method_exists(method_name):
-            method = getattr(task, method_name, None)
-            if method:
-                return callable(method)
-
-        if method_exists("get_openlineage_facets_on_start") or method_exists(
-            "get_openlineage_facets_on_complete"
+        # We need to handle older OpenLineage implementations that do not 
implement OpenLineageMixin
+        if isinstance(task, OpenLineageMixin) or 
self._is_backwards_compatible_openlineage_implementation(
+            task
         ):
             return self.default_extractor
         return None
@@ -199,3 +196,12 @@ class ExtractorManager(LoggingMixin):
         except AttributeError:
             self.log.error("Extractor returns non-valid metadata: %s", 
task_metadata)
             return None
+
+    @staticmethod
+    def _is_backwards_compatible_openlineage_implementation(task) -> bool:
+        # Existence of those methods is a valid OL implementation.
+        return (
+            hasattr(task, "get_openlineage_facets_on_start")
+            or hasattr(task, "get_openlineage_facets_on_complete")
+            or hasattr(task, "get_openlineage_facets_on_failure")
+        )
diff --git a/airflow/providers/openlineage/plugins/listener.py 
b/airflow/providers/openlineage/plugins/listener.py
index 99394863f5..516c8456bb 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -42,9 +42,15 @@ class OpenLineageListener:
 
     def __init__(self):
         self.log = logging.getLogger(__name__)
-        self.executor: Executor = None  # type: ignore
         self.extractor_manager = ExtractorManager()
         self.adapter = OpenLineageAdapter()
+        self._executor: Executor | None = None
+
+    @property
+    def executor(self) -> Executor:
+        if self._executor is None:
+            self._executor = ThreadPoolExecutor(max_workers=8, 
thread_name_prefix="openlineage_")
+        return self._executor
 
     @hookimpl
     def on_task_instance_running(
@@ -151,7 +157,7 @@ class OpenLineageListener:
     @hookimpl
     def on_starting(self, component):
         self.log.debug("on_starting: %s", component.__class__.__name__)
-        self.executor = ThreadPoolExecutor(max_workers=8, 
thread_name_prefix="openlineage_")
+        self.executor
 
     @hookimpl
     def before_stopping(self, component):
diff --git a/airflow/providers/openlineage/plugins/openlineage.py 
b/airflow/providers/openlineage/plugins/openlineage.py
index 2ec0801147..d0ade8259d 100644
--- a/airflow/providers/openlineage/plugins/openlineage.py
+++ b/airflow/providers/openlineage/plugins/openlineage.py
@@ -16,12 +16,15 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 import os
 
 from airflow.configuration import conf
 from airflow.plugins_manager import AirflowPlugin
 from airflow.providers.openlineage.plugins.macros import lineage_parent_id, 
lineage_run_id
 
+log = logging.getLogger("airflow")
+
 
 def _is_disabled() -> bool:
     return (
diff --git a/airflow/utils/openlineage_mixin.py 
b/airflow/utils/openlineage_mixin.py
new file mode 100644
index 0000000000..d3773e9e37
--- /dev/null
+++ b/airflow/utils/openlineage_mixin.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import typing
+
+if typing.TYPE_CHECKING:
+    from airflow.models import TaskInstance
+    from airflow.providers.openlineage.extractors import OperatorLineage
+
+
+class OpenLineageMixin:
+    """
+    This interface marks implementation of OpenLineage methods,
+    allowing us to check for it's existence rather than existence of 
particular methods on BaseOperator.
+    """
+
+    def get_openlineage_facets_on_start(self) -> OperatorLineage:
+        raise NotImplementedError()
+
+    def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) 
-> OperatorLineage:
+        return self.get_openlineage_facets_on_start()
+
+    def get_openlineage_facets_on_fail(self, task_instance: TaskInstance) -> 
OperatorLineage:
+        return self.get_openlineage_facets_on_complete(task_instance)
diff --git a/tests/providers/openlineage/extractors/test_default_extractor.py 
b/tests/providers/openlineage/extractors/test_default_extractor.py
index 41d477aa9c..e3c9221093 100644
--- a/tests/providers/openlineage/extractors/test_default_extractor.py
+++ b/tests/providers/openlineage/extractors/test_default_extractor.py
@@ -137,7 +137,7 @@ class OperatorWrongOperatorLineageClass(BaseOperator):
         )
 
 
-class BrokenOperator(BaseOperator):
+class BrokenOperator:
     get_openlineage_facets = []
 
     def execute(self, context) -> Any:
@@ -145,7 +145,7 @@ class BrokenOperator(BaseOperator):
 
 
 def test_default_extraction():
-    extractor = ExtractorManager().get_extractor_class(ExampleOperator)
+    extractor = 
ExtractorManager().get_extractor_class(ExampleOperator(task_id="test"))
     assert extractor is DefaultExtractor
 
     metadata = extractor(ExampleOperator(task_id="test")).extract()
@@ -172,7 +172,7 @@ def test_default_extraction():
 
 
 def test_extraction_without_on_complete():
-    extractor = ExtractorManager().get_extractor_class(OperatorWithoutComplete)
+    extractor = 
ExtractorManager().get_extractor_class(OperatorWithoutComplete(task_id="test"))
     assert extractor is DefaultExtractor
 
     metadata = extractor(OperatorWithoutComplete(task_id="test")).extract()
@@ -196,7 +196,7 @@ def test_extraction_without_on_complete():
 
 
 def test_extraction_without_on_start():
-    extractor = ExtractorManager().get_extractor_class(OperatorWithoutStart)
+    extractor = 
ExtractorManager().get_extractor_class(OperatorWithoutStart(task_id="test"))
     assert extractor is DefaultExtractor
 
     metadata = extractor(OperatorWithoutStart(task_id="test")).extract()
@@ -217,16 +217,6 @@ def test_extraction_without_on_start():
     )
 
 
-def test_does_not_use_default_extractor_when_not_a_method():
-    extractor_class = 
ExtractorManager().get_extractor_class(BrokenOperator(task_id="a"))
-    assert extractor_class is None
-
-
-def test_does_not_use_default_extractor_when_no_get_openlineage_facets():
-    extractor_class = 
ExtractorManager().get_extractor_class(BaseOperator(task_id="b"))
-    assert extractor_class is None
-
-
 def test_does_not_use_default_extractor_when_explicite_extractor():
     extractor_class = ExtractorManager().get_extractor_class(
         PythonOperator(task_id="c", python_callable=lambda: 7)

Reply via email to