This patch add a couple of new SSH utility functions to
the ssh module:
- clearing the whole 'ganeti_pub_keys' file
- overriding the whole 'ganeti_pub_keys' file
- retrieving all keys from the file at once

Those functions will be used in later patches. Unit tests
are provided.

Signed-off-by: Helga Velroyen <hel...@google.com>
---
 lib/ssh.py                     | 37 +++++++++++++++++++++++++++++++++++--
 src/Ganeti/Constants.hs        |  3 +++
 test/py/ganeti.ssh_unittest.py | 22 ++++++++++++++++++++++
 3 files changed, 60 insertions(+), 2 deletions(-)

diff --git a/lib/ssh.py b/lib/ssh.py
index 7cdf0b7..82a394d 100644
--- a/lib/ssh.py
+++ b/lib/ssh.py
@@ -27,6 +27,7 @@
 import logging
 import os
 import tempfile
+import stat
 
 from functools import partial
 
@@ -37,6 +38,7 @@ from ganeti import netutils
 from ganeti import pathutils
 from ganeti import vcluster
 from ganeti import compat
+from ganeti import serializer
 
 
 def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA,
@@ -528,6 +530,36 @@ def ReplaceNameByUuid(node_uuid, node_name, 
key_file=pathutils.SSH_PUB_KEYS,
                         error_fn=error_fn)
 
 
+def ClearPubKeyFile(key_file=pathutils.SSH_PUB_KEYS, mode=0600):
+  """Resets the content of the public key file.
+
+  """
+  utils.WriteFile(key_file, data="", mode=mode)
+
+
+def OverridePubKeyFile(key_map, key_file=pathutils.SSH_PUB_KEYS,
+                       error_fn=errors.ProgrammerError):
+  """Overrides the public key file with a list of given keys.
+
+  @type key_map: dict from str to list of str
+  @param key_map: dictionary mapping uuids to lists of SSH keys
+
+  """
+  try:
+    fd_tmp, tmpname = tempfile.mkstemp(dir=os.path.dirname(key_file))
+    f_tmp = os.fdopen(fd_tmp, "w")
+    for (uuid, keys) in key_map.items():
+      for key in keys:
+        f_tmp.write("%s %s\n" % (uuid, key))
+    f_tmp.flush()
+    os.rename(tmpname, key_file)
+    os.chmod(key_file, stat.S_IRUSR | stat.S_IWUSR)
+  except IOError, e:
+    raise error_fn("Cannot override key file due to error '%s'" % e)
+  finally:
+    f_tmp.close()
+
+
 def QueryPubKeyFile(target_uuids, key_file=pathutils.SSH_PUB_KEYS,
                     error_fn=errors.ProgrammerError):
   """Retrieves a map of keys for the requested node UUIDs.
@@ -545,6 +577,7 @@ def QueryPubKeyFile(target_uuids, 
key_file=pathutils.SSH_PUB_KEYS,
   @return: dictionary mapping node uuids to their ssh keys
 
   """
+  all_keys = target_uuids is None
   if isinstance(target_uuids, str):
     target_uuids = [target_uuids]
   result = {}
@@ -554,7 +587,7 @@ def QueryPubKeyFile(target_uuids, 
key_file=pathutils.SSH_PUB_KEYS,
       (uuid, key) = _ParseKeyLine(line, error_fn)
       if not uuid:
         continue
-      if uuid in target_uuids:
+      if all_keys or (uuid in target_uuids):
         if uuid not in result:
           result[uuid] = []
         result[uuid].append(key)
@@ -597,7 +630,7 @@ def InitPubKeyFile(master_uuid, 
key_file=pathutils.SSH_PUB_KEYS):
 
   """
   _, pub_key, _ = GetUserFiles(constants.SSH_LOGIN_USER)
-  utils.WriteFile(key_file, data="", mode=0600)
+  ClearPubKeyFile(key_file=key_file)
   key = utils.ReadFile(pub_key)
   AddPublicKey(master_uuid, key, key_file=key_file)
 
diff --git a/src/Ganeti/Constants.hs b/src/Ganeti/Constants.hs
index 4d3c4b7..d018b20 100644
--- a/src/Ganeti/Constants.hs
+++ b/src/Ganeti/Constants.hs
@@ -4475,6 +4475,9 @@ sshsSshRootKey = "ssh_root_key"
 sshsSshAuthorizedKeys :: String
 sshsSshAuthorizedKeys = "authorized_keys"
 
+sshsSshPublicKeys :: String
+sshsSshPublicKeys = "public_keys"
+
 sshsNodeDaemonCertificate :: String
 sshsNodeDaemonCertificate = "node_daemon_certificate"
 
diff --git a/test/py/ganeti.ssh_unittest.py b/test/py/ganeti.ssh_unittest.py
index 84e96cf..9abc629 100755
--- a/test/py/ganeti.ssh_unittest.py
+++ b/test/py/ganeti.ssh_unittest.py
@@ -343,6 +343,12 @@ class TestPublicSshKeys(testutils.GanetiTestCase):
     self.assertEquals([self.KEY_B], result[self.UUID_2])
     self.assertEquals(2, len(result))
 
+    # Query all keys
+    target_uuids = None
+    result = ssh.QueryPubKeyFile(target_uuids, key_file=pub_key_file)
+    self.assertEquals([self.KEY_A], result[self.UUID_1])
+    self.assertEquals([self.KEY_B], result[self.UUID_2])
+
   def testReplaceNameByUuid(self):
     pub_key_file = self._CreateTempFile()
     name = "my.precious.node"
@@ -379,6 +385,22 @@ class TestPublicSshKeys(testutils.GanetiTestCase):
     result = ssh.QueryPubKeyFile(self.UUID_1, key_file=pub_key_file)
     self.assertEquals([self.KEY_A], result[self.UUID_1])
 
+  def testClearPubKeyFile(self):
+    pub_key_file = self._CreateTempFile()
+    ssh.AddPublicKey(self.UUID_2, self.KEY_A, key_file=pub_key_file)
+    ssh.ClearPubKeyFile(key_file=pub_key_file)
+    self.assertFileContent(pub_key_file, "")
+
+  def testOverridePubKeyFile(self):
+    pub_key_file = self._CreateTempFile()
+    key_map = {self.UUID_1: [self.KEY_A, self.KEY_B],
+               self.UUID_2: [self.KEY_A]}
+    ssh.OverridePubKeyFile(key_map, key_file=pub_key_file)
+    self.assertFileContent(pub_key_file,
+      "123-456 ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
+      "123-456 ssh-dss BAasjkakfa234SFSFDA345462AAAB root@key-b\n"
+      "789-ABC ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n")
+
 
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
-- 
2.1.0.rc2.206.gedb03e5

Reply via email to