This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a23bf4c06a plugins: load plugins from providers (#32692)
a23bf4c06a is described below
commit a23bf4c06a12b4c91f6f0501a8a624c612fd4455
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Tue Jul 25 11:17:37 2023 +0200
plugins: load plugins from providers (#32692)
Signed-off-by: Maciej Obuchowski <[email protected]>
---
airflow/__init__.py | 9 ++++-----
airflow/plugins_manager.py | 32 +++++++++++++++++++++++++++++-
airflow/providers_manager.py | 36 ++++++++++++++++++++++++++++++++++
tests/always/test_providers_manager.py | 30 +++++++++++++++++++++++++++-
tests/plugins/test_plugins_manager.py | 33 +++++++++++++++++++++++++++++++
5 files changed, 133 insertions(+), 7 deletions(-)
diff --git a/airflow/__init__.py b/airflow/__init__.py
index d330b9332a..2e8e088ebc 100644
--- a/airflow/__init__.py
+++ b/airflow/__init__.py
@@ -106,11 +106,6 @@ def __getattr__(name: str):
return val
-if not settings.LAZY_LOAD_PLUGINS:
- from airflow import plugins_manager
-
- plugins_manager.ensure_plugins_loaded()
-
if not settings.LAZY_LOAD_PROVIDERS:
from airflow import providers_manager
@@ -118,6 +113,10 @@ if not settings.LAZY_LOAD_PROVIDERS:
manager.initialize_providers_list()
manager.initialize_providers_hooks()
manager.initialize_providers_extra_links()
+if not settings.LAZY_LOAD_PLUGINS:
+ from airflow import plugins_manager
+
+ plugins_manager.ensure_plugins_loaded()
# This is never executed, but tricks static analyzers (PyDev, PyCharm,)
diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index 946f064269..17acdd245c 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -38,7 +38,7 @@ from types import ModuleType
from airflow import settings
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.file import find_path_from_directory
-from airflow.utils.module_loading import qualname
+from airflow.utils.module_loading import import_string, qualname
if TYPE_CHECKING:
from airflow.hooks.base import BaseHook
@@ -50,6 +50,7 @@ log = logging.getLogger(__name__)
import_errors: dict[str, str] = {}
plugins: list[AirflowPlugin] | None = None
+loaded_plugins: set[str] = set()
# Plugin components to integrate as modules
registered_hooks: list[BaseHook] | None = None
@@ -205,10 +206,16 @@ def is_valid_plugin(plugin_obj):
def register_plugin(plugin_instance):
"""
Start plugin load and register it after success initialization.
+ If plugin is already registered, do nothing.
:param plugin_instance: subclass of AirflowPlugin
"""
global plugins
+
+ if plugin_instance.name in loaded_plugins:
+ return
+
+ loaded_plugins.add(plugin_instance.name)
plugin_instance.on_load()
plugins.append(plugin_instance)
@@ -267,6 +274,26 @@ def load_plugins_from_plugin_directory():
import_errors[file_path] = str(e)
+def load_providers_plugins():
+ from airflow.providers_manager import ProvidersManager
+
+ log.debug("Loading plugins from providers")
+ providers_manager = ProvidersManager()
+ providers_manager.initialize_providers_plugins()
+ for plugin in providers_manager.plugins:
+ log.debug("Importing plugin %s from class %s", plugin.name,
plugin.plugin_class)
+
+ try:
+ plugin_instance = import_string(plugin.plugin_class)
+ if not is_valid_plugin(plugin_instance):
+ log.warning("Plugin %s is not a valid plugin", plugin.name)
+ continue
+ register_plugin(plugin_instance)
+ except ImportError:
+ log.exception("Failed to load plugin %s from class name %s",
plugin.name, plugin.plugin_class)
+ continue
+
+
def make_module(name: str, objects: list[Any]):
"""Creates new module."""
if not objects:
@@ -306,6 +333,9 @@ def ensure_plugins_loaded():
load_plugins_from_plugin_directory()
load_entrypoint_plugins()
+ if not settings.LAZY_LOAD_PROVIDERS:
+ load_providers_plugins()
+
# We don't do anything with these for now, but we want to keep track of
# them so we can integrate them in to the UI's Connection screens
for plugin in plugins:
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index bda6df3441..b7689a92da 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -217,6 +217,14 @@ class TriggerInfo(NamedTuple):
integration_name: str
+class PluginInfo(NamedTuple):
+ """Plugin class, name and provider it comes from."""
+
+ name: str
+ plugin_class: str
+ provider_name: str
+
+
class HookInfo(NamedTuple):
"""Hook information."""
@@ -421,6 +429,8 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
self._customized_form_fields_schema_validator = (
_create_customized_form_field_behaviours_schema_validator()
)
+ # Set of plugins contained in providers
+ self._plugins_set: set[PluginInfo] = set()
@provider_info_cache("list")
def initialize_providers_list(self):
@@ -516,6 +526,11 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
self.initialize_providers_list()
self._discover_auth_backends()
+ @provider_info_cache("plugins")
+ def initialize_providers_plugins(self):
+ self.initialize_providers_list()
+ self._discover_plugins()
+
def _discover_all_providers_from_packages(self) -> None:
"""
Discover all providers by scanning packages installed.
@@ -1024,6 +1039,21 @@ class ProvidersManager(LoggingMixin,
metaclass=Singleton):
if provider.data.get("config"):
self._provider_configs[provider_package] =
provider.data.get("config")
+ def _discover_plugins(self) -> None:
+ """Retrieve all plugins defined in the providers."""
+ for provider_package, provider in self._provider_dict.items():
+ for plugin_dict in provider.data.get("plugins", ()):
+ if not _correctness_check(provider_package,
plugin_dict["plugin-class"], provider):
+ log.warning("Plugin not loaded due to above correctness
check problem.")
+ continue
+ self._plugins_set.add(
+ PluginInfo(
+ name=plugin_dict["name"],
+ plugin_class=plugin_dict["plugin-class"],
+ provider_name=provider_package,
+ )
+ )
+
@provider_info_cache("triggers")
def initialize_providers_triggers(self):
"""Initialization of providers triggers."""
@@ -1062,6 +1092,12 @@ class ProvidersManager(LoggingMixin,
metaclass=Singleton):
# When we return hooks here it will only be used to retrieve hook
information
return self._hooks_lazy_dict
+ @property
+ def plugins(self) -> list[PluginInfo]:
+ """Returns information about plugins available in providers."""
+ self.initialize_providers_plugins()
+ return sorted(self._plugins_set, key=lambda x: x.plugin_class)
+
@property
def taskflow_decorators(self) -> dict[str, TaskDecorator]:
self.initialize_providers_taskflow_decorator()
diff --git a/tests/always/test_providers_manager.py
b/tests/always/test_providers_manager.py
index b99dbcb84f..7e05d1cfb1 100644
--- a/tests/always/test_providers_manager.py
+++ b/tests/always/test_providers_manager.py
@@ -28,7 +28,13 @@ from flask_babel import lazy_gettext
from wtforms import BooleanField, Field, StringField
from airflow.exceptions import AirflowOptionalProviderFeatureException
-from airflow.providers_manager import HookClassProvider, LazyDictWithCache,
ProviderInfo, ProvidersManager
+from airflow.providers_manager import (
+ HookClassProvider,
+ LazyDictWithCache,
+ PluginInfo,
+ ProviderInfo,
+ ProvidersManager,
+)
class TestProviderManager:
@@ -157,6 +163,28 @@ class TestProviderManager:
" and 'airflow.providers.dummy.hooks.dummy.DummyHook2'."
) in self._caplog.records[0].message
+ def test_providers_manager_register_plugins(self):
+ providers_manager = ProvidersManager()
+
providers_manager._provider_dict["apache-airflow-providers-apache-hive"] =
ProviderInfo(
+ version="0.0.1",
+ data={
+ "plugins": [
+ {
+ "name": "plugin1",
+ "plugin-class":
"airflow.providers.apache.hive.plugins.hive.HivePlugin",
+ }
+ ]
+ },
+ package_or_source="package",
+ )
+ providers_manager._discover_plugins()
+ assert len(providers_manager._plugins_set) == 1
+ assert providers_manager._plugins_set.pop() == PluginInfo(
+ name="plugin1",
+
plugin_class="airflow.providers.apache.hive.plugins.hive.HivePlugin",
+ provider_name="apache-airflow-providers-apache-hive",
+ )
+
def test_hooks(self):
with pytest.warns(expected_warning=None) as warning_records:
with self._caplog.at_level(logging.WARNING):
diff --git a/tests/plugins/test_plugins_manager.py
b/tests/plugins/test_plugins_manager.py
index 019e2a69a8..bf74863407 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -160,6 +160,13 @@ def test_flaskappbuilder_nomenu_views():
class TestPluginsManager:
+ @pytest.fixture(autouse=True, scope="function")
+ def clean_plugins(self):
+ from airflow import plugins_manager
+
+ plugins_manager.loaded_plugins = set()
+ plugins_manager.plugins = []
+
def test_no_log_when_no_plugins(self, caplog):
with mock_plugin_manager(plugins=[]):
@@ -378,6 +385,32 @@ class TestPluginsManager:
assert get_listener_manager().has_listeners
assert get_listener_manager().pm.get_plugins().pop().__name__ ==
"tests.listeners.empty_listener"
+ def test_should_import_plugin_from_providers(self):
+ from airflow import plugins_manager
+
+ with mock.patch("airflow.plugins_manager.plugins", []):
+ assert len(plugins_manager.plugins) == 0
+ plugins_manager.load_providers_plugins()
+ assert len(plugins_manager.plugins) >= 2
+
+ def test_does_not_double_import_entrypoint_provider_plugins(self):
+ from airflow import plugins_manager
+
+ mock_entrypoint = mock.Mock()
+ mock_entrypoint.name = "test-entrypoint-plugin"
+ mock_entrypoint.module = "module_name_plugin"
+
+ mock_dist = mock.Mock()
+ mock_dist.metadata = {"Name": "test-entrypoint-plugin"}
+ mock_dist.version = "1.0.0"
+ mock_dist.entry_points = [mock_entrypoint]
+
+ with mock.patch("airflow.plugins_manager.plugins", []):
+ assert len(plugins_manager.plugins) == 0
+ plugins_manager.load_entrypoint_plugins()
+ plugins_manager.load_providers_plugins()
+ assert len(plugins_manager.plugins) == 2
+
class TestPluginsDirectorySource:
def test_should_return_correct_path_name(self):