This is an automated email from the ASF dual-hosted git repository.

tomaz pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/libcloud.git

commit 6d44bb33a046c123bea1c0266d70e83caf38198d
Author: Dan Peschman <[email protected]>
AuthorDate: Fri Feb 26 15:05:09 2021 -0700

    add ex_auth_cache parameter to OpenStack drivers for caching of 
authentication tokens and reuse across processes
---
 libcloud/common/openstack.py                       |  39 ++-
 libcloud/common/openstack_identity.py              | 387 +++++++++++++++------
 libcloud/test/common/test_openstack_identity.py    | 138 +++++++-
 .../fixtures/openstack/_v3__auth_unauthorized.json |   1 +
 libcloud/test/compute/test_cloudwatt.py            |  11 +
 libcloud/test/compute/test_openstack.py            |  91 ++++-
 6 files changed, 543 insertions(+), 124 deletions(-)

diff --git a/libcloud/common/openstack.py b/libcloud/common/openstack.py
index 55d18aa..e41f11e 100644
--- a/libcloud/common/openstack.py
+++ b/libcloud/common/openstack.py
@@ -21,10 +21,12 @@ from libcloud.utils.py3 import ET
 from libcloud.utils.py3 import httplib
 
 from libcloud.common.base import ConnectionUserAndKey, Response
+from libcloud.common.exceptions import BaseHTTPError
 from libcloud.common.types import ProviderError
 from libcloud.compute.types import (LibcloudError, MalformedResponseError)
 from libcloud.compute.types import KeyPairDoesNotExistError
-from libcloud.common.openstack_identity import get_class_for_auth_version
+from libcloud.common.openstack_identity import (AUTH_TOKEN_HEADER,
+                                                get_class_for_auth_version)
 
 # Imports for backward compatibility reasons
 from libcloud.common.openstack_identity import (OpenStackServiceCatalog,
@@ -132,6 +134,13 @@ class OpenStackBaseConnection(ConnectionUserAndKey):
                                     If not specified, a provider specific
                                     default will be used.
     :type ex_force_service_region: ``str``
+
+    :param ex_auth_cache: External cache where authentication tokens are
+                          stored for reuse by other processes. Tokens are
+                          always cached in memory on the driver instance. To
+                          share tokens among multiple drivers, processes, or
+                          systems, pass a cache here.
+    :type ex_auth_cache: :class:`OpenStackAuthenticationCache`
     """
 
     auth_url = None  # type: str
@@ -158,6 +167,7 @@ class OpenStackBaseConnection(ConnectionUserAndKey):
                  ex_force_service_type=None,
                  ex_force_service_name=None,
                  ex_force_service_region=None,
+                 ex_auth_cache=None,
                  retry_delay=None, backoff=None):
         super(OpenStackBaseConnection, self).__init__(
             user_id, key, secure=secure, timeout=timeout,
@@ -177,6 +187,7 @@ class OpenStackBaseConnection(ConnectionUserAndKey):
         self._ex_force_service_type = ex_force_service_type
         self._ex_force_service_name = ex_force_service_name
         self._ex_force_service_region = ex_force_service_region
+        self._ex_auth_cache = ex_auth_cache
         self._osa = None
 
         if ex_force_auth_token and not ex_force_base_url:
@@ -215,7 +226,8 @@ class OpenStackBaseConnection(ConnectionUserAndKey):
                             token_scope=self._ex_token_scope,
                             timeout=self.timeout,
                             proxy_url=self.proxy_url,
-                            parent_conn=self)
+                            parent_conn=self,
+                            auth_cache=self._ex_auth_cache)
 
         return self._osa
 
@@ -229,12 +241,15 @@ class OpenStackBaseConnection(ConnectionUserAndKey):
         if method.upper() in ['POST', 'PUT'] and default_content_type:
             headers = {'Content-Type': default_content_type}
 
-        return super(OpenStackBaseConnection, self).request(action=action,
-                                                            params=params,
-                                                            data=data,
-                                                            method=method,
-                                                            headers=headers,
-                                                            raw=raw)
+        try:
+            return super().request(action=action, params=params, data=data,
+                                   method=method, headers=headers, raw=raw)
+        except BaseHTTPError as ex:
+            # Evict cached auth token if we receive Unauthorized while using it
+            if (ex.code == httplib.UNAUTHORIZED
+                    and self._ex_force_auth_token is None):
+                self.get_auth_class().clear_cached_auth_context()
+            raise
 
     def _get_auth_url(self):
         """
@@ -296,7 +311,7 @@ class OpenStackBaseConnection(ConnectionUserAndKey):
         return url
 
     def add_default_headers(self, headers):
-        headers['X-Auth-Token'] = self.auth_token
+        headers[AUTH_TOKEN_HEADER] = self.auth_token
         headers['Accept'] = self.accept_format
         return headers
 
@@ -437,7 +452,8 @@ class OpenStackDriverMixin(object):
                  ex_tenant_domain_id='default',
                  ex_force_service_type=None,
                  ex_force_service_name=None,
-                 ex_force_service_region=None, *args, **kwargs):
+                 ex_force_service_region=None,
+                 ex_auth_cache=None, *args, **kwargs):
         self._ex_force_base_url = ex_force_base_url
         self._ex_force_auth_url = ex_force_auth_url
         self._ex_force_auth_version = ex_force_auth_version
@@ -449,6 +465,7 @@ class OpenStackDriverMixin(object):
         self._ex_force_service_type = ex_force_service_type
         self._ex_force_service_name = ex_force_service_name
         self._ex_force_service_region = ex_force_service_region
