This patch initializes the "ganeti_pub_keys" file on
cluster initialization and adds the master's key to it.
On node-add, the key file is queried for the keys of
the master candidates and those are transferred to the
new node and added to its "authorized_keys" file.

Signed-off-by: Helga Velroyen <hel...@google.com>
---
 lib/bootstrap.py                                   |  4 ++
 lib/client/gnt_node.py                             | 15 ++++++-
 lib/ssh.py                                         | 51 ++++++++++++++++++----
 lib/tools/prepare_node_join.py                     | 40 +++++++++++------
 src/Ganeti/Constants.hs                            |  3 ++
 test/py/ganeti.ssh_unittest.py                     | 20 +++++++++
 test/py/ganeti.tools.prepare_node_join_unittest.py | 14 +++++-
 7 files changed, 122 insertions(+), 25 deletions(-)

diff --git a/lib/bootstrap.py b/lib/bootstrap.py
index 9f7e681..0160f6b 100644
--- a/lib/bootstrap.py
+++ b/lib/bootstrap.py
@@ -27,6 +27,7 @@ import os
 import os.path
 import re
 import logging
+import tempfile
 import time
 import tempfile
 
@@ -836,6 +837,9 @@ def InitCluster(cluster_name, mac_prefix, # pylint: 
disable=R0913, R0914
   cfg.Update(cfg.GetClusterInfo(), logging.error)
   ssconf.WriteSsconfFiles(cfg.GetSsconfValues())
 
+  master_uuid = cfg.GetMasterNode()
+  if modify_ssh_setup:
+    ssh.InitPubKeyFile(master_uuid)
   # set up the inter-node password and certificate
   _InitGanetiServerSetup(hostname.name)
 
diff --git a/lib/client/gnt_node.py b/lib/client/gnt_node.py
index d10469f..d77d031 100644
--- a/lib/client/gnt_node.py
+++ b/lib/client/gnt_node.py
@@ -183,7 +183,7 @@ def _ReadSshKeys(keyfiles, _tostderr_fn=ToStderr):
   return result
 
 
-def _SetupSSH(options, cluster_name, node, ssh_port):
+def _SetupSSH(options, cluster_name, node, ssh_port, cl):
   """Configures a destination node's SSH daemon.
 
   @param options: Command line options
@@ -193,8 +193,18 @@ def _SetupSSH(options, cluster_name, node, ssh_port):
   @param node: Destination node name
   @type ssh_port: int
   @param ssh_port: Destination node ssh port
+  @param cl: luxi client
 
   """
+  # Retrieve the list of master and master candidates
+  candidate_filter = ["|", ["=", "role", "M"], ["=", "role", "C"]]
+  result = cl.Query(constants.QR_NODE, ["uuid"], candidate_filter)
+  if len(result.data) < 1:
+    raise errors.OpPrereqError("No master or master candidate nodes are"
+                               " found.")
+  candidates = [uuid for (_, uuid) in result.data[0]]
+  candidate_keys = ssh.QueryPubKeyFile(candidates)
+
   if options.force_join:
     ToStderr("The \"--force-join\" option is no longer supported and will be"
              " ignored.")
@@ -214,6 +224,7 @@ def _SetupSSH(options, cluster_name, node, ssh_port):
     constants.SSHS_NODE_DAEMON_CERTIFICATE: cert_pem,
     constants.SSHS_SSH_HOST_KEY: host_keys,
     constants.SSHS_SSH_ROOT_KEY: root_keys,
+    constants.SSHS_SSH_AUTHORIZED_KEYS: candidate_keys,
     }
 
   bootstrap.RunNodeSetupCmd(cluster_name, node, pathutils.PREPARE_NODE_JOIN,
@@ -289,7 +300,7 @@ def AddNode(opts, args):
              "and grant full intra-cluster ssh root access to/from it\n", node)
 
   if opts.node_setup:
-    _SetupSSH(opts, cluster_name, node, ssh_port)
+    _SetupSSH(opts, cluster_name, node, ssh_port, cl)
 
   bootstrap.SetupNodeDaemon(opts, cluster_name, node, ssh_port)
 
