The docker registry v1 and v2 versions have completely different
authentication methods that need handling. The v2 OAuth scheme
requires modifying request headers and re-trying requests after
getting an auth token. Introduce a pluggable framework for auth
can be hooked into the _get_url() method to deal with the request
and response objects, as well as errors.

Signed-off-by: Daniel P. Berrange <berra...@redhat.com>
---
 libvirt-sandbox/image/sources/docker.py | 143 +++++++++++++++++++++++++-------
 1 file changed, 113 insertions(+), 30 deletions(-)

diff --git a/libvirt-sandbox/image/sources/docker.py 
b/libvirt-sandbox/image/sources/docker.py
index 658d90a..a54f563 100644
--- a/libvirt-sandbox/image/sources/docker.py
+++ b/libvirt-sandbox/image/sources/docker.py
@@ -30,6 +30,7 @@ import subprocess
 import shutil
 import urlparse
 import hashlib
+from abc import ABCMeta, abstractmethod
 
 from . import base
 
@@ -82,8 +83,83 @@ class DockerImage():
                             template.path)
 
 
+class DockerAuth():
+
+    __metaclass__ = ABCMeta
+    def __init__(self):
+        pass
+
+    @abstractmethod
+    def prepare_req(self, req):
+        pass
+
+    @abstractmethod
+    def process_res(self, res):
+        pass
+
+    @abstractmethod
+    def process_err(self, err):
+        return False
+
+
+class DockerAuthNop(DockerAuth):
+
+    def prepare_req(self, req):
+        pass
+
+    def process_res(self, res):
+        pass
+
+    def process_err(self, err):
+        return False
+
+
+class DockerAuthBasic(DockerAuth):
+
+    def __init__(self, username, password):
+        self.username = username
+        self.password = password
+        self.token = None
+
+    def prepare_req(self, req):
+        if self.username is not None:
+            auth = base64.encodestring(
+                '%s:%s' % (self.username, self.password)).replace('\n', '')
+
+            req.add_header("Authorization", "Basic %s" % auth)
+
+        req.add_header("X-Docker-Token", "true")
+
+    def process_res(self, res):
+        self.token = res.info().getheader('X-Docker-Token')
+
+    def process_err(self, err):
+        return False
+
+
+class DockerAuthToken(DockerAuth):
+
+    def __init__(self, token):
+        self.token = token
+
+    def prepare_req(self, req):
+        req.add_header("Authorization", "Token %s" % self.token)
+
+    def process_res(self, res):
+        pass
+
+    def process_err(self, err):
+        return False
+
+
 class DockerSource(base.Source):
 
+    def __init__(self):
+        self.auth_handler = DockerAuthNop()
+
+    def set_auth_handler(self, auth_handler):
+        self.auth_handler = auth_handler
+
     def _check_cert_validate(self):
         major = sys.version_info.major
         SSL_WARNING = "SSL certificates couldn't be validated by default. You 
need to have 2.7.9/3.4.3 or higher"
@@ -113,28 +189,29 @@ class DockerSource(base.Source):
     def download_template(self, image, template, templatedir):
         self._check_cert_validate()
 
+        basicauth = DockerAuthBasic(template.username, template.password)
+        self.set_auth_handler(basicauth)
         try:
             (data, res) = self._get_json(template,
                                          None,
                                          "/v1/repositories/%s/%s/images" % (
                                              image.repo, image.name,
-                                         ),
-                                         {"X-Docker-Token": "true"})
+                                         ))
         except urllib2.HTTPError, e:
             raise ValueError(["Image '%s' does not exist" % template])
 
         registryendpoint = res.info().getheader('X-Docker-Endpoints')
-        token = res.info().getheader('X-Docker-Token')
 
-        headers = {}
-        if token is not None:
-            headers["Authorization"] = "Token " + token
+        if basicauth.token is not None:
+            self.set_auth_handler(DockerAuthToken(basicauth.token))
+        else:
+            self.set_auth_handler(DockerAuthNop())
+
         (data, res) = self._get_json(template,
                                      registryendpoint,
                                      "/v1/repositories/%s/%s/tags" %(
                                          image.repo, image.name
-                                     ),
-                                     headers)
+                                     ))
 
         if image.tag not in data:
             raise ValueError(["Tag '%s' does not exist for image '%s'" %
@@ -143,8 +220,7 @@ class DockerSource(base.Source):
 
         (data, res) = self._get_json(template,
                                      registryendpoint,
-                                     "/v1/images/" + imagetagid + "/ancestry",
-                                     headers)
+                                     "/v1/images/" + imagetagid + "/ancestry")
 
         if data[0] != imagetagid:
             raise ValueError(["Expected first layer id '%s' to match image id 
'%s'",
@@ -167,14 +243,12 @@ class DockerSource(base.Source):
                     res = self._save_data(template,
                                           registryendpoint,
                                           "/v1/images/" + layerid + "/json",
-                                          headers,
                                           jsonfile)
                     createdFiles.append(jsonfile)
 
                     self._save_data(template,
                                     registryendpoint,
                                     "/v1/images/" + layerid + "/layer",
-                                    headers,
                                     datafile)
                     createdFiles.append(datafile)
 
@@ -201,10 +275,10 @@ class DockerSource(base.Source):
                 except:
                     pass
 
-    def _save_data(self, template, server, path, headers,
+    def _save_data(self, template, server, path,
                    dest, checksum=None):
         try:
-            res = self._get_url(template, server, path, headers)
+            res = self._get_url(template, server, path)
 
             datalen = res.info().getheader("Content-Length")
             if datalen is not None:
@@ -247,7 +321,7 @@ class DockerSource(base.Source):
             debug("FAIL %s\n" % str(e))
             raise
 
-    def _get_url(self, template, server, path, headers):
+    def _get_url(self, template, server, path, headers=None):
         if template.protocol is None:
             protocol = "https"
         else:
@@ -266,25 +340,34 @@ class DockerSource(base.Source):
         debug("Fetching %s..." % url)
 
         req = urllib2.Request(url=url)
-        for h in headers.keys():
-            req.add_header(h, headers[h])
-
-        #www Auth header starts
-        if template.username and template.password:
-            base64string = base64.encodestring(
-                '%s:%s' % (template.username,
-                           template.password)).replace('\n', '')
-            req.add_header("Authorization", "Basic %s" % base64string)
-        #www Auth header finish
+        if headers is not None:
+            for h in headers.keys():
+                req.add_header(h, headers[h])
 
-        return urllib2.urlopen(req)
+        self.auth_handler.prepare_req(req)
 
-    def _get_json(self, template, server, path, headers):
         try:
-            if headers is None:
-                headers = {}
+            res = urllib2.urlopen(req)
+            self.auth_handler.process_res(res)
+            return res
+        except urllib2.HTTPError as e:
+            if e.code == 401:
+                retry = self.auth_handler.process_err(e)
+                if retry:
+                    debug("Re-Fetching %s..." % url)
+                    self.auth_handler.prepare_req(req)
+                    res = urllib2.urlopen(req)
+                    self.auth_handler.process_res(res)
+                    return res
+                else:
+                    debug("Not re-fetching")
+                    raise
             else:
-                headers = copy.copy(headers)
+                raise
+
+    def _get_json(self, template, server, path):
+        try:
+            headers = {}
             headers["Accept"] = "application/json")
             res = self._get_url(template, server, path, headers)
             data = json.loads(res.read())
-- 
2.7.4

--
libvir-list mailing list
libvir-list@redhat.com
https://www.redhat.com/mailman/listinfo/libvir-list

Reply via email to