This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch load-plugins-from-providers
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 1332b894651159268d632f5548f0ed8b2ac76d38
Author: Maciej Obuchowski <obuchowski.mac...@gmail.com>
AuthorDate: Tue Jun 27 14:24:24 2023 +0200

    plugins: load plugins from providers
    
    Signed-off-by: Maciej Obuchowski <obuchowski.mac...@gmail.com>
---
 airflow/__init__.py                    |  9 ++++-----
 airflow/plugins_manager.py             | 32 +++++++++++++++++++++++++++++-
 airflow/providers_manager.py           | 36 ++++++++++++++++++++++++++++++++++
 tests/always/test_providers_manager.py | 30 +++++++++++++++++++++++++++-
 tests/plugins/test_plugins_manager.py  | 33 +++++++++++++++++++++++++++++++
 5 files changed, 133 insertions(+), 7 deletions(-)

diff --git a/airflow/__init__.py b/airflow/__init__.py
index d330b9332a..2e8e088ebc 100644
--- a/airflow/__init__.py
+++ b/airflow/__init__.py
@@ -106,11 +106,6 @@ def __getattr__(name: str):
     return val
 
 
-if not settings.LAZY_LOAD_PLUGINS:
-    from airflow import plugins_manager
-
-    plugins_manager.ensure_plugins_loaded()
-
 if not settings.LAZY_LOAD_PROVIDERS:
     from airflow import providers_manager
 
@@ -118,6 +113,10 @@ if not settings.LAZY_LOAD_PROVIDERS:
     manager.initialize_providers_list()
     manager.initialize_providers_hooks()
     manager.initialize_providers_extra_links()
+if not settings.LAZY_LOAD_PLUGINS:
+    from airflow import plugins_manager
+
+    plugins_manager.ensure_plugins_loaded()
 
 
 # This is never executed, but tricks static analyzers (PyDev, PyCharm,)
diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index 946f064269..17acdd245c 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -38,7 +38,7 @@ from types import ModuleType
 from airflow import settings
 from airflow.utils.entry_points import entry_points_with_dist
 from airflow.utils.file import find_path_from_directory
-from airflow.utils.module_loading import qualname
+from airflow.utils.module_loading import import_string, qualname
 
 if TYPE_CHECKING:
     from airflow.hooks.base import BaseHook
@@ -50,6 +50,7 @@ log = logging.getLogger(__name__)
 import_errors: dict[str, str] = {}
 
 plugins: list[AirflowPlugin] | None = None
+loaded_plugins: set[str] = set()
 
 # Plugin components to integrate as modules
 registered_hooks: list[BaseHook] | None = None
@@ -205,10 +206,16 @@ def is_valid_plugin(plugin_obj):
 def register_plugin(plugin_instance):
     """
     Start plugin load and register it after success initialization.
+    If plugin is already registered, do nothing.
 
     :param plugin_instance: subclass of AirflowPlugin
     """
     global plugins
+
+    if plugin_instance.name in loaded_plugins:
+        return
+
+    loaded_plugins.add(plugin_instance.name)
     plugin_instance.on_load()
     plugins.append(plugin_instance)
 
@@ -267,6 +274,26 @@ def load_plugins_from_plugin_directory():
             import_errors[file_path] = str(e)
 
 
+def load_providers_plugins():
+    from airflow.providers_manager import ProvidersManager
+
+    log.debug("Loading plugins from providers")
+    providers_manager = ProvidersManager()
+    providers_manager.initialize_providers_plugins()
+    for plugin in providers_manager.plugins:
+        log.debug("Importing plugin %s from class %s", plugin.name, 
plugin.plugin_class)
+
+        try:
+            plugin_instance = import_string(plugin.plugin_class)
+            if not is_valid_plugin(plugin_instance):
+                log.warning("Plugin %s is not a valid plugin", plugin.name)
+                continue
+            register_plugin(plugin_instance)
+        except ImportError:
+            log.exception("Failed to load plugin %s from class name %s", 
plugin.name, plugin.plugin_class)
+            continue
+
+
 def make_module(name: str, objects: list[Any]):
     """Creates new module."""
     if not objects:
@@ -306,6 +333,9 @@ def ensure_plugins_loaded():
         load_plugins_from_plugin_directory()
         load_entrypoint_plugins()
 
+        if not settings.LAZY_LOAD_PROVIDERS:
+            load_providers_plugins()
+
         # We don't do anything with these for now, but we want to keep track of
         # them so we can integrate them in to the UI's Connection screens
         for plugin in plugins:
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index bda6df3441..b7689a92da 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -217,6 +217,14 @@ class TriggerInfo(NamedTuple):
     integration_name: str
 
 
+class PluginInfo(NamedTuple):
+    """Plugin class, name and provider it comes from."""
+
+    name: str
+    plugin_class: str
+    provider_name: str
+
+
 class HookInfo(NamedTuple):
     """Hook information."""
 
@@ -421,6 +429,8 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
         self._customized_form_fields_schema_validator = (
             _create_customized_form_field_behaviours_schema_validator()
         )
+        # Set of plugins contained in providers
+        self._plugins_set: set[PluginInfo] = set()
 
     @provider_info_cache("list")
     def initialize_providers_list(self):
@@ -516,6 +526,11 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
         self.initialize_providers_list()
         self._discover_auth_backends()
 