diff --git a/lib/ssh.py b/lib/ssh.py
index 61c7d2e..34d709d 100644
--- a/lib/ssh.py
+++ b/lib/ssh.py
@@ -132,16 +132,16 @@ def _SplitSshKey(key):
     return (True, parts)
 
 
-def AddAuthorizedKey(file_obj, key):
-  """Adds an SSH public key to an authorized_keys file.
+def AddAuthorizedKeys(file_obj, keys):
+  """Adds a list of SSH public key to an authorized_keys file.
 
   @type file_obj: str or file handle
   @param file_obj: path to authorized_keys file
-  @type key: str
-  @param key: string containing key
+  @type keys: list of str
+  @param keys: list of strings containing keys
 
   """
-  key_fields = _SplitSshKey(key)
+  key_field_list = [(key, _SplitSshKey(key)) for key in keys]
 
   if isinstance(file_obj, basestring):
     f = open(file_obj, "a+")
@@ -152,19 +152,39 @@ def AddAuthorizedKey(file_obj, key):
     nl = True
     for line in f:
       # Ignore whitespace changes
-      if _SplitSshKey(line) == key_fields:
+      line_key = _SplitSshKey(line)
+      key_found = False
+      for (key, split_key) in key_field_list:
+        if line_key == split_key:
+          key_found = True
+          key_field_list.remove((key, split_key))
+          break
+      if key_found:
         break
       nl = line.endswith("\n")
     else:
       if not nl:
         f.write("\n")
-      f.write(key.rstrip("\r\n"))
-      f.write("\n")
+      for (key, _) in key_field_list:
+        f.write(key.rstrip("\r\n"))
+        f.write("\n")
       f.flush()
   finally:
     f.close()
 
 
+def AddAuthorizedKey(file_obj, key):
+  """Adds an SSH public key to an authorized_keys file.
+
+  @type file_obj: str or file handle
+  @param file_obj: path to authorized_keys file
+  @type key: str
+  @param key: string containing key
+
+  """
+  AddAuthorizedKeys(file_obj, [key])
+
+
 def RemoveAuthorizedKey(file_name, key):
   """Removes an SSH public key from an authorized_keys file.
 
@@ -557,6 +577,21 @@ def InitSSHSetup(error_fn=errors.OpPrereqError):
   AddAuthorizedKey(auth_keys, utils.ReadFile(pub_key))
 
 
+def InitPubKeyFile(master_uuid, key_file=pathutils.SSH_PUB_KEYS):
+  """Creates the public key file and adds the master node's SSH key.
+
+  @type master_uuid: str
+  @param master_uuid: the master node's UUID
+  @type key_file: str
+  @param key_file: name of the file containing the public keys
+
+  """
+  _, pub_key, _ = GetUserFiles(constants.SSH_LOGIN_USER)
+  utils.WriteFile(key_file, data="", mode=0600)
+  key = utils.ReadFile(pub_key)
+  AddPublicKey(master_uuid, key, key_file=key_file)
+
+
 class SshRunner:
   """Wrapper for SSH commands.
 
diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py
index 789c872..7d96c6a 100644
--- a/lib/tools/prepare_node_join.py
+++ b/lib/tools/prepare_node_join.py
@@ -55,6 +55,8 @@ _DATA_CHECK = ht.TStrictDict(False, True, {
   constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
   constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
   constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
+  constants.SSHS_SSH_AUTHORIZED_KEYS:
+    ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)),
   })
 
 
@@ -213,20 +215,30 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None):
 
   """
   keys = data.get(constants.SSHS_SSH_ROOT_KEY)