+        self._ex_auth_cache = ex_auth_cache
 
     def openstack_connection_kwargs(self):
         """
@@ -479,4 +496,6 @@ class OpenStackDriverMixin(object):
             rv['ex_force_service_name'] = self._ex_force_service_name
         if self._ex_force_service_region:
             rv['ex_force_service_region'] = self._ex_force_service_region
+        if self._ex_auth_cache is not None:
+            rv['ex_auth_cache'] = self._ex_auth_cache
         return rv
diff --git a/libcloud/common/openstack_identity.py 
b/libcloud/common/openstack_identity.py
index 307fa86..85d8198 100644
--- a/libcloud/common/openstack_identity.py
+++ b/libcloud/common/openstack_identity.py
@@ -18,6 +18,7 @@ Common / shared code for handling authentication against 
OpenStack identity
 service (Keystone).
 """
 
+from collections import namedtuple
 import datetime
 
 from libcloud.utils.py3 import httplib
@@ -34,6 +35,7 @@ except ImportError:
     import json  # type: ignore
 
 AUTH_API_VERSION = '1.1'
+AUTH_TOKEN_HEADER = 'X-Auth-Token'
 
 # Auth versions which contain token expiration information.
 AUTH_VERSIONS_WITH_EXPIRES = [
@@ -56,6 +58,10 @@ AUTH_TOKEN_EXPIRES_GRACE_SECONDS = 5
 
 
 __all__ = [
+    'OpenStackAuthenticationCache',
+    'OpenStackAuthenticationCacheKey',
+    'OpenStackAuthenticationContext',
+
     'OpenStackIdentityVersion',
     'OpenStackIdentityDomain',
     'OpenStackIdentityProject',
@@ -80,6 +86,72 @@ __all__ = [
 ]
 
 
+class OpenStackAuthenticationCache:
+    """
+    Base class for external OpenStack authentication caches.
+
+    Authentication tokens are always cached in memory in
+    :class:`OpenStackIdentityConnection`.auth_token and related fields.  These
+    tokens are lost when the driver is garbage collected.  To share tokens
+    among multiple drivers, processes, or systems, use an
+    :class:`OpenStackAuthenticationCache` in
+    OpenStackIdentityConnection.auth_cache.
+
+    Cache implementors should inherit this class and define the methods below.
+    """
+    def get(self, key):
+        """
+        Get an authentication context from the cache.
+
+        :param key: Key to fetch.
+        :type key: :class:`.OpenStackAuthenticationCacheKey`
+
+        :return: The cached context for the given key, if present; None if not.
+        :rtype: :class:`OpenStackAuthenticationContext`
+        """
+        raise NotImplementedError
+
+    def put(self, key, context):
+        """
+        Put an authentication context into the cache.
+
+        :param key: Key where the context will be stored.
+        :type key: :class:`.OpenStackAuthenticationCacheKey`
+
+        :param context: The context to cache.
+        :type context: :class:`.OpenStackAuthenticationContext`
+        """
+        raise NotImplementedError
+
+    def clear(self, key):
+        """
+        Clear an authentication context from the cache.
+
+        :param key: Key to clear.
+        :type key: :class:`.OpenStackAuthenticationCacheKey`
+        """
+        raise NotImplementedError
+
+
+OpenStackAuthenticationCacheKey = namedtuple(
+    'OpenStackAuthenticationCacheKey',
+    ['auth_url', 'user_id', 'token_scope', 'tenant_name', 'domain_name',
+     'tenant_domain_id'])
+
+
+class OpenStackAuthenticationContext:
+    """
+    An authentication token and related context.
+    """
+    def __init__(self, token, expiration=None, user=None, roles=None,
+                 urls=None):
+        self.token = token
+        self.expiration = expiration
+        self.user = user
+        self.roles = roles
+        self.urls = urls
+
+
 class OpenStackIdentityEndpointType(object):
     """
     Enum class for openstack identity endpoint type.
@@ -580,7 +652,8 @@ class OpenStackIdentityConnection(ConnectionUserAndKey):
     def __init__(self, auth_url, user_id, key, tenant_name=None,
                  tenant_domain_id='default', domain_name='Default',
                  token_scope=OpenStackIdentityTokenScope.PROJECT,
-                 timeout=None, proxy_url=None, parent_conn=None):
+                 timeout=None, proxy_url=None, parent_conn=None,
+                 auth_cache=None):
         super(OpenStackIdentityConnection, self).__init__(user_id=user_id,
                                                           key=key,
                                                           url=auth_url,
@@ -599,13 +672,16 @@ class OpenStackIdentityConnection(ConnectionUserAndKey):
         self.auth_url = auth_url
         self.tenant_name = tenant_name
         self.domain_name = domain_name
+        self.tenant_domain_id = tenant_domain_id
         self.token_scope = token_scope
         self.timeout = timeout
+        self.auth_cache = auth_cache
 
         self.urls = {}
         self.auth_token = None
         self.auth_token_expires = None
         self.auth_user_info = None
+        self.auth_user_roles = None
 
     def authenticated_request(self, action, params=None, data=None,
                               headers=None, method='GET', raw=False):
@@ -613,13 +689,18 @@ class OpenStackIdentityConnection(ConnectionUserAndKey):
         Perform an authenticated request against the identity API.
         """
         if not self.auth_token:
-            raise ValueError('Not to be authenticated to perform this request')
+            raise ValueError(
+                'Need to be authenticated to perform this request')
 
         headers = headers or {}
-        headers['X-Auth-Token'] = self.auth_token
+        headers[AUTH_TOKEN_HEADER] = self.auth_token
 
-        return self.request(action=action, params=params, data=data,
-                            headers=headers, method=method, raw=raw)
+        response = self.request(action=action, params=params, data=data,
+                                headers=headers, method=method, raw=raw)
+        # Evict cached auth token if we receive Unauthorized while using it
+        if response.status == httplib.UNAUTHORIZED:
+            self.clear_cached_auth_context()
+        return response
 
     def morph_action_hook(self, action):
         (_, _, _, request_path) = self._tuple_from_url(self.auth_url)
@@ -671,6 +752,22 @@ class OpenStackIdentityConnection(ConnectionUserAndKey):
         """
         raise NotImplementedError('authenticate not implemented')
 
+    def clear_cached_auth_context(self):
+        """
+        Clear the cached authentication context.
+
+        The context is cleared from fields on this connection and from the
+        external cache, if one is configured.
+        """
+        self.auth_token = None
+        self.auth_token_expires = None
+        self.auth_user_info = None
+        self.auth_user_roles = None
+        self.urls = {}
+
+        if self.auth_cache is not None:
+            self.auth_cache.clear(self._cache_key)
+
     def list_supported_versions(self):
         """
         Retrieve a list of all the identity versions which are supported by
@@ -722,6 +819,14 @@ class OpenStackIdentityConnection(ConnectionUserAndKey):
         if self.is_token_valid():
             return False
 
+        # See if there's a new token in the cache
+        self._load_auth_context_from_cache()
+
+        # If there was a token in the cache, it is now stored in our local
+        # auth_token and related fields.  Ensure it is still valid.
+        if self.is_token_valid():
+            return False
+
         return True
 
     def _to_projects(self, data):
@@ -741,6 +846,53 @@ class OpenStackIdentityConnection(ConnectionUserAndKey):
                                                               None))
         return project
 
