This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2919abe5b3 Add more ways to connect to weaviate (#35864)
2919abe5b3 is described below

commit 2919abe5b3f2d186c896aebbc51acf98d554ef33
Author: Ephraim Anierobi <splendidzig...@gmail.com>
AuthorDate: Tue Nov 28 20:30:06 2023 +0100

    Add more ways to connect to weaviate (#35864)
    
    * Add more ways to connect to weaviate
    
    There are other options for connecting to weaviate. This commit adds
    these other options and also improved the imports/typing
    
    * fixup! Add more ways to connect to weaviate
    
    * fixup! fixup! Add more ways to connect to weaviate
    
    * add depreccation
    
    * remove mark as dbtest
---
 airflow/providers/weaviate/hooks/weaviate.py       |  50 ++++---
 .../connections.rst                                |  19 +++
 tests/providers/weaviate/hooks/test_weaviate.py    | 148 ++++++++++++++++++++-
 3 files changed, 198 insertions(+), 19 deletions(-)

diff --git a/airflow/providers/weaviate/hooks/weaviate.py 
b/airflow/providers/weaviate/hooks/weaviate.py
index c8b0ed05d4..151aaabea6 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -17,10 +17,13 @@
 
 from __future__ import annotations
 
+import warnings
 from typing import Any
 
-import weaviate
+from weaviate import Client as WeaviateClient
+from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, 
AuthClientPassword
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
 
 
@@ -40,19 +43,19 @@ class WeaviateHook(BaseHook):
         super().__init__(*args, **kwargs)
         self.conn_id = conn_id
 
-    @staticmethod
-    def get_connection_form_widgets() -> dict[str, Any]:
+    @classmethod
+    def get_connection_form_widgets(cls) -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget
         from flask_babel import lazy_gettext
         from wtforms import PasswordField
 
         return {
-            "token": PasswordField(lazy_gettext("Weaviate API Token"), 
widget=BS3PasswordFieldWidget()),
+            "token": PasswordField(lazy_gettext("Weaviate API Key"), 
widget=BS3PasswordFieldWidget()),
         }
 
-    @staticmethod
-    def get_ui_field_behaviour() -> dict[str, Any]:
+    @classmethod
+    def get_ui_field_behaviour(cls) -> dict[str, Any]:
         """Returns custom field behaviour."""
         return {
             "hidden_fields": ["port", "schema"],
@@ -62,28 +65,43 @@ class WeaviateHook(BaseHook):
             },
         }
 
-    def get_client(self) -> weaviate.Client:
+    def get_conn(self) -> WeaviateClient:
         conn = self.get_connection(self.conn_id)
         url = conn.host
         username = conn.login or ""
         password = conn.password or ""
         extras = conn.extra_dejson
-        token = extras.pop("token", "")
+        access_token = extras.get("access_token", None)
+        refresh_token = extras.get("refresh_token", None)
+        expires_in = extras.get("expires_in", 60)
+        # previously token was used as api_key(backwards compatibility)
+        api_key = extras.get("api_key", None) or extras.get("token", None)
+        client_secret = extras.get("client_secret", None)
         additional_headers = extras.pop("additional_headers", {})
-        scope = conn.extra_dejson.get("oidc_scope", "offline_access")
-
-        if token == "" and username != "":
-            auth_client_secret = weaviate.AuthClientPassword(
-                username=username, password=password, scope=scope
+        scope = extras.get("scope", None) or extras.get("oidc_scope", None)
+        if api_key:
+            auth_client_secret = AuthApiKey(api_key)
+        elif access_token:
+            auth_client_secret = AuthBearerToken(
+                access_token, expires_in=expires_in, 
refresh_token=refresh_token
             )
+        elif client_secret:
+            auth_client_secret = 
AuthClientCredentials(client_secret=client_secret, scope=scope)
         else:
-            auth_client_secret = weaviate.AuthApiKey(token)
+            auth_client_secret = AuthClientPassword(username=username, 
password=password, scope=scope)
 
-        client = weaviate.Client(
+        return WeaviateClient(
             url=url, auth_client_secret=auth_client_secret, 
additional_headers=additional_headers
         )
 
-        return client
+    def get_client(self) -> WeaviateClient:
+        # Keeping this for backwards compatibility
+        warnings.warn(
+            "The `get_client` method has been renamed to `get_conn`",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
+        return self.get_conn()
 
     def test_connection(self) -> tuple[bool, str]:
         try:
diff --git a/docs/apache-airflow-providers-weaviate/connections.rst 
b/docs/apache-airflow-providers-weaviate/connections.rst
index 5e16164ff6..081fe14d92 100644
--- a/docs/apache-airflow-providers-weaviate/connections.rst
+++ b/docs/apache-airflow-providers-weaviate/connections.rst
@@ -42,6 +42,8 @@ OIDC Password (optional)
 Extra (optional)
     Specify the extra parameters (as json dictionary) that can be used in the
     connection. All parameters are optional.
+    The extras are those parameters that are acceptable in the different 
authentication methods
+    here: `Authentication 
<https://weaviate-python-client.readthedocs.io/en/stable/weaviate.auth.html>`__
 
     * If you'd like to use Vectorizers for your class, configure the API keys 
to use the corresponding
       embedding API. The extras accepts a key ``additional_headers`` 
containing the dictionary
@@ -50,3 +52,20 @@ Extra (optional)
 
 Weaviate API Token (optional)
     Specify your Weaviate API Key to connect when API Key option is to be used 
for authentication.
+
+Supported Authentication Methods
+--------------------------------
+* API Key Authentication: This method uses the Weaviate API Key to 
authenticate the connection. You can either have the
+  API key in the ``Weaviate API Token`` field or in the extra field as a 
dictionary with key ``token`` or ``api_key`` and
+  value as the API key.
+
+* Bearer Token Authentication: This method uses the Access Token to 
authenticate the connection. You need to
+  have the Access Token in the extra field as a dictionary with key 
``access_token`` and value as the Access Token. Other
+  parameters such as ``expires_in`` and ``refresh_token`` are optional.
+
+* Client Credentials Authentication: This method uses the Client Credentials 
to authenticate the connection. You need to
+  have the Client Credentials in the extra field as a dictionary with key 
``client_secret`` and value as the Client Credentials.
+  The ``scope`` is optional.
+
+* Password Authentication: This method uses the username and password to 
authenticate the connection. You can specify the
+  scope in the extra field as a dictionary with key ``scope`` and value as the 
scope. The ``scope`` is optional.
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py 
b/tests/providers/weaviate/hooks/test_weaviate.py
index 56f57ebc9b..0274004fc0 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -16,10 +16,12 @@
 # under the License.
 from __future__ import annotations
 
+from unittest import mock
 from unittest.mock import MagicMock, Mock, patch
 
 import pytest
 
+from airflow.models import Connection
 from airflow.providers.weaviate.hooks.weaviate import WeaviateHook
 
 TEST_CONN_ID = "test_weaviate_conn"
@@ -38,13 +40,153 @@ def weaviate_hook():
     return hook
 
 
+@pytest.fixture
+def mock_auth_api_key():
+    with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthApiKey") as 
m:
+        yield m
+
+
+@pytest.fixture
+def mock_auth_bearer_token():
+    with 
mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthBearerToken") as m:
+        yield m
+
+
+@pytest.fixture
+def mock_auth_client_credentials():
+    with 
mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientCredentials") 
as m:
+        yield m
+
+
+@pytest.fixture
+def mock_auth_client_password():
+    with 
mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientPassword") as m:
+        yield m
+
+
+class TestWeaviateHook:
+    """
+    Test the WeaviateHook Hook.
+    """
+
+    @pytest.fixture(autouse=True)
+    def setup_method(self, monkeypatch):
+        """Set up the test method."""
+        self.weaviate_api_key1 = "weaviate_api_key1"
+        self.weaviate_api_key2 = "weaviate_api_key2"
+        self.api_key = "api_key"
+        self.weaviate_client_credentials = "weaviate_client_credentials"
+        self.client_secret = "client_secret"
+        self.scope = "scope1 scope2"
+        self.client_password = "client_password"
+        self.client_bearer_token = "client_bearer_token"
+        self.host = "http://localhost:8080";
+        conns = (
+            Connection(
+                conn_id=self.weaviate_api_key1,
+                host=self.host,
+                conn_type="weaviate",
+                extra={"api_key": self.api_key},
+            ),
+            Connection(
+                conn_id=self.weaviate_api_key2,
+                host=self.host,
+                conn_type="weaviate",
+                extra={"token": self.api_key},
+            ),
+            Connection(
+                conn_id=self.weaviate_client_credentials,
+                host=self.host,
+                conn_type="weaviate",
+                extra={"client_secret": self.client_secret, "scope": 
self.scope},
+            ),
+            Connection(
+                conn_id=self.client_password,
+                host=self.host,
+                conn_type="weaviate",
+                login="login",
+                password="password",
+            ),
+            Connection(
+                conn_id=self.client_bearer_token,
+                host=self.host,
+                conn_type="weaviate",
+                extra={
+                    "access_token": self.client_bearer_token,
+                    "expires_in": 30,
+                    "refresh_token": "refresh_token",
+                },
+            ),
+        )
+        for conn in conns:
+            monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", 
conn.get_uri())
+
+    @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+    def test_get_conn_with_api_key_in_extra(self, mock_client, 
mock_auth_api_key):
+        hook = WeaviateHook(conn_id=self.weaviate_api_key1)
+        hook.get_conn()
+        mock_auth_api_key.assert_called_once_with(self.api_key)
+        mock_client.assert_called_once_with(
+            url=self.host, 
auth_client_secret=mock_auth_api_key(api_key=self.api_key), 
additional_headers={}
+        )
+
+    @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+    def test_get_conn_with_token_in_extra(self, mock_client, 
mock_auth_api_key):
+        # when token is passed in extra
+        hook = WeaviateHook(conn_id=self.weaviate_api_key2)
+        hook.get_conn()
+        mock_auth_api_key.assert_called_once_with(self.api_key)
+        mock_client.assert_called_once_with(
+            url=self.host, 
auth_client_secret=mock_auth_api_key(api_key=self.api_key), 
additional_headers={}
+        )
+
+    @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+    def test_get_conn_with_access_token_in_extra(self, mock_client, 
mock_auth_bearer_token):
+        hook = WeaviateHook(conn_id=self.client_bearer_token)
+        hook.get_conn()
+        mock_auth_bearer_token.assert_called_once_with(
+            self.client_bearer_token, expires_in=30, 
refresh_token="refresh_token"
+        )
+        mock_client.assert_called_once_with(
+            url=self.host,
+            auth_client_secret=mock_auth_bearer_token(
+                access_token=self.client_bearer_token, expires_in=30, 
refresh_token="refresh_token"
+            ),
+            additional_headers={},
+        )
+
+    @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+    def test_get_conn_with_client_secret_in_extra(self, mock_client, 
mock_auth_client_credentials):
+        hook = WeaviateHook(conn_id=self.weaviate_client_credentials)
+        hook.get_conn()
+        mock_auth_client_credentials.assert_called_once_with(
+            client_secret=self.client_secret, scope=self.scope
+        )
+        mock_client.assert_called_once_with(
+            url=self.host,
+            
auth_client_secret=mock_auth_client_credentials(api_key=self.client_secret, 
scope=self.scope),
+            additional_headers={},
+        )
+
+    @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+    def test_get_conn_with_client_password_in_extra(self, mock_client, 
mock_auth_client_password):
+        hook = WeaviateHook(conn_id=self.client_password)
+        hook.get_conn()
+        mock_auth_client_password.assert_called_once_with(username="login", 
password="password", scope=None)
+        mock_client.assert_called_once_with(
+            url=self.host,
+            auth_client_secret=mock_auth_client_password(username="login", 
password="password", scope=None),
+            additional_headers={},
+        )
+
+
 def test_create_class(weaviate_hook):
     """
     Test the create_class method of WeaviateHook.
     """
     # Mock the Weaviate Client
     mock_client = MagicMock()
-    weaviate_hook.get_client = MagicMock(return_value=mock_client)
+    weaviate_hook.get_conn = MagicMock(return_value=mock_client)
 
     # Define test class JSON
     test_class_json = {
@@ -65,7 +207,7 @@ def test_create_schema(weaviate_hook):
     """
     # Mock the Weaviate Client
     mock_client = MagicMock()
-    weaviate_hook.get_client = MagicMock(return_value=mock_client)
+    weaviate_hook.get_conn = MagicMock(return_value=mock_client)
 
     # Define test schema JSON
     test_schema_json = {
@@ -90,7 +232,7 @@ def test_batch_data(weaviate_hook):
     """
     # Mock the Weaviate Client
     mock_client = MagicMock()
-    weaviate_hook.get_client = MagicMock(return_value=mock_client)
+    weaviate_hook.get_conn = MagicMock(return_value=mock_client)
 
     # Define test data
     test_class_name = "TestClass"

Reply via email to