This is an automated email from the ASF dual-hosted git repository. jshao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push: new 44271cbb3 [#4173] improvement(client-python): Add OAuth Error Handler and related exceptions, test cases in client-python (#4324) 44271cbb3 is described below commit 44271cbb3bdaea4179761a5aa1dcb620196a8166 Author: noidname01 <55401762+noidnam...@users.noreply.github.com> AuthorDate: Thu Aug 1 14:43:41 2024 +0800 [#4173] improvement(client-python): Add OAuth Error Handler and related exceptions, test cases in client-python (#4324) ### What changes were proposed in this pull request? * Add OAuth Error Handler and related exceptions, UT in `client-python` based on `client-java` ### Why are the changes needed? Fix: #4173 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Added and test with `./gradlew client:client-python:test` --------- Co-authored-by: TimWang <tim.w...@pranaq.com> --- .../auth/default_oauth2_token_provider.py | 5 +- .../dto/responses/oauth2_error_response.py | 40 +++++++++++++++ clients/client-python/gravitino/exceptions/base.py | 8 +++ .../exceptions/handlers/oauth_error_handler.py | 58 ++++++++++++++++++++++ .../client-python/gravitino/utils/http_client.py | 13 ++++- .../tests/unittests/auth/mock_base.py | 26 ++++++++++ .../unittests/auth/test_oauth2_token_provider.py | 30 ++++++++++- 7 files changed, 175 insertions(+), 5 deletions(-) diff --git a/clients/client-python/gravitino/auth/default_oauth2_token_provider.py b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py index 3fb730395..beefc90c4 100644 --- a/clients/client-python/gravitino/auth/default_oauth2_token_provider.py +++ b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py @@ -27,6 +27,7 @@ from gravitino.dto.requests.oauth2_client_credential_request import ( OAuth2ClientCredentialRequest, ) from gravitino.exceptions.base import GravitinoRuntimeException +from gravitino.exceptions.handlers.oauth_error_handler import OAUTH_ERROR_HANDLER CLIENT_CREDENTIALS = "client_credentials" CREDENTIAL_SPLITTER = ":" @@ -107,7 +108,9 @@ class DefaultOAuth2TokenProvider(OAuth2TokenProvider): ) resp = self._client.post_form( - self._path, data=client_credential_request.to_dict() + self._path, + data=client_credential_request, + error_handler=OAUTH_ERROR_HANDLER, ) oauth2_resp = OAuth2TokenResponse.from_json(resp.body, infer_missing=True) oauth2_resp.validate() diff --git a/clients/client-python/gravitino/dto/responses/oauth2_error_response.py b/clients/client-python/gravitino/dto/responses/oauth2_error_response.py new file mode 100644 index 000000000..f7e472c13 --- /dev/null +++ b/clients/client-python/gravitino/dto/responses/oauth2_error_response.py @@ -0,0 +1,40 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +""" + +from dataclasses import dataclass, field +from dataclasses_json import config + +from gravitino.dto.responses.error_response import ErrorResponse + + +@dataclass +class OAuth2ErrorResponse(ErrorResponse): + """Represents the response of an OAuth2 error.""" + + _type: str = field(metadata=config(field_name="error")) + _message: str = field(metadata=config(field_name="error_description")) + + def type(self): + return self._type + + def message(self): + return self._message + + def validate(self): + assert self._type is not None, "OAuthErrorResponse should contain type" diff --git a/clients/client-python/gravitino/exceptions/base.py b/clients/client-python/gravitino/exceptions/base.py index 418304d7d..7700e151a 100644 --- a/clients/client-python/gravitino/exceptions/base.py +++ b/clients/client-python/gravitino/exceptions/base.py @@ -85,3 +85,11 @@ class UnsupportedOperationException(GravitinoRuntimeException): class UnknownError(RuntimeError): """An exception thrown when other unknown exception is thrown""" + + +class UnauthorizedException(GravitinoRuntimeException): + """An exception thrown when a user is not authorized to perform an action.""" + + +class BadRequestException(GravitinoRuntimeException): + """An exception thrown when the request is invalid.""" diff --git a/clients/client-python/gravitino/exceptions/handlers/oauth_error_handler.py b/clients/client-python/gravitino/exceptions/handlers/oauth_error_handler.py new file mode 100644 index 000000000..ede4d58ae --- /dev/null +++ b/clients/client-python/gravitino/exceptions/handlers/oauth_error_handler.py @@ -0,0 +1,58 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +""" + +from gravitino.exceptions.base import UnauthorizedException, BadRequestException +from gravitino.dto.responses.oauth2_error_response import OAuth2ErrorResponse +from gravitino.exceptions.handlers.rest_error_handler import RestErrorHandler + +INVALID_CLIENT_ERROR = "invalid_client" +INVALID_REQUEST_ERROR = "invalid_request" +INVALID_GRANT_ERROR = "invalid_grant" +UNAUTHORIZED_CLIENT_ERROR = "unauthorized_client" +UNSUPPORTED_GRANT_TYPE_ERROR = "unsupported_grant_type" +INVALID_SCOPE_ERROR = "invalid_scope" + + +class OAuthErrorHandler(RestErrorHandler): + + def handle(self, error_response: OAuth2ErrorResponse): + + error_message = error_response.message() + exception_type = error_response.type() + + if exception_type == INVALID_CLIENT_ERROR: + raise UnauthorizedException( + f"Not authorized: {exception_type}: {error_message}" + ) + + if exception_type in [ + INVALID_REQUEST_ERROR, + INVALID_GRANT_ERROR, + UNAUTHORIZED_CLIENT_ERROR, + UNSUPPORTED_GRANT_TYPE_ERROR, + INVALID_SCOPE_ERROR, + ]: + raise BadRequestException( + f"Malformed request: {exception_type}: {error_message}" + ) + + super().handle(error_response) + + +OAUTH_ERROR_HANDLER = OAuthErrorHandler() diff --git a/clients/client-python/gravitino/utils/http_client.py b/clients/client-python/gravitino/utils/http_client.py index 89b75d641..678942bb4 100644 --- a/clients/client-python/gravitino/utils/http_client.py +++ b/clients/client-python/gravitino/utils/http_client.py @@ -37,6 +37,7 @@ from gravitino.typing import JSONType from gravitino.constants.timeout import TIMEOUT from gravitino.dto.responses.error_response import ErrorResponse +from gravitino.dto.responses.oauth2_error_response import OAuth2ErrorResponse from gravitino.exceptions.base import RESTException, UnknownError from gravitino.exceptions.handlers.error_handler import ErrorHandler @@ -145,11 +146,19 @@ class HTTPClient: ErrorResponse.generate_error_response(RESTException, err.reason), ) - err_resp = ErrorResponse.from_json(err_body, infer_missing=True) + err_resp = self._parse_error_response(err_body) err_resp.validate() return (False, err_resp) + def _parse_error_response(self, err_body: bytes) -> ErrorResponse: + json_err_body = _json.loads(err_body) + + if "code" in json_err_body: + return ErrorResponse.from_json(err_body, infer_missing=True) + + return OAuth2ErrorResponse.from_json(err_body, infer_missing=True) + # pylint: disable=too-many-locals def _request( self, @@ -228,7 +237,7 @@ class HTTPClient: def post_form(self, endpoint, data=None, error_handler=None, **kwargs): return self._request( - "post", endpoint, data=data, error_handler=error_handler**kwargs + "post", endpoint, data=data, error_handler=error_handler, **kwargs ) def close(self): diff --git a/clients/client-python/tests/unittests/auth/mock_base.py b/clients/client-python/tests/unittests/auth/mock_base.py index f7b66c6b3..2becd5457 100644 --- a/clients/client-python/tests/unittests/auth/mock_base.py +++ b/clients/client-python/tests/unittests/auth/mock_base.py @@ -28,6 +28,12 @@ from cryptography.hazmat.primitives import serialization as crypto_serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import default_backend as crypto_default_backend +from gravitino.dto.responses.oauth2_error_response import OAuth2ErrorResponse +from gravitino.exceptions.handlers.oauth_error_handler import ( + INVALID_CLIENT_ERROR, + INVALID_GRANT_ERROR, +) + @dataclass class TestResponse: @@ -61,6 +67,26 @@ JWT_PRIVATE_KEY = generate_private_key() GENERATED_TIME = int(time.time()) +def mock_authentication_invalid_client_error(): + return ( + False, + OAuth2ErrorResponse.from_json( + json.dumps({"error": INVALID_CLIENT_ERROR, "error_description": "invalid"}), + infer_missing=True, + ), + ) + + +def mock_authentication_invalid_grant_error(): + return ( + False, + OAuth2ErrorResponse.from_json( + json.dumps({"error": INVALID_GRANT_ERROR, "error_description": "invalid"}), + infer_missing=True, + ), + ) + + def mock_authentication_with_error_authentication_type(): return TestResponse( body=json.dumps( diff --git a/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py index b60efbf04..7d9ef9e25 100644 --- a/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py +++ b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py @@ -22,6 +22,7 @@ from unittest.mock import patch from gravitino.auth.auth_constants import AuthConstants from gravitino.auth.default_oauth2_token_provider import DefaultOAuth2TokenProvider +from gravitino.exceptions.base import BadRequestException, UnauthorizedException from tests.unittests.auth import mock_base OAUTH_PORT = 1082 @@ -40,8 +41,33 @@ class TestOAuth2TokenProvider(unittest.TestCase): with self.assertRaises(AssertionError): _ = DefaultOAuth2TokenProvider(uri="test", credential="xx", scope="test") - # TODO - # Error Test + @patch( + "gravitino.utils.http_client.HTTPClient._make_request", + return_value=mock_base.mock_authentication_invalid_client_error(), + ) + def test_authertication_invalid_client_error(self, *mock_methods): + + with self.assertRaises(UnauthorizedException): + _ = DefaultOAuth2TokenProvider( + uri=f"http://127.0.0.1:{OAUTH_PORT}", + credential="yy:xx", + path="oauth/token", + scope="test", + ) + + @patch( + "gravitino.utils.http_client.HTTPClient._make_request", + return_value=mock_base.mock_authentication_invalid_grant_error(), + ) + def test_authertication_invalid_grant_error(self, *mock_methods): + + with self.assertRaises(BadRequestException): + _ = DefaultOAuth2TokenProvider( + uri=f"http://127.0.0.1:{OAUTH_PORT}", + credential="yy:xx", + path="oauth/token", + scope="test", + ) @patch( "gravitino.utils.http_client.HTTPClient.post_form",