+    @property
+    def _cache_key(self):
+        """
+        The key where this connection's authentication context will be cached.
+
+        :rtype: :class:`OpenStackAuthenticationCacheKey`
+        """
+        return OpenStackAuthenticationCacheKey(
+            self.auth_url, self.user_id, self.token_scope, self.tenant_name,
+            self.domain_name, self.tenant_domain_id)
+
+    def _cache_auth_context(self, context):
+        """
+        Store an authentication context in memory and the cache.
+
+        :param context: Authentication context to cache.
+        :type key: :class:`.OpenStackAuthenticationContext`
+        """
+        self.urls = context.urls
+        self.auth_token = context.token
+        self.auth_token_expires = context.expiration
+        self.auth_user_info = context.user
+        self.auth_user_roles = context.roles
+
+        if self.auth_cache is not None:
+            self.auth_cache.put(self._cache_key, context)
+
+    def _load_auth_context_from_cache(self):
+        """
+        Fetch an authentication context for this connection from the cache.
+
+        :rtype: :class:`OpenStackAuthenticationContext`
+        """
+        if self.auth_cache is None:
+            return None
+
+        context = self.auth_cache.get(self._cache_key)
+        if context is None:
+            return None
+
+        self.urls = context.urls
+        self.auth_token = context.token
+        self.auth_token_expires = context.expiration
+        self.auth_user_info = context.user
+        self.auth_user_roles = context.roles
+        return context
+
 
 class OpenStackIdentity_1_0_Connection(OpenStackIdentityConnection):
     """
@@ -824,11 +976,11 @@ class 
OpenStackIdentity_1_1_Connection(OpenStackIdentityConnection):
 
             try:
                 expires = body['auth']['token']['expires']
-
-                self.auth_token = body['auth']['token']['id']
-                self.auth_token_expires = parse_date(expires)
-                self.urls = body['auth']['serviceCatalog']
-                self.auth_user_info = None
+                self._cache_auth_context(
+                    OpenStackAuthenticationContext(
+                        body['auth']['token']['id'],
+                        expiration=parse_date(expires),
+                        urls=body['auth']['serviceCatalog']))
             except KeyError as e:
                 raise MalformedResponseError('Auth JSON response is \
                                              missing required elements', e)
@@ -897,11 +1049,12 @@ class 
OpenStackIdentity_2_0_Connection(OpenStackIdentityConnection):
             try:
                 access = body['access']
                 expires = access['token']['expires']
-
-                self.auth_token = access['token']['id']
-                self.auth_token_expires = parse_date(expires)
-                self.urls = access['serviceCatalog']
-                self.auth_user_info = access.get('user', {})
+                self._cache_auth_context(
+                    OpenStackAuthenticationContext(
+                        access['token']['id'],
+                        expiration=parse_date(expires),
+                        urls=access['serviceCatalog'],
+                        user=access.get('user', {})))
             except KeyError as e:
                 raise MalformedResponseError('Auth JSON response is \
                                              missing required elements', e)
@@ -935,14 +1088,15 @@ class 
OpenStackIdentity_3_0_Connection(OpenStackIdentityConnection):
     def __init__(self, auth_url, user_id, key, tenant_name=None,
                  domain_name='Default', tenant_domain_id='default',
                  token_scope=OpenStackIdentityTokenScope.PROJECT,
-                 timeout=None, proxy_url=None, parent_conn=None):
+                 timeout=None, proxy_url=None, parent_conn=None,
+                 auth_cache=None):
         """
         :param tenant_name: Name of the project this user belongs to. Note:
                             When token_scope is set to project, this argument
                             control to which project to scope the token to.
         :type tenant_name: ``str``
 
-        :param domain_name: Domain the user belongs to. Note: Then token_scope
+        :param domain_name: Domain the user belongs to. Note: When token_scope
                             is set to token, this argument controls to which
                             domain to scope the token to.
         :type domain_name: ``str``
@@ -950,6 +1104,9 @@ class 
OpenStackIdentity_3_0_Connection(OpenStackIdentityConnection):
         :param token_scope: Whether to scope a token to a "project", a
                             "domain" or "unscoped"
         :type token_scope: ``str``
