This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch openlineage-interface in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a3a8d92a5b246245a3ed1e1540fb11a55cd00157 Author: Maciej Obuchowski <[email protected]> AuthorDate: Mon Jul 31 14:38:10 2023 +0200 add OpenLineage methods as mixin interface to BaseOperator Signed-off-by: Maciej Obuchowski <[email protected]> --- airflow/models/abstractoperator.py | 3 +- airflow/providers/openlineage/extractors/base.py | 29 ++++++++++------ .../providers/openlineage/extractors/manager.py | 20 +++++++---- airflow/providers/openlineage/plugins/listener.py | 10 ++++-- .../providers/openlineage/plugins/openlineage.py | 3 ++ airflow/utils/openlineage_mixin.py | 39 ++++++++++++++++++++++ .../extractors/test_default_extractor.py | 18 +++------- 7 files changed, 88 insertions(+), 34 deletions(-) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 11e9184735..daf1d45a25 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -33,6 +33,7 @@ from airflow.template.templater import Templater from airflow.utils.context import Context from airflow.utils.db import exists_query from airflow.utils.log.secrets_masker import redact +from airflow.utils.openlineage_mixin import OpenLineageMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.sqlalchemy import skip_locked, with_row_locks @@ -82,7 +83,7 @@ class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" -class AbstractOperator(Templater, DAGNode): +class AbstractOperator(Templater, DAGNode, OpenLineageMixin): """Common implementation for operators, including unmapped and mapped. This base class is more about sharing implementations, not defining a common diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index 95d8fa6f28..43a0b584f6 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -18,6 +18,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from contextlib import suppress from attrs import Factory, define from openlineage.client.facet import BaseFacet @@ -84,20 +85,28 @@ class DefaultExtractor(BaseExtractor): def extract(self) -> OperatorLineage | None: # OpenLineage methods are optional - if there's no method, return None - try: + with suppress(AttributeError): return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore - except AttributeError: - return None + return None def extract_on_complete(self, task_instance) -> OperatorLineage | None: + """ + For complete method, we want to handle on_failure and on_complete methods as priority. + If they are not implemented - which happens in older, pre-OpenLineageMixin + classes, we're falling back to on_start method. + """ if task_instance.state == TaskInstanceState.FAILED: - on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None) - if on_failed and callable(on_failed): - return self._get_openlineage_facets(on_failed, task_instance) - on_complete = getattr(self.operator, "get_openlineage_facets_on_complete", None) - if on_complete and callable(on_complete): - return self._get_openlineage_facets(on_complete, task_instance) - return self.extract() + with suppress(AttributeError): + return self._get_openlineage_facets( + self.operator.get_openlineage_facets_on_failure, task_instance + ) + with suppress(AttributeError): + return self._get_openlineage_facets( + self.operator.get_openlineage_facets_on_complete, task_instance + ) + with suppress(AttributeError): + return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start, task_instance) + return None def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | None: try: diff --git a/airflow/providers/openlineage/extractors/manager.py b/airflow/providers/openlineage/extractors/manager.py index 02a4124840..8f293ae545 100644 --- a/airflow/providers/openlineage/extractors/manager.py +++ b/airflow/providers/openlineage/extractors/manager.py @@ -32,6 +32,7 @@ from airflow.providers.openlineage.plugins.facets import ( from airflow.providers.openlineage.utils.utils import get_filtered_unknown_operator_keys from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string +from airflow.utils.openlineage_mixin import OpenLineageMixin if TYPE_CHECKING: from airflow.models import Operator @@ -135,13 +136,9 @@ class ExtractorManager(LoggingMixin): if task.task_type in self.extractors: return self.extractors[task.task_type] - def method_exists(method_name): - method = getattr(task, method_name, None) - if method: - return callable(method) - - if method_exists("get_openlineage_facets_on_start") or method_exists( - "get_openlineage_facets_on_complete" + # We need to handle older OpenLineage implementations that do not implement OpenLineageMixin + if isinstance(task, OpenLineageMixin) or self._is_backwards_compatible_openlineage_implementation( + task ): return self.default_extractor return None @@ -199,3 +196,12 @@ class ExtractorManager(LoggingMixin): except AttributeError: self.log.error("Extractor returns non-valid metadata: %s", task_metadata) return None + + @staticmethod + def _is_backwards_compatible_openlineage_implementation(task) -> bool: + # Existence of those methods is a valid OL implementation. + return ( + hasattr(task, "get_openlineage_facets_on_start") + or hasattr(task, "get_openlineage_facets_on_complete") + or hasattr(task, "get_openlineage_facets_on_failure") + ) diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index 99394863f5..516c8456bb 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -42,9 +42,15 @@ class OpenLineageListener: def __init__(self): self.log = logging.getLogger(__name__) - self.executor: Executor = None # type: ignore self.extractor_manager = ExtractorManager() self.adapter = OpenLineageAdapter() + self._executor: Executor | None = None + + @property + def executor(self) -> Executor: + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") + return self._executor @hookimpl def on_task_instance_running( @@ -151,7 +157,7 @@ class OpenLineageListener: @hookimpl def on_starting(self, component): self.log.debug("on_starting: %s", component.__class__.__name__) - self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") + self.executor @hookimpl def before_stopping(self, component): diff --git a/airflow/providers/openlineage/plugins/openlineage.py b/airflow/providers/openlineage/plugins/openlineage.py index 2ec0801147..d0ade8259d 100644 --- a/airflow/providers/openlineage/plugins/openlineage.py +++ b/airflow/providers/openlineage/plugins/openlineage.py @@ -16,12 +16,15 @@ # under the License. from __future__ import annotations +import logging import os from airflow.configuration import conf from airflow.plugins_manager import AirflowPlugin from airflow.providers.openlineage.plugins.macros import lineage_parent_id, lineage_run_id +log = logging.getLogger("airflow") + def _is_disabled() -> bool: return ( diff --git a/airflow/utils/openlineage_mixin.py b/airflow/utils/openlineage_mixin.py new file mode 100644 index 0000000000..d3773e9e37 --- /dev/null +++ b/airflow/utils/openlineage_mixin.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + from airflow.models import TaskInstance + from airflow.providers.openlineage.extractors import OperatorLineage + + +class OpenLineageMixin: + """ + This interface marks implementation of OpenLineage methods, + allowing us to check for it's existence rather than existence of particular methods on BaseOperator. + """ + + def get_openlineage_facets_on_start(self) -> OperatorLineage: + raise NotImplementedError() + + def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage: + return self.get_openlineage_facets_on_start() + + def get_openlineage_facets_on_fail(self, task_instance: TaskInstance) -> OperatorLineage: + return self.get_openlineage_facets_on_complete(task_instance) diff --git a/tests/providers/openlineage/extractors/test_default_extractor.py b/tests/providers/openlineage/extractors/test_default_extractor.py index 41d477aa9c..e3c9221093 100644 --- a/tests/providers/openlineage/extractors/test_default_extractor.py +++ b/tests/providers/openlineage/extractors/test_default_extractor.py @@ -137,7 +137,7 @@ class OperatorWrongOperatorLineageClass(BaseOperator): ) -class BrokenOperator(BaseOperator): +class BrokenOperator: get_openlineage_facets = [] def execute(self, context) -> Any: @@ -145,7 +145,7 @@ class BrokenOperator(BaseOperator): def test_default_extraction(): - extractor = ExtractorManager().get_extractor_class(ExampleOperator) + extractor = ExtractorManager().get_extractor_class(ExampleOperator(task_id="test")) assert extractor is DefaultExtractor metadata = extractor(ExampleOperator(task_id="test")).extract() @@ -172,7 +172,7 @@ def test_default_extraction(): def test_extraction_without_on_complete(): - extractor = ExtractorManager().get_extractor_class(OperatorWithoutComplete) + extractor = ExtractorManager().get_extractor_class(OperatorWithoutComplete(task_id="test")) assert extractor is DefaultExtractor metadata = extractor(OperatorWithoutComplete(task_id="test")).extract() @@ -196,7 +196,7 @@ def test_extraction_without_on_complete(): def test_extraction_without_on_start(): - extractor = ExtractorManager().get_extractor_class(OperatorWithoutStart) + extractor = ExtractorManager().get_extractor_class(OperatorWithoutStart(task_id="test")) assert extractor is DefaultExtractor metadata = extractor(OperatorWithoutStart(task_id="test")).extract() @@ -217,16 +217,6 @@ def test_extraction_without_on_start(): ) -def test_does_not_use_default_extractor_when_not_a_method(): - extractor_class = ExtractorManager().get_extractor_class(BrokenOperator(task_id="a")) - assert extractor_class is None - - -def test_does_not_use_default_extractor_when_no_get_openlineage_facets(): - extractor_class = ExtractorManager().get_extractor_class(BaseOperator(task_id="b")) - assert extractor_class is None - - def test_does_not_use_default_extractor_when_explicite_extractor(): extractor_class = ExtractorManager().get_extractor_class( PythonOperator(task_id="c", python_callable=lambda: 7)