-  if not keys:
-    return
-
-  (auth_keys_file, keyfiles) = \
-    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
-                        _homedir_fn=_homedir_fn)
-
-  _UpdateKeyFiles(keys, dry_run, keyfiles)
-
-  if dry_run:
-    logging.info("This is a dry run, not modifying %s", auth_keys_file)
-  else:
-    for (_, _, public_key) in keys:
-      ssh.AddAuthorizedKey(auth_keys_file, public_key)
+  authorized_keys = data.get(constants.SSHS_SSH_AUTHORIZED_KEYS)
+
+  if keys or authorized_keys:
+    (auth_keys_file, keyfiles) = \
+      ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
+                          _homedir_fn=_homedir_fn)
+
+    if keys:
+      _UpdateKeyFiles(keys, dry_run, keyfiles)
+
+      if dry_run:
+        logging.info("This is a dry run, not modifying %s", auth_keys_file)
+      else:
+        for (_, _, public_key) in keys:
+          ssh.AddAuthorizedKey(auth_keys_file, public_key)
+
+    if authorized_keys:
+      if dry_run:
+        logging.info("This is a dry run, not modifying %s", auth_keys_file)
+      else:
+        all_authorized_keys = []
+        for keys in authorized_keys.values():
+          all_authorized_keys += keys
+        ssh.AddAuthorizedKeys(auth_keys_file, all_authorized_keys)
 
 
 def LoadData(raw):
diff --git a/src/Ganeti/Constants.hs b/src/Ganeti/Constants.hs
index 2da4858..e29526b 100644
--- a/src/Ganeti/Constants.hs
+++ b/src/Ganeti/Constants.hs
@@ -4450,6 +4450,9 @@ sshsSshHostKey = "ssh_host_key"
 sshsSshRootKey :: String
 sshsSshRootKey = "ssh_root_key"
 
+sshsSshAuthorizedKeys :: String
+sshsSshAuthorizedKeys = "authorized_keys"
+
 sshsNodeDaemonCertificate :: String
 sshsNodeDaemonCertificate = "node_daemon_certificate"
 
diff --git a/test/py/ganeti.ssh_unittest.py b/test/py/ganeti.ssh_unittest.py
index cd8da19..09d3169 100755
--- a/test/py/ganeti.ssh_unittest.py
+++ b/test/py/ganeti.ssh_unittest.py
@@ -215,6 +215,26 @@ class TestSshKeys(testutils.GanetiTestCase):
       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
 
+  def testAddingNewKeys(self):
+    ssh.AddAuthorizedKeys(self.tmpname,
+                          ["ssh-dss AAAAB3NzaC1kc3MAAACB root@test"])
+    self.assertFileContent(self.tmpname,
+      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
+      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
+      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
+      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
+
+    ssh.AddAuthorizedKeys(self.tmpname,
+                          ["ssh-dss AAAAB3asdfasdfaYTUCB laracroft@test",
+                           "ssh-dss AasdfliuobaosfMAAACB frodo@test"])
+    self.assertFileContent(self.tmpname,
+      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
+      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
+      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
+      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n"
+      "ssh-dss AAAAB3asdfasdfaYTUCB laracroft@test\n"
+      "ssh-dss AasdfliuobaosfMAAACB frodo@test\n")
+
 
 class TestPublicSshKeys(testutils.GanetiTestCase):
   """Test case for the handling of the list of public ssh keys."""
diff --git a/test/py/ganeti.tools.prepare_node_join_unittest.py 
b/test/py/ganeti.tools.prepare_node_join_unittest.py
index 54d1d62..fe4ff26 100755
--- a/test/py/ganeti.tools.prepare_node_join_unittest.py
+++ b/test/py/ganeti.tools.prepare_node_join_unittest.py
@@ -55,10 +55,22 @@ class TestLoadData(unittest.TestCase):
     raw = serializer.DumpJson([])
     self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
 
-  def testValidData(self):
+  def testEmptyDict(self):
     raw = serializer.DumpJson({})
     self.assertEqual(prepare_node_join.LoadData(raw), {})
 
+  def testValidData(self):
+    key_list = [[constants.SSHK_DSA, "private foo", "public bar"]]
+    data_dict = {
+      constants.SSHS_CLUSTER_NAME: "Skynet",
+      constants.SSHS_SSH_HOST_KEY: key_list,
+      constants.SSHS_SSH_ROOT_KEY: key_list,
+      constants.SSHS_SSH_AUTHORIZED_KEYS:
+        {"nodeuuid01234": ["foo"],
+         "nodeuuid56789": ["bar"]}}
+    raw = serializer.DumpJson(data_dict)
+    self.assertEqual(prepare_node_join.LoadData(raw), data_dict)
+
 
 class TestVerifyCertificate(testutils.GanetiTestCase):
   def setUp(self):
-- 
2.0.0.526.g5318336

Reply via email to