+
+        :param auth_cache: Where to cache authentication tokens.
+        :type auth_cache: :class:`OpenStackAuthenticationCache`
         """
         super(OpenStackIdentity_3_0_Connection,
               self).__init__(auth_url=auth_url,
@@ -957,10 +1114,12 @@ class 
OpenStackIdentity_3_0_Connection(OpenStackIdentityConnection):
                              key=key,
                              tenant_name=tenant_name,
                              domain_name=domain_name,
+                             tenant_domain_id=tenant_domain_id,
                              token_scope=token_scope,
                              timeout=timeout,
                              proxy_url=proxy_url,
-                             parent_conn=parent_conn)
+                             parent_conn=parent_conn,
+                             auth_cache=auth_cache)
 
         if self.token_scope not in self.VALID_TOKEN_SCOPES:
             raise ValueError('Invalid value for "token_scope" argument: %s' %
@@ -974,9 +1133,6 @@ class 
OpenStackIdentity_3_0_Connection(OpenStackIdentityConnection):
                 not self.domain_name):
             raise ValueError('Must provide domain_name argument')
 
-        self.auth_user_roles = None
-        self.tenant_domain_id = tenant_domain_id
-
     def authenticate(self, force=False):
         """
         Perform authentication.
@@ -989,47 +1145,7 @@ class 
OpenStackIdentity_3_0_Connection(OpenStackIdentityConnection):
         response = self.request('/v3/auth/tokens', data=data,
                                 headers={'Content-Type': 'application/json'},
                                 method='POST')
-
-        if response.status == httplib.UNAUTHORIZED:
-            # Invalid credentials
-            raise InvalidCredsError()
-        elif response.status in [httplib.OK, httplib.CREATED]:
-            headers = response.headers
-
-            try:
-                body = json.loads(response.body)
-            except Exception as e:
-                raise MalformedResponseError('Failed to parse JSON', e)
-
-            try:
-                roles = self._to_roles(body['token']['roles'])
-            except Exception:
-                roles = []
-
-            try:
-                expires = body['token']['expires_at']
-
-                self.auth_token = headers['x-subject-token']
-                self.auth_token_expires = parse_date(expires)
-                # Note: catalog is not returned for unscoped tokens
-                self.urls = body['token'].get('catalog', None)
-                self.auth_user_info = body['token'].get('user', None)
-                self.auth_user_roles = roles
-            except KeyError as e:
-                raise MalformedResponseError('Auth JSON response is \
-                                             missing required elements', e)
-            body = 'code: %s body:%s' % (response.status, response.body)
-        elif response.status == 300:
-            # ambiguous version request
-            raise LibcloudError(
-                'Auth request returned ambiguous version error, try'
-                'using the version specific URL to connect,'
-                ' e.g. identity/v3/auth/tokens')
-        else:
-            body = 'code: %s body:%s' % (response.status, response.body)
-            raise MalformedResponseError('Malformed response', body=body,
-                                         driver=self.driver)
-
+        self._parse_token_response(response, cache_it=True)
         return self
 
     def list_domains(self):
@@ -1356,6 +1472,94 @@ class 
OpenStackIdentity_3_0_Connection(OpenStackIdentityConnection):
 
         return data
 
+    def _load_auth_context_from_cache(self):
+        context = super()._load_auth_context_from_cache()
+        if context is None:
+            return None
+
+        # Since v3 only caches the token and expiration, fetch the
+        # service catalog and other bits of the authentication context
+        # from Keystone.
+        try:
+            self._fetch_auth_token()
+        except InvalidCredsError:
+            # Unauthorized; cached auth context was cleared as part of
+            # _fetch_auth_token
+            return None
+
+        # Local auth context variables set in _fetch_auth_token
+        return context
+
+    def _parse_token_response(self, response, cache_it=False,
+                              raise_ambiguous_version_error=True):
+        """
+        Parse a response from /v3/auth/tokens.
+
+        :param cache_it: Should we cache the authentication context?
+        :type cache_it: ``bool``
+
+        :param raise_ambiguous_version_error: Should an ambiguous version
+            error be raised on a 300 response?
+        :type raise_ambiguous_version_error: ``bool``
+        """
+        if response.status == httplib.UNAUTHORIZED:
+            raise InvalidCredsError()
+        elif response.status in [httplib.OK, httplib.CREATED]:
+            headers = response.headers
+
+            try:
+                body = json.loads(response.body)
+            except Exception as e:
+                raise MalformedResponseError('Failed to parse JSON', e)
+
+            try:
+                roles = self._to_roles(body['token']['roles'])
+            except Exception:
+                roles = []
+
+            try:
+                expires = parse_date(body['token']['expires_at'])
+                token = headers['x-subject-token']
+
+                # Cache the fewest fields required for token reuse to minimize
+                # cache size. Other fields, especially the service catalog, can
+                # be quite large. Fetch these from Keystone when the token is
+                # first loaded from cache.
+                if cache_it:
+                    self._cache_auth_context(
+                        OpenStackAuthenticationContext(
+                            token, expiration=expires))
+
+                self.auth_token = token
+                self.auth_token_expires = expires
+                # Note: catalog is not returned for unscoped tokens
+                self.urls = body['token'].get('catalog', None)
+                self.auth_user_info = body['token'].get('user', None)
+                self.auth_user_roles = roles
+            except KeyError as e:
+                raise MalformedResponseError('Auth JSON response is \
+                                             missing required elements', e)
+        elif raise_ambiguous_version_error and response.status == 300:
+            # ambiguous version request
+            raise LibcloudError(
+                'Auth request returned ambiguous version error, try'
+                'using the version specific URL to connect,'
+                ' e.g. identity/v3/auth/tokens')
+        else:
+            body = 'code: %s body:%s' % (response.status, response.body)
+            raise MalformedResponseError('Malformed response', body=body,
+                                         driver=self.driver)
+
+    def _fetch_auth_token(self):
+        """
+        Fetch our authentication token and service catalog.
+        """
+        headers = {'X-Subject-Token': self.auth_token}
+        response = self.authenticated_request('/v3/auth/tokens',
+                                              headers=headers)
+        self._parse_token_response(response)
+        return self
+
     def _to_domains(self, data):
         result = []
         for item in data:
