The ssh-keygen utility permits only some combinations of key types and bit sizes. As many more things can go wrong late in the renewal process, this patch introduces prerequisite checks mimicking those of ssh-keygen.
Signed-off-by: Hrvoje Ribicic <[email protected]> --- lib/client/gnt_cluster.py | 6 ++---- lib/cmdlib/cluster/__init__.py | 33 +++++++++++++++++++++------------ lib/ssh.py | 42 +++++++++++++++++++++++++++++++++++++++++- test/py/ganeti.ssh_unittest.py | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 96 insertions(+), 17 deletions(-) diff --git a/lib/client/gnt_cluster.py b/lib/client/gnt_cluster.py index 59f1eb9..c47c916 100644 --- a/lib/client/gnt_cluster.py +++ b/lib/client/gnt_cluster.py @@ -304,10 +304,8 @@ def InitCluster(opts, args): else: ssh_key_type = constants.SSH_DEFAULT_KEY_TYPE - if opts.ssh_key_bits: - ssh_key_bits = opts.ssh_key_bits - else: - ssh_key_bits = constants.SSH_DEFAULT_KEY_BITS + ssh_key_bits = ssh.DetermineKeyBits(ssh_key_type, opts.ssh_key_bits, None, + None) bootstrap.InitCluster(cluster_name=args[0], secondary_ip=opts.secondary_ip, diff --git a/lib/cmdlib/cluster/__init__.py b/lib/cmdlib/cluster/__init__.py index 5658646..3147f96 100644 --- a/lib/cmdlib/cluster/__init__.py +++ b/lib/cmdlib/cluster/__init__.py @@ -87,6 +87,23 @@ class LUClusterRenewCrypto(NoHooksLU): self.share_locks = ShareAll() self.share_locks[locking.LEVEL_NODE] = 0 + def CheckPrereq(self): + """Check prerequisites. + + Notably the compatibility of specified key bits and key type. + + """ + cluster_info = self.cfg.GetClusterInfo() + + self.ssh_key_type = self.op.ssh_key_type + if self.ssh_key_type is None: + self.ssh_key_type = cluster_info.ssh_key_type + + self.ssh_key_bits = ssh.DetermineKeyBits(self.ssh_key_type, + self.op.ssh_key_bits, + cluster_info.ssh_key_type, + cluster_info.ssh_key_bits) + def _RenewNodeSslCertificates(self, feedback_fn): """Renews the nodes' SSL certificates. @@ -167,28 +184,20 @@ class LUClusterRenewCrypto(NoHooksLU): cluster_info = self.cfg.GetClusterInfo() - new_ssh_key_type = self.op.ssh_key_type - if new_ssh_key_type is None: - new_ssh_key_type = cluster_info.ssh_key_type - - new_ssh_key_bits = self.op.ssh_key_bits - if new_ssh_key_bits is None: - new_ssh_key_bits = cluster_info.ssh_key_bits - result = self.rpc.call_node_ssh_keys_renew( [master_uuid], node_uuids, node_names, master_candidate_uuids, potential_master_candidates, cluster_info.ssh_key_type, # Old key type - new_ssh_key_type, # New key type - new_ssh_key_bits) # New key bits + self.ssh_key_type, # New key type + self.ssh_key_bits) # New key bits result[master_uuid].Raise("Could not renew the SSH keys of all nodes") # After the keys have been successfully swapped, time to commit the change # in key type - cluster_info.ssh_key_type = new_ssh_key_type - cluster_info.ssh_key_bits = new_ssh_key_bits + cluster_info.ssh_key_type = self.ssh_key_type + cluster_info.ssh_key_bits = self.ssh_key_bits self.cfg.Update(cluster_info, feedback_fn) def Exec(self, feedback_fn): diff --git a/lib/ssh.py b/lib/ssh.py index d2684fc..7b27214 100644 --- a/lib/ssh.py +++ b/lib/ssh.py @@ -37,6 +37,7 @@ import logging import os import tempfile +from collections import namedtuple from functools import partial from ganeti import utils @@ -1094,5 +1095,44 @@ def ReadRemoteSshPubKeys(pub_key_file, node, cluster_name, port, ask_key, if result.failed: raise errors.OpPrereqError("Could not fetch a public SSH key (%s) from node" " '%s': ran command '%s', failure reason: '%s'." - % (pub_key_file, node, cmd, result.fail_reason)) + % (pub_key_file, node, cmd, result.fail_reason), + errors.ECODE_INVAL) return result.stdout + + +KeyBitInfo = namedtuple('KeyBitInfo', ['default', 'validation_fn']) +SSH_KEY_VALID_BITS = { + constants.SSHK_DSA: KeyBitInfo(1024, lambda b: b == 1024), + constants.SSHK_RSA: KeyBitInfo(2048, lambda b: b >= 768), + constants.SSHK_ECDSA: KeyBitInfo(384, lambda b: b in [256, 384, 521]), +} + + +def DetermineKeyBits(key_type, key_bits, old_key_type, old_key_bits): + """Checks the key bits to be used for a given key type, or provides defaults. + + @type key_type: one of L{constants.SSHK_ALL} + @param key_type: The key type to use. + @type key_bits: positive int or None + @param key_bits: The number of bits to use, if supplied by user. + @type old_key_type: one of L{constants.SSHK_ALL} or None + @param old_key_type: The previously used key type, if any. + @type old_key_bits: positive int or None + @param old_key_bits: The previously used number of bits, if any. + + @rtype: positive int + @return: The number of bits to use. + + """ + if key_bits is None: + if old_key_type is not None and old_key_type == key_type: + key_bits = old_key_bits + else: + key_bits = SSH_KEY_VALID_BITS[key_type].default + + if not SSH_KEY_VALID_BITS[key_type].validation_fn(key_bits): + raise errors.OpPrereqError("Invalid key type and bit size combination:" + " %s with %s bits" % (key_type, key_bits), + errors.ECODE_INVAL) + + return key_bits diff --git a/test/py/ganeti.ssh_unittest.py b/test/py/ganeti.ssh_unittest.py index b13dda1..265adec 100755 --- a/test/py/ganeti.ssh_unittest.py +++ b/test/py/ganeti.ssh_unittest.py @@ -488,5 +488,37 @@ class TestGetUserFiles(testutils.GanetiTestCase): self.assertTrue(os.path.exists(self.priv_filename + suffix + ".pub")) +class TestDetermineKeyBits(): + def testCompleteness(self): + self.assertEquals(constants.SSHK_ALL, ssh.SSH_KEY_VALID_BITS.keys()) + + def testAdoptDefault(self): + self.assertEquals(2048, DetermineKeyBits("rsa", None, None, None)) + self.assertEquals(1024, DetermineKeyBits("dsa", None, None, None)) + + def testAdoptOldKeySize(self): + self.assertEquals(4098, DetermineKeyBits("rsa", None, "rsa", 4098)) + self.assertEquals(2048, DetermineKeyBits("rsa", None, "dsa", 1024)) + + def testDsaSpecificValues(self): + self.assertRaises(errors.OpPrereqError, DetermineKeyBits, "dsa", 2048, + None, None) + self.assertRaises(errors.OpPrereqError, DetermineKeyBits, "dsa", 512, + None, None) + self.assertEquals(1024, DetermineKeyBits("dsa", None, None, None)) + + def testEcdsaSpecificValues(self): + self.assertRaises(errors.OpPrereqError, DetermineKeyBits, "ecdsa", 2048, + None, None) + for b in [256, 384, 521]: + self.assertEquals(b, DetermineKeyBits("ecdsa", b, None, None)) + + def testRsaSpecificValues(self): + self.assertRaises(errors.OpPrereqError, DetermineKeyBits, "dsa", 766, + None, None) + for b in [768, 769, 2048, 2049, 4096]: + self.assertEquals(b, DetermineKeyBits("rsa", b, None, None)) + + if __name__ == "__main__": testutils.GanetiTestProgram() -- 2.6.0.rc2.230.g3dd15c0
