This is an automated email from the ASF dual-hosted git repository.

jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git


The following commit(s) were added to refs/heads/main by this push:
     new 44271cbb3 [#4173] improvement(client-python): Add OAuth Error Handler 
and related exceptions, test cases in client-python (#4324)
44271cbb3 is described below

commit 44271cbb3bdaea4179761a5aa1dcb620196a8166
Author: noidname01 <55401762+noidnam...@users.noreply.github.com>
AuthorDate: Thu Aug 1 14:43:41 2024 +0800

    [#4173] improvement(client-python): Add OAuth Error Handler and related 
exceptions, test cases in client-python (#4324)
    
    ### What changes were proposed in this pull request?
    
    * Add OAuth Error Handler and related exceptions, UT in `client-python`
    based on `client-java`
    
    ### Why are the changes needed?
    
    Fix: #4173
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT Added and test with `./gradlew client:client-python:test`
    
    ---------
    
    Co-authored-by: TimWang <tim.w...@pranaq.com>
---
 .../auth/default_oauth2_token_provider.py          |  5 +-
 .../dto/responses/oauth2_error_response.py         | 40 +++++++++++++++
 clients/client-python/gravitino/exceptions/base.py |  8 +++
 .../exceptions/handlers/oauth_error_handler.py     | 58 ++++++++++++++++++++++
 .../client-python/gravitino/utils/http_client.py   | 13 ++++-
 .../tests/unittests/auth/mock_base.py              | 26 ++++++++++
 .../unittests/auth/test_oauth2_token_provider.py   | 30 ++++++++++-
 7 files changed, 175 insertions(+), 5 deletions(-)

diff --git 
a/clients/client-python/gravitino/auth/default_oauth2_token_provider.py 
b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
index 3fb730395..beefc90c4 100644
--- a/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
+++ b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
@@ -27,6 +27,7 @@ from gravitino.dto.requests.oauth2_client_credential_request 
import (
     OAuth2ClientCredentialRequest,
 )
 from gravitino.exceptions.base import GravitinoRuntimeException
+from gravitino.exceptions.handlers.oauth_error_handler import 
OAUTH_ERROR_HANDLER
 
 CLIENT_CREDENTIALS = "client_credentials"
 CREDENTIAL_SPLITTER = ":"
@@ -107,7 +108,9 @@ class DefaultOAuth2TokenProvider(OAuth2TokenProvider):
         )
 
         resp = self._client.post_form(
-            self._path, data=client_credential_request.to_dict()
+            self._path,
+            data=client_credential_request,
+            error_handler=OAUTH_ERROR_HANDLER,
         )
         oauth2_resp = OAuth2TokenResponse.from_json(resp.body, 
infer_missing=True)
         oauth2_resp.validate()
diff --git 
a/clients/client-python/gravitino/dto/responses/oauth2_error_response.py 
b/clients/client-python/gravitino/dto/responses/oauth2_error_response.py
new file mode 100644
index 000000000..f7e472c13
--- /dev/null
+++ b/clients/client-python/gravitino/dto/responses/oauth2_error_response.py
@@ -0,0 +1,40 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+from dataclasses import dataclass, field
+from dataclasses_json import config
+
+from gravitino.dto.responses.error_response import ErrorResponse
+
+
+@dataclass
+class OAuth2ErrorResponse(ErrorResponse):
+    """Represents the response of an OAuth2 error."""
+
+    _type: str = field(metadata=config(field_name="error"))
+    _message: str = field(metadata=config(field_name="error_description"))
+
+    def type(self):
+        return self._type
+
+    def message(self):
+        return self._message
+
+    def validate(self):
+        assert self._type is not None, "OAuthErrorResponse should contain type"
diff --git a/clients/client-python/gravitino/exceptions/base.py 
b/clients/client-python/gravitino/exceptions/base.py
index 418304d7d..7700e151a 100644
--- a/clients/client-python/gravitino/exceptions/base.py
+++ b/clients/client-python/gravitino/exceptions/base.py
@@ -85,3 +85,11 @@ class 
UnsupportedOperationException(GravitinoRuntimeException):
 
 class UnknownError(RuntimeError):
     """An exception thrown when other unknown exception is thrown"""
+
+
+class UnauthorizedException(GravitinoRuntimeException):
+    """An exception thrown when a user is not authorized to perform an 
action."""
+
+
+class BadRequestException(GravitinoRuntimeException):
+    """An exception thrown when the request is invalid."""
diff --git 
a/clients/client-python/gravitino/exceptions/handlers/oauth_error_handler.py 
b/clients/client-python/gravitino/exceptions/handlers/oauth_error_handler.py
new file mode 100644
index 000000000..ede4d58ae
--- /dev/null
+++ b/clients/client-python/gravitino/exceptions/handlers/oauth_error_handler.py
@@ -0,0 +1,58 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+from gravitino.exceptions.base import UnauthorizedException, 
BadRequestException
+from gravitino.dto.responses.oauth2_error_response import OAuth2ErrorResponse
+from gravitino.exceptions.handlers.rest_error_handler import RestErrorHandler
+
+INVALID_CLIENT_ERROR = "invalid_client"
+INVALID_REQUEST_ERROR = "invalid_request"
+INVALID_GRANT_ERROR = "invalid_grant"
+UNAUTHORIZED_CLIENT_ERROR = "unauthorized_client"
+UNSUPPORTED_GRANT_TYPE_ERROR = "unsupported_grant_type"
+INVALID_SCOPE_ERROR = "invalid_scope"
+
+
+class OAuthErrorHandler(RestErrorHandler):
+
+    def handle(self, error_response: OAuth2ErrorResponse):
+
+        error_message = error_response.message()
+        exception_type = error_response.type()
+
+        if exception_type == INVALID_CLIENT_ERROR:
+            raise UnauthorizedException(
+                f"Not authorized: {exception_type}: {error_message}"
+            )
+
+        if exception_type in [
+            INVALID_REQUEST_ERROR,
+            INVALID_GRANT_ERROR,
+            UNAUTHORIZED_CLIENT_ERROR,
+            UNSUPPORTED_GRANT_TYPE_ERROR,
+            INVALID_SCOPE_ERROR,
+        ]:
+            raise BadRequestException(
+                f"Malformed request: {exception_type}: {error_message}"
+            )
+
+        super().handle(error_response)
+
+
+OAUTH_ERROR_HANDLER = OAuthErrorHandler()
diff --git a/clients/client-python/gravitino/utils/http_client.py 
b/clients/client-python/gravitino/utils/http_client.py
index 89b75d641..678942bb4 100644
--- a/clients/client-python/gravitino/utils/http_client.py
+++ b/clients/client-python/gravitino/utils/http_client.py
@@ -37,6 +37,7 @@ from gravitino.typing import JSONType
 from gravitino.constants.timeout import TIMEOUT
 
 from gravitino.dto.responses.error_response import ErrorResponse
+from gravitino.dto.responses.oauth2_error_response import OAuth2ErrorResponse
 from gravitino.exceptions.base import RESTException, UnknownError
 from gravitino.exceptions.handlers.error_handler import ErrorHandler
 
@@ -145,11 +146,19 @@ class HTTPClient:
                     ErrorResponse.generate_error_response(RESTException, 
err.reason),
                 )
 
-            err_resp = ErrorResponse.from_json(err_body, infer_missing=True)
+            err_resp = self._parse_error_response(err_body)
             err_resp.validate()
 
             return (False, err_resp)
 
+    def _parse_error_response(self, err_body: bytes) -> ErrorResponse:
+        json_err_body = _json.loads(err_body)
+
+        if "code" in json_err_body:
+            return ErrorResponse.from_json(err_body, infer_missing=True)
+
+        return OAuth2ErrorResponse.from_json(err_body, infer_missing=True)
+
     # pylint: disable=too-many-locals
     def _request(
         self,
@@ -228,7 +237,7 @@ class HTTPClient:
 
     def post_form(self, endpoint, data=None, error_handler=None, **kwargs):
         return self._request(
-            "post", endpoint, data=data, error_handler=error_handler**kwargs
+            "post", endpoint, data=data, error_handler=error_handler, **kwargs
         )
 
     def close(self):
diff --git a/clients/client-python/tests/unittests/auth/mock_base.py 
b/clients/client-python/tests/unittests/auth/mock_base.py
index f7b66c6b3..2becd5457 100644
--- a/clients/client-python/tests/unittests/auth/mock_base.py
+++ b/clients/client-python/tests/unittests/auth/mock_base.py
@@ -28,6 +28,12 @@ from cryptography.hazmat.primitives import serialization as 
crypto_serialization
 from cryptography.hazmat.primitives.asymmetric import rsa
 from cryptography.hazmat.backends import default_backend as 
crypto_default_backend
 
+from gravitino.dto.responses.oauth2_error_response import OAuth2ErrorResponse
+from gravitino.exceptions.handlers.oauth_error_handler import (
+    INVALID_CLIENT_ERROR,
+    INVALID_GRANT_ERROR,
+)
+
 
 @dataclass
 class TestResponse:
@@ -61,6 +67,26 @@ JWT_PRIVATE_KEY = generate_private_key()
 GENERATED_TIME = int(time.time())
 
 
+def mock_authentication_invalid_client_error():
+    return (
+        False,
+        OAuth2ErrorResponse.from_json(
+            json.dumps({"error": INVALID_CLIENT_ERROR, "error_description": 
"invalid"}),
+            infer_missing=True,
+        ),
+    )
+
+
+def mock_authentication_invalid_grant_error():
+    return (
+        False,
+        OAuth2ErrorResponse.from_json(
+            json.dumps({"error": INVALID_GRANT_ERROR, "error_description": 
"invalid"}),
+            infer_missing=True,
+        ),
+    )
+
+
 def mock_authentication_with_error_authentication_type():
     return TestResponse(
         body=json.dumps(
diff --git 
a/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py 
b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
index b60efbf04..7d9ef9e25 100644
--- a/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
+++ b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
@@ -22,6 +22,7 @@ from unittest.mock import patch
 
 from gravitino.auth.auth_constants import AuthConstants
 from gravitino.auth.default_oauth2_token_provider import 
DefaultOAuth2TokenProvider
+from gravitino.exceptions.base import BadRequestException, 
UnauthorizedException
 from tests.unittests.auth import mock_base
 
 OAUTH_PORT = 1082
@@ -40,8 +41,33 @@ class TestOAuth2TokenProvider(unittest.TestCase):
         with self.assertRaises(AssertionError):
             _ = DefaultOAuth2TokenProvider(uri="test", credential="xx", 
scope="test")
 
-    # TODO
-    # Error Test
+    @patch(
+        "gravitino.utils.http_client.HTTPClient._make_request",
+        return_value=mock_base.mock_authentication_invalid_client_error(),
+    )
+    def test_authertication_invalid_client_error(self, *mock_methods):
+
+        with self.assertRaises(UnauthorizedException):
+            _ = DefaultOAuth2TokenProvider(
+                uri=f"http://127.0.0.1:{OAUTH_PORT}";,
+                credential="yy:xx",
+                path="oauth/token",
+                scope="test",
+            )
+
+    @patch(
+        "gravitino.utils.http_client.HTTPClient._make_request",
+        return_value=mock_base.mock_authentication_invalid_grant_error(),
+    )
+    def test_authertication_invalid_grant_error(self, *mock_methods):
+
+        with self.assertRaises(BadRequestException):
+            _ = DefaultOAuth2TokenProvider(
+                uri=f"http://127.0.0.1:{OAUTH_PORT}";,
+                credential="yy:xx",
+                path="oauth/token",
+                scope="test",
+            )
 
     @patch(
         "gravitino.utils.http_client.HTTPClient.post_form",

Reply via email to