@@ -1515,41 +1719,8 @@ class OpenStackIdentity_3_0_Connection_OIDC_access_token(
         response = self.request('/v3/auth/tokens', data=data,
                                 headers={'Content-Type': 'application/json'},
                                 method='POST')
-
-        if response.status == httplib.UNAUTHORIZED:
-            # Invalid credentials
-            raise InvalidCredsError()
-        elif response.status in [httplib.OK, httplib.CREATED]:
-            headers = response.headers
-
-            try:
-                body = json.loads(response.body)
-            except Exception as e:
-                raise MalformedResponseError('Failed to parse JSON', e)
-
-            try:
-                roles = self._to_roles(body['token']['roles'])
-            except Exception:
-                roles = []
-
-            try:
-                expires = body['token']['expires_at']
-
-                self.auth_token = headers['x-subject-token']
-                self.auth_token_expires = parse_date(expires)
-                # Note: catalog is not returned for unscoped tokens
-                self.urls = body['token'].get('catalog', None)
-                self.auth_user_info = body['token'].get('user', None)
-                self.auth_user_roles = roles
-            except KeyError as e:
-                raise MalformedResponseError('Auth JSON response is \
-                                             missing required elements', e)
-            body = 'code: %s body:%s' % (response.status, response.body)
-        else:
-            body = 'code: %s body:%s' % (response.status, response.body)
-            raise MalformedResponseError('Malformed response', body=body,
-                                         driver=self.driver)
-
+        self._parse_token_response(response, cache_it=True,
+                                   raise_ambiguous_version_error=False)
         return self
 
     def _get_unscoped_token_from_oidc_token(self):
@@ -1586,7 +1757,7 @@ class OpenStackIdentity_3_0_Connection_OIDC_access_token(
         path = '/v3/auth/projects'
         response = self.request(path,
                                 headers={'Content-Type': 'application/json',
-                                         'X-Auth-Token': token},
+                                         AUTH_TOKEN_HEADER: token},
                                 method='GET')
 
         if response.status not in [httplib.UNAUTHORIZED, httplib.OK,
@@ -1596,7 +1767,7 @@ class OpenStackIdentity_3_0_Connection_OIDC_access_token(
             response = self.request(path,
                                     headers={'Content-Type':
                                              'application/json',
-                                             'X-Auth-Token': token},
+                                             AUTH_TOKEN_HEADER: token},
                                     method='GET')
 
         if response.status == httplib.UNAUTHORIZED:
@@ -1640,7 +1811,8 @@ class 
OpenStackIdentity_2_0_Connection_VOMS(OpenStackIdentityConnection,
     def __init__(self, auth_url, user_id, key, tenant_name=None,
                  domain_name='Default',
                  token_scope=OpenStackIdentityTokenScope.PROJECT,
-                 timeout=None, proxy_url=None, parent_conn=None):
+                 timeout=None, proxy_url=None, parent_conn=None,
+                 auth_cache=None):
         CertificateConnection.__init__(self, cert_file=key,
                                        url=auth_url,
                                        proxy_url=proxy_url,
@@ -1661,6 +1833,7 @@ class 
OpenStackIdentity_2_0_Connection_VOMS(OpenStackIdentityConnection,
         self.token_scope = token_scope
         self.timeout = timeout
         self.proxy_url = proxy_url
+        self.auth_cache = auth_cache
 
         self.urls = {}
         self.auth_token = None
@@ -1713,7 +1886,7 @@ class 
OpenStackIdentity_2_0_Connection_VOMS(OpenStackIdentityConnection,
         """
         headers = {'Accept': 'application/json',
                    'Content-Type': 'application/json',
-                   'X-Auth-Token': token}
+                   AUTH_TOKEN_HEADER: token}
         response = self.request('/v2.0/tenants', headers=headers, method='GET')
 
         if response.status == httplib.UNAUTHORIZED:
@@ -1748,15 +1921,15 @@ class 
OpenStackIdentity_2_0_Connection_VOMS(OpenStackIdentityConnection,
             try:
                 access = body['access']
                 expires = access['token']['expires']
-
-                self.auth_token = access['token']['id']
-                self.auth_token_expires = parse_date(expires)
-                self.urls = access['serviceCatalog']
-                self.auth_user_info = access.get('user', {})
+                self._cache_auth_context(
+                    OpenStackAuthenticationContext(
+                        access['token']['id'],
+                        expiration=parse_date(expires),
+                        urls=access['serviceCatalog'],
+                        user=access.get('user', {})))
             except KeyError as e:
                 raise MalformedResponseError('Auth JSON response is \
                                              missing required elements', e)
-
         return self
 
 
diff --git a/libcloud/test/common/test_openstack_identity.py 
b/libcloud/test/common/test_openstack_identity.py
index 888a2b0..6080506 100644
--- a/libcloud/test/common/test_openstack_identity.py
+++ b/libcloud/test/common/test_openstack_identity.py
@@ -41,14 +41,20 @@ from libcloud.test import unittest
 from libcloud.test import MockHttp
 from libcloud.test.secrets import OPENSTACK_PARAMS
 from libcloud.test.file_fixtures import ComputeFileFixtures
+from libcloud.test.compute.test_openstack import OpenStackMockAuthCache
 from libcloud.test.compute.test_openstack import OpenStackMockHttp
 from libcloud.test.compute.test_openstack import OpenStack_2_0_MockHttp
 
+TOMORROW = datetime.datetime.today() + datetime.timedelta(1)
+YESTERDAY = datetime.datetime.today() - datetime.timedelta(1)
+
 
 class OpenStackIdentityConnectionTestCase(unittest.TestCase):
     def setUp(self):
         OpenStackBaseConnection.auth_url = None
         OpenStackBaseConnection.conn_class = OpenStackMockHttp
+        OpenStack_2_0_MockHttp.type = None
+        OpenStackIdentity_3_0_MockHttp.type = None
 
     def test_auth_url_is_correctly_assembled(self):
         tuples = [
@@ -166,9 +172,6 @@ class 
OpenStackIdentityConnectionTestCase(unittest.TestCase):
         connection = self._get_mock_connection(OpenStack_2_0_MockHttp)
         auth_url = connection.auth_url
 
-        yesterday = datetime.datetime.today() - datetime.timedelta(1)
-        tomorrow = datetime.datetime.today() + datetime.timedelta(1)
-
         osa = OpenStackIdentity_2_0_Connection(auth_url=auth_url,
                                                user_id=user_id,
                                                key=key,
@@ -179,7 +182,7 @@ class 
OpenStackIdentityConnectionTestCase(unittest.TestCase):
 
         # Force re-auth, expired token
         osa.auth_token = None
-        osa.auth_token_expires = yesterday
+        osa.auth_token_expires = YESTERDAY
         count = 5
 
         for i in range(0, count):
@@ -189,7 +192,7 @@ class 
OpenStackIdentityConnectionTestCase(unittest.TestCase):
 
         # No force reauth, expired token
         osa.auth_token = None
-        osa.auth_token_expires = yesterday
+        osa.auth_token_expires = YESTERDAY
 
         mocked_auth_method.call_count = 0
         self.assertEqual(mocked_auth_method.call_count, 0)
@@ -209,7 +212,7 @@ class 
OpenStackIdentityConnectionTestCase(unittest.TestCase):
             osa.authenticate(force=False)
 
             if i == 0:
-                osa.auth_token_expires = tomorrow
+                osa.auth_token_expires = TOMORROW
 
         self.assertEqual(mocked_auth_method.call_count, 1)
 
@@ -230,6 +233,104 @@ class 
OpenStackIdentityConnectionTestCase(unittest.TestCase):
 
         self.assertEqual(mocked_auth_method.call_count, 1)
 
+    def test_authentication_cache(self):
+        tuples = [
+            # 1.0 does not provide token expiration, so it always
+            # re-authenticates and never uses the cache.
+            # ('1.0', OpenStackMockHttp, {}),
+            ('1.1', OpenStackMockHttp, {}),
+            ('2.0', OpenStack_2_0_MockHttp, {}),
+            ('2.0_apikey', OpenStack_2_0_MockHttp, {}),
+            ('2.0_password', OpenStack_2_0_MockHttp, {}),
+            ('3.x_password', OpenStackIdentity_3_0_MockHttp, {'user_id': 
'test_user_id', 'key': 'test_key',
+                                                              'token_scope': 
'project', 'tenant_name': 'test_tenant',
+                                                              
'tenant_domain_id': 'test_tenant_domain_id',
+                                                              'domain_name': 
'test_domain'}),
+            ('3.x_oidc_access_token', OpenStackIdentity_3_0_MockHttp, 
{'user_id': 'test_user_id', 'key': 'test_key',
+                                                              'token_scope': 
'domain', 'tenant_name': 'test_tenant',
+                                                              
'tenant_domain_id': 'test_tenant_domain_id',
+                                                              'domain_name': 
'test_domain'})
+        ]
+
+        user_id = OPENSTACK_PARAMS[0]
+        key = OPENSTACK_PARAMS[1]
+
+        for (auth_version, mock_http_class, kwargs) in tuples:
+            mock_http_class.type = None
+            connection = \
+                self._get_mock_connection(mock_http_class=mock_http_class)
+            auth_url = connection.auth_url
+
+            if not kwargs:
+                kwargs['user_id'] = user_id
+                kwargs['key'] = key
+
+            auth_cache = OpenStackMockAuthCache()
+            self.assertEqual(len(auth_cache), 0)
+            kwargs['auth_cache'] = auth_cache
+
+            cls = get_class_for_auth_version(auth_version=auth_version)
+            osa = cls(auth_url=auth_url, parent_conn=connection, **kwargs)
+            osa = osa.authenticate()
+
+            # Token is cached
+            self.assertEqual(len(auth_cache), 1)
+
+            # New client, token from cache is re-used
+            osa = cls(auth_url=auth_url, parent_conn=connection, **kwargs)
+            osa.request = Mock(wraps=osa.request)
+            osa = osa.authenticate()
+
+            # No auth API call
+            if auth_version in ('1.1', '2.0', '2.0_apikey', '2.0_password'):
+                self.assertEqual(osa.request.call_count, 0)
+            elif auth_version in ('3.x_password', '3.x_oidc_access_token'):
+                # v3 only caches token and expiration; service catalog URLs
+                # and the rest of the auth context are fetched from Keystone
+                osa.request.assert_called_once_with(
+                    action='/v3/auth/tokens', params=None, data=None,
+                    headers={'X-Subject-Token': 
'00000000000000000000000000000000',
+                             'X-Auth-Token': 
'00000000000000000000000000000000'},
+                    method='GET', raw=False)
+
+            # Cache size unchanged
+            self.assertEqual(len(auth_cache), 1)
+
+            # Authenticates if cached token expired
+            cache_key = list(auth_cache.store.keys())[0]
+            auth_context = auth_cache.get(cache_key)
+            auth_context.expiration = YESTERDAY
+            auth_cache.put(cache_key, auth_context)
+
+            osa = cls(auth_url=auth_url, parent_conn=connection, **kwargs)
+            osa.request = Mock(wraps=osa.request)
+            osa._get_unscoped_token_from_oidc_token = Mock(return_value='000')
+            OpenStackIdentity_3_0_MockHttp.type = 'GET_UNAUTHORIZED_POST_OK'
+            osa = osa.authenticate()
+
+            if auth_version in ('1.1', '2.0', '2.0_apikey', '2.0_password'):
+                self.assertEqual(osa.request.call_count, 1)
+                self.assertTrue(osa.request.call_args[1]['method'], 'POST')
+            elif auth_version in ('3.x_password', '3.x_oidc_access_token'):
+                self.assertTrue(osa.request.call_args[0][0], '/v3/auth/tokens')
+                self.assertTrue(osa.request.call_args[1]['method'], 'POST')
+
+            # Token evicted from cache if 401 received on another call
+            if hasattr(osa, 'list_projects'):
+                mock_http_class.type = None
+                auth_cache.reset()
+
+                osa = cls(auth_url=auth_url, parent_conn=connection, **kwargs)
+                osa.request = Mock(wraps=osa.request)
+                osa = osa.authenticate()
+                self.assertEqual(len(auth_cache), 1)
+                mock_http_class.type = 'UNAUTHORIZED'
+                try:
+                    osa.list_projects()
+                except:  # These methods don't handle 401s
+                    pass
+                self.assertEqual(len(auth_cache), 0)
+
     def _get_mock_connection(self, mock_http_class, auth_url=None):
         OpenStackBaseConnection.conn_class = mock_http_class
 
@@ -767,6 +868,13 @@ class OpenStackIdentity_3_0_MockHttp(MockHttp):
             return (httplib.OK, body, self.json_content_headers, 
httplib.responses[httplib.OK])
         raise NotImplementedError()
 
+    def _v3_projects_UNAUTHORIZED(self, method, url, body, headers):
+        if method == 'GET':
+            body = ComputeFileFixtures('openstack').load('_v3__auth.json')
+            return (httplib.UNAUTHORIZED, body, self.json_content_headers,
+                    httplib.responses[httplib.UNAUTHORIZED])
+        raise NotImplementedError()
+
     def 
_v3_OS_FEDERATION_identity_providers_test_user_id_protocols_test_tenant_auth(self,
 method, url, body, headers):
         if method == 'GET':
             if 'Authorization' not in headers:
@@ -784,6 +892,14 @@ class OpenStackIdentity_3_0_MockHttp(MockHttp):
         raise NotImplementedError()
 
     def _v3_auth_tokens(self, method, url, body, headers):
+        if method == 'GET':
+            body = json.loads(
+                ComputeFileFixtures('openstack').load('_v3__auth.json'))
+            body['token']['expires_at'] = TOMORROW.isoformat()
+            headers = self.json_content_headers.copy()
+            headers['x-subject-token'] = '00000000000000000000000000000000'
+            return (httplib.OK, json.dumps(body), headers,
+                    httplib.responses[httplib.OK])
         if method == 'POST':
             status = httplib.OK
             data = json.loads(body)
@@ -798,6 +914,16 @@ class OpenStackIdentity_3_0_MockHttp(MockHttp):
             return (status, body, headers, httplib.responses[httplib.OK])
         raise NotImplementedError()
 
+    def _v3_auth_tokens_GET_UNAUTHORIZED_POST_OK(self, method, url, body, 
headers):
+        if method == 'GET':
+            body = ComputeFileFixtures('openstack').load(
+                '_v3__auth_unauthorized.json')
+            return (httplib.UNAUTHORIZED, body, self.json_content_headers,
+                    httplib.responses[httplib.UNAUTHORIZED])
+        if method == 'POST':
+            return self._v3_auth_tokens(method, url, body, headers)
+        raise NotImplementedError()
+
     def _v3_users(self, method, url, body, headers):
         if method == 'GET':
             # list users
diff --git 
a/libcloud/test/compute/fixtures/openstack/_v3__auth_unauthorized.json 
b/libcloud/test/compute/fixtures/openstack/_v3__auth_unauthorized.json
new file mode 100644
index 0000000..24a51c1
--- /dev/null
+++ b/libcloud/test/compute/fixtures/openstack/_v3__auth_unauthorized.json
@@ -0,0 +1 @@
+{"error": {"message": "The request you have made requires authentication.", 
"code": 401, "title": "Unauthorized"}}
diff --git a/libcloud/test/compute/test_cloudwatt.py 
b/libcloud/test/compute/test_cloudwatt.py
index cc98abb..980c3d5 100644
--- a/libcloud/test/compute/test_cloudwatt.py
+++ b/libcloud/test/compute/test_cloudwatt.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import sys
 import unittest
 
 from libcloud.compute.drivers.cloudwatt import CloudwattNodeDriver
@@ -30,3 +31,13 @@ class CloudwattNodeDriverTests(OpenStack_1_1_Tests, 
unittest.TestCase):
 
     def test_auth_token_without_base_url_raises_exception(self):
         pass
+
+    def test_ex_auth_cache_passed_to_identity_connection(self):
+        pass
+
+    def test_unauthorized_clears_cached_auth_context(self):
+        pass
+
+
+if __name__ == '__main__':
+    sys.exit(unittest.main())
diff --git a/libcloud/test/compute/test_openstack.py 
b/libcloud/test/compute/test_openstack.py
index 8c7bd45..64da36f 100644
--- a/libcloud/test/compute/test_openstack.py
+++ b/libcloud/test/compute/test_openstack.py
@@ -37,6 +37,8 @@ from libcloud.utils.py3 import method_type
 from libcloud.utils.py3 import u
 
 from libcloud.common.base import LibcloudConnection
+from libcloud.common.exceptions import BaseHTTPError
+from libcloud.common.openstack_identity import OpenStackAuthenticationCache
 from libcloud.common.types import InvalidCredsError, MalformedResponseError, \
     LibcloudError
 from libcloud.compute.types import Provider, KeyPairDoesNotExistError, 
StorageVolumeState, \
@@ -219,6 +221,34 @@ class OpenStack_1_0_Tests(TestCaseMixin, 
unittest.TestCase):
         else:
             self.fail('test should have thrown')
 
+    def test_ex_auth_cache_passed_to_identity_connection(self):
+        kwargs = self.driver_kwargs.copy()
+        kwargs['ex_auth_cache'] = OpenStackMockAuthCache()
+        driver = self.driver_type(*self.driver_args, **kwargs)
+        driver.list_nodes()
+        self.assertEqual(kwargs['ex_auth_cache'],
+                         driver.connection.get_auth_class().auth_cache)
+
+    def test_unauthorized_clears_cached_auth_context(self):
+        auth_cache = OpenStackMockAuthCache()
+        self.assertEqual(len(auth_cache), 0)
+
+        kwargs = self.driver_kwargs.copy()
+        kwargs['ex_auth_cache'] = auth_cache
+        driver = self.driver_type(*self.driver_args, **kwargs)
+        driver.list_nodes()
+
+        # Token was cached
+        self.assertEqual(len(auth_cache), 1)
+
+        # Simulate token being revoked
+        self.driver_klass.connectionCls.conn_class.type = 'UNAUTHORIZED'
+        with pytest.raises(BaseHTTPError) as ex:
+            driver.list_nodes()
+
+        # Token was evicted
+        self.assertEqual(len(auth_cache), 0)
+
     def test_error_parsing_when_body_is_missing_message(self):
         OpenStackMockHttp.type = 'NO_MESSAGE_IN_ERROR_BODY'
         try:
@@ -516,6 +546,9 @@ class OpenStackMockHttp(MockHttp, unittest.TestCase):
         body = self.fixtures.load('v1_slug_servers_detail_metadata.xml')
         return (httplib.OK, body, XML_HEADERS, httplib.responses[httplib.OK])
 
+    def _v1_0_slug_servers_detail_UNAUTHORIZED(self, method, url, body, 
headers):
+        return (httplib.UNAUTHORIZED, "", {}, 
httplib.responses[httplib.UNAUTHORIZED])
+
     def _v1_0_slug_images_333111(self, method, url, body, headers):
         if method != "DELETE":
             raise NotImplementedError()
@@ -675,7 +708,7 @@ class OpenStack_1_1_Tests(unittest.TestCase, TestCaseMixin):
 
     def _force_reauthentication(self):
         """
-        Trash current auth token so driver will be forced to re-authentication
+        Trash current auth token so driver will be forced to re-authenticate
         on next request.
         """
         self.driver.connection._ex_force_base_url = 
'http://ex_force_base_url.com:666/forced_url'
@@ -766,6 +799,35 @@ class OpenStack_1_1_Tests(unittest.TestCase, 
TestCaseMixin):
         self.assertEqual('/v1.1/slug', driver.connection.request_path)
         self.assertEqual(443, driver.connection.port)
 
+    def test_ex_auth_cache_passed_to_identity_connection(self):
+        kwargs = self.driver_kwargs.copy()
+        kwargs['ex_auth_cache'] = OpenStackMockAuthCache()
+        driver = self.driver_type(*self.driver_args, **kwargs)
+        osa = driver.connection.get_auth_class()
+        driver.list_nodes()
+        self.assertEqual(kwargs['ex_auth_cache'],
+                         driver.connection.get_auth_class().auth_cache)
+
+    def test_unauthorized_clears_cached_auth_context(self):
+        auth_cache = OpenStackMockAuthCache()
+        self.assertEqual(len(auth_cache), 0)
+
+        kwargs = self.driver_kwargs.copy()
+        kwargs['ex_auth_cache'] = auth_cache
+        driver = self.driver_type(*self.driver_args, **kwargs)
+        driver.list_nodes()
+
+        # Token was cached
+        self.assertEqual(len(auth_cache), 1)
+
+        # Simulate token being revoked
+        self.driver_klass.connectionCls.conn_class.type = 'UNAUTHORIZED'
+        with pytest.raises(BaseHTTPError) as ex:
+            driver.list_nodes()
+
+        # Token was evicted
+        self.assertEqual(len(auth_cache), 0)
+
     def test_list_nodes(self):
         nodes = self.driver.list_nodes()
         self.assertEqual(len(nodes), 2)
@@ -2208,6 +2270,9 @@ class OpenStack_1_1_MockHttp(MockHttp, unittest.TestCase):
         body = self.fixtures.load('_servers_detail_ERROR_STATE.json')
         return (httplib.OK, body, self.json_content_headers, 
httplib.responses[httplib.OK])
 
+    def _v2_1337_servers_detail_UNAUTHORIZED(self, method, url, body, headers):
+        return (httplib.UNAUTHORIZED, "", {}, 
httplib.responses[httplib.UNAUTHORIZED])
+
     def _v2_1337_servers_does_not_exist(self, *args, **kwargs):
         return httplib.NOT_FOUND, None, {}, 
httplib.responses[httplib.NOT_FOUND]
 
@@ -2879,6 +2944,9 @@ class OpenStack_2_0_MockHttp(OpenStack_1_1_MockHttp):
             setattr(self, new_name, method_type(method, self,
                                                 OpenStack_2_0_MockHttp))
 
+    def _v2_0_tenants_UNAUTHORIZED(self, method, url, body, headers):
+        return (httplib.UNAUTHORIZED, "", {}, 
httplib.responses[httplib.UNAUTHORIZED])
+
 
 class OpenStack_1_1_Auth_2_0_Tests(OpenStack_1_1_Tests):
     driver_args = OPENSTACK_PARAMS + ('1.1',)
@@ -2906,5 +2974,26 @@ class OpenStack_1_1_Auth_2_0_Tests(OpenStack_1_1_Tests):
                        'name': 'identity:default'}]})
 
 
+class OpenStackMockAuthCache(OpenStackAuthenticationCache):
+    def __init__(self):
+        self.reset()
+
+    def get(self, key):
+        return self.store.get(key)
+
+    def put(self, key, context):
+        self.store[key] = context
+
+    def clear(self, key):
+        if key in self.store:
+            del self.store[key]
+
+    def reset(self):
+        self.store = {}
+
+    def __len__(self):
+        return len(self.store)
+
+
 if __name__ == '__main__':
     sys.exit(unittest.main())

Reply via email to