This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch load-plugins-from-providers in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 1332b894651159268d632f5548f0ed8b2ac76d38 Author: Maciej Obuchowski <obuchowski.mac...@gmail.com> AuthorDate: Tue Jun 27 14:24:24 2023 +0200 plugins: load plugins from providers Signed-off-by: Maciej Obuchowski <obuchowski.mac...@gmail.com> --- airflow/__init__.py | 9 ++++----- airflow/plugins_manager.py | 32 +++++++++++++++++++++++++++++- airflow/providers_manager.py | 36 ++++++++++++++++++++++++++++++++++ tests/always/test_providers_manager.py | 30 +++++++++++++++++++++++++++- tests/plugins/test_plugins_manager.py | 33 +++++++++++++++++++++++++++++++ 5 files changed, 133 insertions(+), 7 deletions(-) diff --git a/airflow/__init__.py b/airflow/__init__.py index d330b9332a..2e8e088ebc 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -106,11 +106,6 @@ def __getattr__(name: str): return val -if not settings.LAZY_LOAD_PLUGINS: - from airflow import plugins_manager - - plugins_manager.ensure_plugins_loaded() - if not settings.LAZY_LOAD_PROVIDERS: from airflow import providers_manager @@ -118,6 +113,10 @@ if not settings.LAZY_LOAD_PROVIDERS: manager.initialize_providers_list() manager.initialize_providers_hooks() manager.initialize_providers_extra_links() +if not settings.LAZY_LOAD_PLUGINS: + from airflow import plugins_manager + + plugins_manager.ensure_plugins_loaded() # This is never executed, but tricks static analyzers (PyDev, PyCharm,) diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 946f064269..17acdd245c 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -38,7 +38,7 @@ from types import ModuleType from airflow import settings from airflow.utils.entry_points import entry_points_with_dist from airflow.utils.file import find_path_from_directory -from airflow.utils.module_loading import qualname +from airflow.utils.module_loading import import_string, qualname if TYPE_CHECKING: from airflow.hooks.base import BaseHook @@ -50,6 +50,7 @@ log = logging.getLogger(__name__) import_errors: dict[str, str] = {} plugins: list[AirflowPlugin] | None = None +loaded_plugins: set[str] = set() # Plugin components to integrate as modules registered_hooks: list[BaseHook] | None = None @@ -205,10 +206,16 @@ def is_valid_plugin(plugin_obj): def register_plugin(plugin_instance): """ Start plugin load and register it after success initialization. + If plugin is already registered, do nothing. :param plugin_instance: subclass of AirflowPlugin """ global plugins + + if plugin_instance.name in loaded_plugins: + return + + loaded_plugins.add(plugin_instance.name) plugin_instance.on_load() plugins.append(plugin_instance) @@ -267,6 +274,26 @@ def load_plugins_from_plugin_directory(): import_errors[file_path] = str(e) +def load_providers_plugins(): + from airflow.providers_manager import ProvidersManager + + log.debug("Loading plugins from providers") + providers_manager = ProvidersManager() + providers_manager.initialize_providers_plugins() + for plugin in providers_manager.plugins: + log.debug("Importing plugin %s from class %s", plugin.name, plugin.plugin_class) + + try: + plugin_instance = import_string(plugin.plugin_class) + if not is_valid_plugin(plugin_instance): + log.warning("Plugin %s is not a valid plugin", plugin.name) + continue + register_plugin(plugin_instance) + except ImportError: + log.exception("Failed to load plugin %s from class name %s", plugin.name, plugin.plugin_class) + continue + + def make_module(name: str, objects: list[Any]): """Creates new module.""" if not objects: @@ -306,6 +333,9 @@ def ensure_plugins_loaded(): load_plugins_from_plugin_directory() load_entrypoint_plugins() + if not settings.LAZY_LOAD_PROVIDERS: + load_providers_plugins() + # We don't do anything with these for now, but we want to keep track of # them so we can integrate them in to the UI's Connection screens for plugin in plugins: diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index bda6df3441..b7689a92da 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -217,6 +217,14 @@ class TriggerInfo(NamedTuple): integration_name: str +class PluginInfo(NamedTuple): + """Plugin class, name and provider it comes from.""" + + name: str + plugin_class: str + provider_name: str + + class HookInfo(NamedTuple): """Hook information.""" @@ -421,6 +429,8 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton): self._customized_form_fields_schema_validator = ( _create_customized_form_field_behaviours_schema_validator() ) + # Set of plugins contained in providers + self._plugins_set: set[PluginInfo] = set() @provider_info_cache("list") def initialize_providers_list(self): @@ -516,6 +526,11 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton): self.initialize_providers_list() self._discover_auth_backends() + @provider_info_cache("plugins") + def initialize_providers_plugins(self): + self.initialize_providers_list() + self._discover_plugins() + def _discover_all_providers_from_packages(self) -> None: """ Discover all providers by scanning packages installed. @@ -1024,6 +1039,21 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton): if provider.data.get("config"): self._provider_configs[provider_package] = provider.data.get("config") + def _discover_plugins(self) -> None: + """Retrieve all plugins defined in the providers.""" + for provider_package, provider in self._provider_dict.items(): + for plugin_dict in provider.data.get("plugins", ()): + if not _correctness_check(provider_package, plugin_dict["plugin-class"], provider): + log.warning("Plugin not loaded due to above correctness check problem.") + continue + self._plugins_set.add( + PluginInfo( + name=plugin_dict["name"], + plugin_class=plugin_dict["plugin-class"], + provider_name=provider_package, + ) + ) + @provider_info_cache("triggers") def initialize_providers_triggers(self): """Initialization of providers triggers.""" @@ -1062,6 +1092,12 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton): # When we return hooks here it will only be used to retrieve hook information return self._hooks_lazy_dict + @property + def plugins(self) -> list[PluginInfo]: + """Returns information about plugins available in providers.""" + self.initialize_providers_plugins() + return sorted(self._plugins_set, key=lambda x: x.plugin_class) + @property def taskflow_decorators(self) -> dict[str, TaskDecorator]: self.initialize_providers_taskflow_decorator() diff --git a/tests/always/test_providers_manager.py b/tests/always/test_providers_manager.py index b99dbcb84f..7e05d1cfb1 100644 --- a/tests/always/test_providers_manager.py +++ b/tests/always/test_providers_manager.py @@ -28,7 +28,13 @@ from flask_babel import lazy_gettext from wtforms import BooleanField, Field, StringField from airflow.exceptions import AirflowOptionalProviderFeatureException -from airflow.providers_manager import HookClassProvider, LazyDictWithCache, ProviderInfo, ProvidersManager +from airflow.providers_manager import ( + HookClassProvider, + LazyDictWithCache, + PluginInfo, + ProviderInfo, + ProvidersManager, +) class TestProviderManager: @@ -157,6 +163,28 @@ class TestProviderManager: " and 'airflow.providers.dummy.hooks.dummy.DummyHook2'." ) in self._caplog.records[0].message + def test_providers_manager_register_plugins(self): + providers_manager = ProvidersManager() + providers_manager._provider_dict["apache-airflow-providers-apache-hive"] = ProviderInfo( + version="0.0.1", + data={ + "plugins": [ + { + "name": "plugin1", + "plugin-class": "airflow.providers.apache.hive.plugins.hive.HivePlugin", + } + ] + }, + package_or_source="package", + ) + providers_manager._discover_plugins() + assert len(providers_manager._plugins_set) == 1 + assert providers_manager._plugins_set.pop() == PluginInfo( + name="plugin1", + plugin_class="airflow.providers.apache.hive.plugins.hive.HivePlugin", + provider_name="apache-airflow-providers-apache-hive", + ) + def test_hooks(self): with pytest.warns(expected_warning=None) as warning_records: with self._caplog.at_level(logging.WARNING): diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index 019e2a69a8..bf74863407 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -160,6 +160,13 @@ def test_flaskappbuilder_nomenu_views(): class TestPluginsManager: + @pytest.fixture(autouse=True, scope="function") + def clean_plugins(self): + from airflow import plugins_manager + + plugins_manager.loaded_plugins = set() + plugins_manager.plugins = [] + def test_no_log_when_no_plugins(self, caplog): with mock_plugin_manager(plugins=[]): @@ -378,6 +385,32 @@ class TestPluginsManager: assert get_listener_manager().has_listeners assert get_listener_manager().pm.get_plugins().pop().__name__ == "tests.listeners.empty_listener" + def test_should_import_plugin_from_providers(self): + from airflow import plugins_manager + + with mock.patch("airflow.plugins_manager.plugins", []): + assert len(plugins_manager.plugins) == 0 + plugins_manager.load_providers_plugins() + assert len(plugins_manager.plugins) >= 2 + + def test_does_not_double_import_entrypoint_provider_plugins(self): + from airflow import plugins_manager + + mock_entrypoint = mock.Mock() + mock_entrypoint.name = "test-entrypoint-plugin" + mock_entrypoint.module = "module_name_plugin" + + mock_dist = mock.Mock() + mock_dist.metadata = {"Name": "test-entrypoint-plugin"} + mock_dist.version = "1.0.0" + mock_dist.entry_points = [mock_entrypoint] + + with mock.patch("airflow.plugins_manager.plugins", []): + assert len(plugins_manager.plugins) == 0 + plugins_manager.load_entrypoint_plugins() + plugins_manager.load_providers_plugins() + assert len(plugins_manager.plugins) == 2 + class TestPluginsDirectorySource: def test_should_return_correct_path_name(self):