This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 2919abe5b3 Add more ways to connect to weaviate (#35864) 2919abe5b3 is described below commit 2919abe5b3f2d186c896aebbc51acf98d554ef33 Author: Ephraim Anierobi <splendidzig...@gmail.com> AuthorDate: Tue Nov 28 20:30:06 2023 +0100 Add more ways to connect to weaviate (#35864) * Add more ways to connect to weaviate There are other options for connecting to weaviate. This commit adds these other options and also improved the imports/typing * fixup! Add more ways to connect to weaviate * fixup! fixup! Add more ways to connect to weaviate * add depreccation * remove mark as dbtest --- airflow/providers/weaviate/hooks/weaviate.py | 50 ++++--- .../connections.rst | 19 +++ tests/providers/weaviate/hooks/test_weaviate.py | 148 ++++++++++++++++++++- 3 files changed, 198 insertions(+), 19 deletions(-) diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py index c8b0ed05d4..151aaabea6 100644 --- a/airflow/providers/weaviate/hooks/weaviate.py +++ b/airflow/providers/weaviate/hooks/weaviate.py @@ -17,10 +17,13 @@ from __future__ import annotations +import warnings from typing import Any -import weaviate +from weaviate import Client as WeaviateClient +from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook @@ -40,19 +43,19 @@ class WeaviateHook(BaseHook): super().__init__(*args, **kwargs) self.conn_id = conn_id - @staticmethod - def get_connection_form_widgets() -> dict[str, Any]: + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: """Returns connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget from flask_babel import lazy_gettext from wtforms import PasswordField return { - "token": PasswordField(lazy_gettext("Weaviate API Token"), widget=BS3PasswordFieldWidget()), + "token": PasswordField(lazy_gettext("Weaviate API Key"), widget=BS3PasswordFieldWidget()), } - @staticmethod - def get_ui_field_behaviour() -> dict[str, Any]: + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: """Returns custom field behaviour.""" return { "hidden_fields": ["port", "schema"], @@ -62,28 +65,43 @@ class WeaviateHook(BaseHook): }, } - def get_client(self) -> weaviate.Client: + def get_conn(self) -> WeaviateClient: conn = self.get_connection(self.conn_id) url = conn.host username = conn.login or "" password = conn.password or "" extras = conn.extra_dejson - token = extras.pop("token", "") + access_token = extras.get("access_token", None) + refresh_token = extras.get("refresh_token", None) + expires_in = extras.get("expires_in", 60) + # previously token was used as api_key(backwards compatibility) + api_key = extras.get("api_key", None) or extras.get("token", None) + client_secret = extras.get("client_secret", None) additional_headers = extras.pop("additional_headers", {}) - scope = conn.extra_dejson.get("oidc_scope", "offline_access") - - if token == "" and username != "": - auth_client_secret = weaviate.AuthClientPassword( - username=username, password=password, scope=scope + scope = extras.get("scope", None) or extras.get("oidc_scope", None) + if api_key: + auth_client_secret = AuthApiKey(api_key) + elif access_token: + auth_client_secret = AuthBearerToken( + access_token, expires_in=expires_in, refresh_token=refresh_token ) + elif client_secret: + auth_client_secret = AuthClientCredentials(client_secret=client_secret, scope=scope) else: - auth_client_secret = weaviate.AuthApiKey(token) + auth_client_secret = AuthClientPassword(username=username, password=password, scope=scope) - client = weaviate.Client( + return WeaviateClient( url=url, auth_client_secret=auth_client_secret, additional_headers=additional_headers ) - return client + def get_client(self) -> WeaviateClient: + # Keeping this for backwards compatibility + warnings.warn( + "The `get_client` method has been renamed to `get_conn`", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + return self.get_conn() def test_connection(self) -> tuple[bool, str]: try: diff --git a/docs/apache-airflow-providers-weaviate/connections.rst b/docs/apache-airflow-providers-weaviate/connections.rst index 5e16164ff6..081fe14d92 100644 --- a/docs/apache-airflow-providers-weaviate/connections.rst +++ b/docs/apache-airflow-providers-weaviate/connections.rst @@ -42,6 +42,8 @@ OIDC Password (optional) Extra (optional) Specify the extra parameters (as json dictionary) that can be used in the connection. All parameters are optional. + The extras are those parameters that are acceptable in the different authentication methods + here: `Authentication <https://weaviate-python-client.readthedocs.io/en/stable/weaviate.auth.html>`__ * If you'd like to use Vectorizers for your class, configure the API keys to use the corresponding embedding API. The extras accepts a key ``additional_headers`` containing the dictionary @@ -50,3 +52,20 @@ Extra (optional) Weaviate API Token (optional) Specify your Weaviate API Key to connect when API Key option is to be used for authentication. + +Supported Authentication Methods +-------------------------------- +* API Key Authentication: This method uses the Weaviate API Key to authenticate the connection. You can either have the + API key in the ``Weaviate API Token`` field or in the extra field as a dictionary with key ``token`` or ``api_key`` and + value as the API key. + +* Bearer Token Authentication: This method uses the Access Token to authenticate the connection. You need to + have the Access Token in the extra field as a dictionary with key ``access_token`` and value as the Access Token. Other + parameters such as ``expires_in`` and ``refresh_token`` are optional. + +* Client Credentials Authentication: This method uses the Client Credentials to authenticate the connection. You need to + have the Client Credentials in the extra field as a dictionary with key ``client_secret`` and value as the Client Credentials. + The ``scope`` is optional. + +* Password Authentication: This method uses the username and password to authenticate the connection. You can specify the + scope in the extra field as a dictionary with key ``scope`` and value as the scope. The ``scope`` is optional. diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py index 56f57ebc9b..0274004fc0 100644 --- a/tests/providers/weaviate/hooks/test_weaviate.py +++ b/tests/providers/weaviate/hooks/test_weaviate.py @@ -16,10 +16,12 @@ # under the License. from __future__ import annotations +from unittest import mock from unittest.mock import MagicMock, Mock, patch import pytest +from airflow.models import Connection from airflow.providers.weaviate.hooks.weaviate import WeaviateHook TEST_CONN_ID = "test_weaviate_conn" @@ -38,13 +40,153 @@ def weaviate_hook(): return hook +@pytest.fixture +def mock_auth_api_key(): + with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthApiKey") as m: + yield m + + +@pytest.fixture +def mock_auth_bearer_token(): + with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthBearerToken") as m: + yield m + + +@pytest.fixture +def mock_auth_client_credentials(): + with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientCredentials") as m: + yield m + + +@pytest.fixture +def mock_auth_client_password(): + with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientPassword") as m: + yield m + + +class TestWeaviateHook: + """ + Test the WeaviateHook Hook. + """ + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch): + """Set up the test method.""" + self.weaviate_api_key1 = "weaviate_api_key1" + self.weaviate_api_key2 = "weaviate_api_key2" + self.api_key = "api_key" + self.weaviate_client_credentials = "weaviate_client_credentials" + self.client_secret = "client_secret" + self.scope = "scope1 scope2" + self.client_password = "client_password" + self.client_bearer_token = "client_bearer_token" + self.host = "http://localhost:8080" + conns = ( + Connection( + conn_id=self.weaviate_api_key1, + host=self.host, + conn_type="weaviate", + extra={"api_key": self.api_key}, + ), + Connection( + conn_id=self.weaviate_api_key2, + host=self.host, + conn_type="weaviate", + extra={"token": self.api_key}, + ), + Connection( + conn_id=self.weaviate_client_credentials, + host=self.host, + conn_type="weaviate", + extra={"client_secret": self.client_secret, "scope": self.scope}, + ), + Connection( + conn_id=self.client_password, + host=self.host, + conn_type="weaviate", + login="login", + password="password", + ), + Connection( + conn_id=self.client_bearer_token, + host=self.host, + conn_type="weaviate", + extra={ + "access_token": self.client_bearer_token, + "expires_in": 30, + "refresh_token": "refresh_token", + }, + ), + ) + for conn in conns: + monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.get_uri()) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_conn_with_api_key_in_extra(self, mock_client, mock_auth_api_key): + hook = WeaviateHook(conn_id=self.weaviate_api_key1) + hook.get_conn() + mock_auth_api_key.assert_called_once_with(self.api_key) + mock_client.assert_called_once_with( + url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} + ) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_conn_with_token_in_extra(self, mock_client, mock_auth_api_key): + # when token is passed in extra + hook = WeaviateHook(conn_id=self.weaviate_api_key2) + hook.get_conn() + mock_auth_api_key.assert_called_once_with(self.api_key) + mock_client.assert_called_once_with( + url=self.host, auth_client_secret=mock_auth_api_key(api_key=self.api_key), additional_headers={} + ) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_conn_with_access_token_in_extra(self, mock_client, mock_auth_bearer_token): + hook = WeaviateHook(conn_id=self.client_bearer_token) + hook.get_conn() + mock_auth_bearer_token.assert_called_once_with( + self.client_bearer_token, expires_in=30, refresh_token="refresh_token" + ) + mock_client.assert_called_once_with( + url=self.host, + auth_client_secret=mock_auth_bearer_token( + access_token=self.client_bearer_token, expires_in=30, refresh_token="refresh_token" + ), + additional_headers={}, + ) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_conn_with_client_secret_in_extra(self, mock_client, mock_auth_client_credentials): + hook = WeaviateHook(conn_id=self.weaviate_client_credentials) + hook.get_conn() + mock_auth_client_credentials.assert_called_once_with( + client_secret=self.client_secret, scope=self.scope + ) + mock_client.assert_called_once_with( + url=self.host, + auth_client_secret=mock_auth_client_credentials(api_key=self.client_secret, scope=self.scope), + additional_headers={}, + ) + + @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient") + def test_get_conn_with_client_password_in_extra(self, mock_client, mock_auth_client_password): + hook = WeaviateHook(conn_id=self.client_password) + hook.get_conn() + mock_auth_client_password.assert_called_once_with(username="login", password="password", scope=None) + mock_client.assert_called_once_with( + url=self.host, + auth_client_secret=mock_auth_client_password(username="login", password="password", scope=None), + additional_headers={}, + ) + + def test_create_class(weaviate_hook): """ Test the create_class method of WeaviateHook. """ # Mock the Weaviate Client mock_client = MagicMock() - weaviate_hook.get_client = MagicMock(return_value=mock_client) + weaviate_hook.get_conn = MagicMock(return_value=mock_client) # Define test class JSON test_class_json = { @@ -65,7 +207,7 @@ def test_create_schema(weaviate_hook): """ # Mock the Weaviate Client mock_client = MagicMock() - weaviate_hook.get_client = MagicMock(return_value=mock_client) + weaviate_hook.get_conn = MagicMock(return_value=mock_client) # Define test schema JSON test_schema_json = { @@ -90,7 +232,7 @@ def test_batch_data(weaviate_hook): """ # Mock the Weaviate Client mock_client = MagicMock() - weaviate_hook.get_client = MagicMock(return_value=mock_client) + weaviate_hook.get_conn = MagicMock(return_value=mock_client) # Define test data test_class_name = "TestClass"