+    @provider_info_cache("plugins")
+    def initialize_providers_plugins(self):
+        self.initialize_providers_list()
+        self._discover_plugins()
+
     def _discover_all_providers_from_packages(self) -> None:
         """
         Discover all providers by scanning packages installed.
@@ -1024,6 +1039,21 @@ class ProvidersManager(LoggingMixin, 
metaclass=Singleton):
             if provider.data.get("config"):
                 self._provider_configs[provider_package] = 
provider.data.get("config")
 
+    def _discover_plugins(self) -> None:
+        """Retrieve all plugins defined in the providers."""
+        for provider_package, provider in self._provider_dict.items():
+            for plugin_dict in provider.data.get("plugins", ()):
+                if not _correctness_check(provider_package, 
plugin_dict["plugin-class"], provider):
+                    log.warning("Plugin not loaded due to above correctness 
check problem.")
+                    continue
+                self._plugins_set.add(
+                    PluginInfo(
+                        name=plugin_dict["name"],
+                        plugin_class=plugin_dict["plugin-class"],
+                        provider_name=provider_package,
+                    )
+                )
+
     @provider_info_cache("triggers")
     def initialize_providers_triggers(self):
         """Initialization of providers triggers."""
@@ -1062,6 +1092,12 @@ class ProvidersManager(LoggingMixin, 
metaclass=Singleton):
         # When we return hooks here it will only be used to retrieve hook 
information
         return self._hooks_lazy_dict
 
+    @property
+    def plugins(self) -> list[PluginInfo]:
+        """Returns information about plugins available in providers."""
+        self.initialize_providers_plugins()
+        return sorted(self._plugins_set, key=lambda x: x.plugin_class)
+
     @property
     def taskflow_decorators(self) -> dict[str, TaskDecorator]:
         self.initialize_providers_taskflow_decorator()
diff --git a/tests/always/test_providers_manager.py 
b/tests/always/test_providers_manager.py
index b99dbcb84f..7e05d1cfb1 100644
--- a/tests/always/test_providers_manager.py
+++ b/tests/always/test_providers_manager.py
@@ -28,7 +28,13 @@ from flask_babel import lazy_gettext
 from wtforms import BooleanField, Field, StringField
 
 from airflow.exceptions import AirflowOptionalProviderFeatureException
-from airflow.providers_manager import HookClassProvider, LazyDictWithCache, 
ProviderInfo, ProvidersManager
+from airflow.providers_manager import (
+    HookClassProvider,
+    LazyDictWithCache,
+    PluginInfo,
+    ProviderInfo,
+    ProvidersManager,
+)
 
 
 class TestProviderManager:
@@ -157,6 +163,28 @@ class TestProviderManager:
             " and 'airflow.providers.dummy.hooks.dummy.DummyHook2'."
         ) in self._caplog.records[0].message
 
+    def test_providers_manager_register_plugins(self):
+        providers_manager = ProvidersManager()
+        
providers_manager._provider_dict["apache-airflow-providers-apache-hive"] = 
ProviderInfo(
+            version="0.0.1",
+            data={
+                "plugins": [
+                    {
+                        "name": "plugin1",
+                        "plugin-class": 
"airflow.providers.apache.hive.plugins.hive.HivePlugin",
+                    }
+                ]
+            },
+            package_or_source="package",
+        )
+        providers_manager._discover_plugins()
+        assert len(providers_manager._plugins_set) == 1
+        assert providers_manager._plugins_set.pop() == PluginInfo(
+            name="plugin1",
+            
plugin_class="airflow.providers.apache.hive.plugins.hive.HivePlugin",
+            provider_name="apache-airflow-providers-apache-hive",
+        )
+
     def test_hooks(self):
         with pytest.warns(expected_warning=None) as warning_records:
             with self._caplog.at_level(logging.WARNING):
diff --git a/tests/plugins/test_plugins_manager.py 
b/tests/plugins/test_plugins_manager.py
index 019e2a69a8..bf74863407 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -160,6 +160,13 @@ def test_flaskappbuilder_nomenu_views():
 
 
 class TestPluginsManager:
+    @pytest.fixture(autouse=True, scope="function")
+    def clean_plugins(self):
+        from airflow import plugins_manager
+
+        plugins_manager.loaded_plugins = set()
+        plugins_manager.plugins = []
+
     def test_no_log_when_no_plugins(self, caplog):
 
         with mock_plugin_manager(plugins=[]):
@@ -378,6 +385,32 @@ class TestPluginsManager:
             assert get_listener_manager().has_listeners
             assert get_listener_manager().pm.get_plugins().pop().__name__ == 
"tests.listeners.empty_listener"
 
+    def test_should_import_plugin_from_providers(self):
+        from airflow import plugins_manager
+
+        with mock.patch("airflow.plugins_manager.plugins", []):
+            assert len(plugins_manager.plugins) == 0
+            plugins_manager.load_providers_plugins()
+            assert len(plugins_manager.plugins) >= 2
+
+    def test_does_not_double_import_entrypoint_provider_plugins(self):
+        from airflow import plugins_manager
+
+        mock_entrypoint = mock.Mock()
+        mock_entrypoint.name = "test-entrypoint-plugin"
+        mock_entrypoint.module = "module_name_plugin"
+
+        mock_dist = mock.Mock()
+        mock_dist.metadata = {"Name": "test-entrypoint-plugin"}
+        mock_dist.version = "1.0.0"
+        mock_dist.entry_points = [mock_entrypoint]
+
+        with mock.patch("airflow.plugins_manager.plugins", []):
+            assert len(plugins_manager.plugins) == 0
+            plugins_manager.load_entrypoint_plugins()
+            plugins_manager.load_providers_plugins()
+            assert len(plugins_manager.plugins) == 2
+
 
 class TestPluginsDirectorySource:
     def test_should_return_correct_path_name(self):

Reply via email to