This patch implements the handling of SSH keys, when a new node is added. It introduces the new RPC call 'ssh_add_key', which is called to the master's noded when a new node is added. In the backend implementation, noded takes care of distributing the new node's SSH key information to all other nodes in the cluster which are supposed to have this information.
Note: It was rather tedious to test the backend function, because it was calling many other functions which would have needed to be mocked. Instead I added the public key file as a parameter, because this way I could at least reduce the complexity of the test setup and at the same time have direct access to the file that gets manipulated. Also Note: Up till now, there is still only the common cluster SSH key around. I wanted to have some infrastructure in place, before actually individual keys are generated. Signed-off-by: Helga Velroyen <hel...@google.com> --- lib/backend.py | 106 ++++++++++++++++++++++ lib/bootstrap.py | 102 ++-------------------- lib/client/gnt_node.py | 13 ++- lib/cmdlib/node.py | 25 ++++++ lib/errors.py | 6 ++ lib/rpc_defs.py | 12 +++ lib/server/noded.py | 12 +++ lib/ssh.py | 131 ++++++++++++++++++++++++++++ lib/tools/ssh_update.py | 43 +++++++-- test/py/ganeti.backend_unittest.py | 174 +++++++++++++++++++++++++++++++++++++ 10 files changed, 514 insertions(+), 110 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 69548e0..92838f3 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -61,6 +61,7 @@ import stat import tempfile import time import zlib +import copy from ganeti import errors from ganeti import http @@ -1294,6 +1295,111 @@ def EnsureDaemon(daemon_name, run): return fn(daemon_name) +def _InitSshUpdateData(data, noded_cert_file, ssconf_store): + (_, noded_cert) = \ + utils.ExtractX509Certificate(utils.ReadFile(noded_cert_file)) + data[constants.SSHS_NODE_DAEMON_CERTIFICATE] = noded_cert + + cluster_name = ssconf_store.GetClusterName() + data[constants.SSHS_CLUSTER_NAME] = cluster_name + + +def AddNodeSshKey(node_uuid, node_name, + to_authorized_keys, to_public_keys, + get_pub_keys, ssh_port_map, + potential_master_candidates, + pub_key_file=pathutils.SSH_PUB_KEYS, + ssconf_store=None, + noded_cert_file=pathutils.NODED_CERT_FILE, + run_cmd_fn=ssh.RunSshCmdWithStdin): + """Distributes a node's public SSH key across the cluster. + + Note that this function should only be executed on the master node, which + then will copy the new node's key to all nodes in the cluster via SSH. + + @type node_uuid: str + @param node_uuid: the UUID of the node whose key is added + @type node_name: str + @param node_name: the name of the node whose key is added + @type to_authorized_keys: boolean + @param to_authorized_keys: whether the key should be added to the + C{authorized_keys} file of all nodes + @type to_public_keys: boolean + @param to_public_keys: whether the keys should be added to the public key file + @type get_pub_keys: boolean + @param get_pub_keys: whether the node should add the clusters' public keys + to its {ganeti_pub_keys} file + @type ssh_port_map: dict from str to int + @param ssh_port_map: a mapping from node names to SSH port numbers + @type potential_master_candidates: list of str + @param potential_master_candidates: list of node names of potential master + candidates; this should match the list of uuids in the public key file + + """ + if not ssconf_store: + ssconf_store = ssconf.SimpleStore() + + # Check and fix sanity of key file + keys_by_name = ssh.QueryPubKeyFile([node_name], key_file=pub_key_file) + keys_by_uuid = {} + if not keys_by_name or node_name not in keys_by_name: + raise errors.SshUpdateError("No keys found for the new node '%s' in the" + " list of public SSH keys." % node_name) + else: + # Replace the name by UUID in the file as the name should only be used + # temporarily + ssh.ReplaceNameByUuid(node_uuid, node_name, error_fn=errors.SshUpdateError, + key_file=pub_key_file) + keys_by_uuid[node_uuid] = keys_by_name[node_name] + + # Update the master node's key files + if to_authorized_keys: + (auth_key_file, _) = \ + ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False) + ssh.AddAuthorizedKeys(auth_key_file, keys_by_uuid[node_uuid]) + + base_data = {} + _InitSshUpdateData(base_data, noded_cert_file, ssconf_store) + cluster_name = base_data[constants.SSHS_CLUSTER_NAME] + + # Update all nodes except master and the target node + if to_authorized_keys: + base_data[constants.SSHS_SSH_AUTHORIZED_KEYS] = \ + (constants.SSHS_ADD, keys_by_uuid) + + pot_mc_data = copy.deepcopy(base_data) + if to_public_keys: + pot_mc_data[constants.SSHS_SSH_PUBLIC_KEYS] = \ + (constants.SSHS_ADD, keys_by_uuid) + + all_nodes = ssconf_store.GetNodeList() + master_node = ssconf_store.GetMasterNode() + + for node in all_nodes: + if node in [master_node, node_name]: + continue + if node in potential_master_candidates: + run_cmd_fn(cluster_name, node, pathutils.SSH_UPDATE, + True, True, False, False, False, + ssh_port_map.get(node), pot_mc_data, ssconf_store) + else: + if to_authorized_keys: + run_cmd_fn(cluster_name, node, pathutils.SSH_UPDATE, + True, True, False, False, False, + ssh_port_map.get(node), base_data, ssconf_store) + + # Update the target node itself + if get_pub_keys: + node_data = {} + _InitSshUpdateData(node_data, noded_cert_file, ssconf_store) + all_keys = ssh.QueryPubKeyFile(None, key_file=pub_key_file) + node_data[constants.SSHS_SSH_PUBLIC_KEYS] = \ + (constants.SSHS_OVERRIDE, all_keys) + run_cmd_fn(cluster_name, node_name, pathutils.SSH_UPDATE, + True, True, False, False, False, + ssh_port_map.get(node_name), node_data, ssconf_store) + + def GetBlockDevSizes(devices): """Return the size of the given block devices diff --git a/lib/bootstrap.py b/lib/bootstrap.py index dba7dd4..e474b1d 100644 --- a/lib/bootstrap.py +++ b/lib/bootstrap.py @@ -36,9 +36,7 @@ import os import os.path import re import logging -import tempfile import time -import tempfile from ganeti.cmdlib import cluster import ganeti.rpc.node as rpc @@ -260,95 +258,6 @@ def _WaitForSshDaemon(hostname, port, family): (hostname, port, hostip, _DAEMON_READY_TIMEOUT)) -def RunNodeSetupCmd(cluster_name, node, basecmd, debug, verbose, - use_cluster_key, ask_key, strict_host_check, - port, data): - """Runs a command to configure something on a remote machine. - - @type cluster_name: string - @param cluster_name: Cluster name - @type node: string - @param node: Node name - @type basecmd: string - @param basecmd: Base command (path on the remote machine) - @type debug: bool - @param debug: Enable debug output - @type verbose: bool - @param verbose: Enable verbose output - @type use_cluster_key: bool - @param use_cluster_key: See L{ssh.SshRunner.BuildCmd} - @type ask_key: bool - @param ask_key: See L{ssh.SshRunner.BuildCmd} - @type strict_host_check: bool - @param strict_host_check: See L{ssh.SshRunner.BuildCmd} - @type port: int - @param port: The SSH port of the remote machine or None for the default - @param data: JSON-serializable input data for script (passed to stdin) - - """ - cmd = [basecmd] - - # Pass --debug/--verbose to the external script if set on our invocation - if debug: - cmd.append("--debug") - - if verbose: - cmd.append("--verbose") - - logging.debug("Node setup command: %s", cmd) - - version = constants.DIR_VERSION - all_cmds = [["test", "-d", os.path.join(pathutils.PKGLIBDIR, version)]] - if constants.HAS_GNU_LN: - all_cmds.extend([["ln", "-s", "-f", "-T", - os.path.join(pathutils.PKGLIBDIR, version), - os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")], - ["ln", "-s", "-f", "-T", - os.path.join(pathutils.SHAREDIR, version), - os.path.join(pathutils.SYSCONFDIR, "ganeti/share")]]) - else: - all_cmds.extend([["rm", "-f", - os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")], - ["ln", "-s", "-f", - os.path.join(pathutils.PKGLIBDIR, version), - os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")], - ["rm", "-f", - os.path.join(pathutils.SYSCONFDIR, "ganeti/share")], - ["ln", "-s", "-f", - os.path.join(pathutils.SHAREDIR, version), - os.path.join(pathutils.SYSCONFDIR, "ganeti/share")]]) - all_cmds.append(cmd) - - if port is None: - port = netutils.GetDaemonPort(constants.SSH) - - family = ssconf.SimpleStore().GetPrimaryIPFamily() - srun = ssh.SshRunner(cluster_name, - ipv6=(family == netutils.IP6Address.family)) - scmd = srun.BuildCmd(node, constants.SSH_LOGIN_USER, - utils.ShellQuoteArgs( - utils.ShellCombineCommands(all_cmds)), - batch=False, ask_key=ask_key, quiet=False, - strict_host_check=strict_host_check, - use_cluster_key=use_cluster_key, - port=port) - - tempfh = tempfile.TemporaryFile() - try: - tempfh.write(serializer.DumpJson(data)) - tempfh.seek(0) - - result = utils.RunCmd(scmd, interactive=True, input_fd=tempfh) - finally: - tempfh.close() - - if result.failed: - raise errors.OpExecError("Command '%s' failed: %s" % - (result.cmd, result.fail_reason)) - - _WaitForSshDaemon(node, port, family) - - def _InitFileStorageDir(file_storage_dir): """Initialize if needed the file storage. @@ -986,11 +895,14 @@ def SetupNodeDaemon(opts, cluster_name, node, ssh_port): constants.NDS_START_NODE_DAEMON: True, } - RunNodeSetupCmd(cluster_name, node, pathutils.NODE_DAEMON_SETUP, - opts.debug, opts.verbose, - True, opts.ssh_key_check, opts.ssh_key_check, - ssh_port, data) + ssconf_store = ssconf.SimpleStore() + family = ssconf_store.GetPrimaryIPFamily() + ssh.RunSshCmdWithStdin(cluster_name, node, pathutils.NODE_DAEMON_SETUP, + opts.debug, opts.verbose, + True, opts.ssh_key_check, opts.ssh_key_check, + ssh_port, data, ssconf_store, ensure_version=True) + _WaitForSshDaemon(node, ssh_port, family) _WaitForNodeDaemon(node) diff --git a/lib/client/gnt_node.py b/lib/client/gnt_node.py index c633bf2..fb70a78 100644 --- a/lib/client/gnt_node.py +++ b/lib/client/gnt_node.py @@ -251,9 +251,8 @@ def _SetupSSH(options, cluster_name, node, ssh_port, cl): candidate_filter = ["|", ["=", "role", "M"], ["=", "role", "C"]] result = cl.Query(constants.QR_NODE, ["uuid"], candidate_filter) if len(result.data) < 1: - raise errors.OpPrereqError("No master or master candidate nodes are" - " found.") - candidates = [uuid for (_, uuid) in result.data[0]] + raise errors.OpPrereqError("No master or master candidate node is found.") + candidates = [uuid for ((_, uuid),) in result.data] candidate_keys = ssh.QueryPubKeyFile(candidates) if options.force_join: @@ -278,10 +277,10 @@ def _SetupSSH(options, cluster_name, node, ssh_port, cl): constants.SSHS_SSH_AUTHORIZED_KEYS: candidate_keys, } - bootstrap.RunNodeSetupCmd(cluster_name, node, pathutils.PREPARE_NODE_JOIN, - options.debug, options.verbose, False, - options.ssh_key_check, options.ssh_key_check, - ssh_port, data) + ssh.RunSshCmdWithStdin(cluster_name, node, pathutils.PREPARE_NODE_JOIN, + options.debug, options.verbose, False, + options.ssh_key_check, options.ssh_key_check, + ssh_port, data, ssconf.SimpleStore()) fetched_keys = _ReadRemoteSshPubKeys(root_keyfiles, node, cluster_name, ssh_port, options.ssh_key_check, diff --git a/lib/cmdlib/node.py b/lib/cmdlib/node.py index 275785b..3cf2fd3 100644 --- a/lib/cmdlib/node.py +++ b/lib/cmdlib/node.py @@ -54,6 +54,7 @@ from ganeti.cmdlib.common import CheckParamsNotGlobal, \ FindFaultyInstanceDisks, CheckStorageTypeEnabled, CreateNewClientCert, \ AddNodeCertToCandidateCerts, RemoveNodeCertFromCandidateCerts, \ EnsureKvmdOnNodes +from ganeti.ssh import GetSshPortMap def _DecideSelfPromotion(lu, exceptions=None): @@ -337,6 +338,22 @@ class LUNodeAdd(LogicalUnit): self.new_node.name, ovs_name, ovs_link) result.Raise("Failed to initialize OpenVSwitch on new node") + def _SshUpdate(self, new_node_uuid, new_node_name, is_master_candidate, + is_potential_master_candidate, rpcrunner): + """Update the SSH setup of all nodes after adding a new node. + + """ + potential_master_candidates = self.cfg.GetPotentialMasterCandidates() + master_node = self.cfg.GetMasterNode() + port_map = GetSshPortMap(potential_master_candidates, self.cfg) + + result = rpcrunner.call_node_ssh_key_add( + [master_node], new_node_uuid, new_node_name, + is_master_candidate, is_potential_master_candidate, + is_potential_master_candidate, port_map, + potential_master_candidates) + result[master_node].Raise("Could not update the node's SSH setup.") + def Exec(self, feedback_fn): """Adds the new node to the cluster. @@ -439,6 +456,14 @@ class LUNodeAdd(LogicalUnit): EnsureKvmdOnNodes(self, feedback_fn, nodes=[self.new_node.uuid]) + # Update SSH setup of all nodes + modify_ssh_setup = self.cfg.GetClusterInfo().modify_ssh_setup + if modify_ssh_setup: + # FIXME: so far, all nodes are considered potential master candidates + self._SshUpdate(self.new_node.uuid, self.new_node.name, + self.new_node.master_candidate, True, + self.rpc) + class LUNodeSetParams(LogicalUnit): """Modifies the parameters of a node. diff --git a/lib/errors.py b/lib/errors.py index 03bb862..b53ced7 100644 --- a/lib/errors.py +++ b/lib/errors.py @@ -454,6 +454,12 @@ class FileStoragePathError(GenericError): """ +class SshUpdateError(GenericError): + """Error from updating the SSH setup. + + """ + + # errors should be added above diff --git a/lib/rpc_defs.py b/lib/rpc_defs.py index dd9503a..8baf5dd 100644 --- a/lib/rpc_defs.py +++ b/lib/rpc_defs.py @@ -533,6 +533,18 @@ _NODE_CALLS = [ ("daemon", None, "Daemon name"), ("run", None, "Whether the daemon should be running or stopped"), ], None, None, "Ensure daemon is running on the node."), + ("node_ssh_key_add", MULTI, None, constants.RPC_TMO_URGENT, [ + ("node_uuid", None, "UUID of the node whose key is distributed"), + ("node_name", None, "Name of the node whose key is distributed"), + ("to_authorized_keys", None, "Whether the node's key should be added" + " to all nodes' 'authorized_keys' file"), + ("to_public_keys", None, "Whether the node's key should be added" + " to all nodes' public key file"), + ("get_public_keys", None, "Whether the node should get the other nodes'" + " public keys"), + ("ssh_port_map", None, "Map of nodes' SSH ports to be used for transfers"), + ("potential_master_candidates", None, "Potential master candidates")], + None, None, "Distribute a new node's public SSH key on the cluster."), ] _MISC_CALLS = [ diff --git a/lib/server/noded.py b/lib/server/noded.py index 6836361..2a89b2e 100644 --- a/lib/server/noded.py +++ b/lib/server/noded.py @@ -918,6 +918,18 @@ class NodeRequestHandler(http.server.HttpServerHandler): (daemon_name, run) = params return backend.EnsureDaemon(daemon_name, run) + @staticmethod + def perspective_node_ssh_key_add(params): + """Distributes a new node's SSH key if authorized. + + """ + (node_uuid, node_name, to_authorized_keys, + to_public_keys, get_public_keys, ssh_port_map, + potential_master_candidates) = params + return backend.AddNodeSshKey(node_uuid, node_name, to_authorized_keys, + to_public_keys, get_public_keys, + ssh_port_map, potential_master_candidates) + # cluster -------------------------- @staticmethod diff --git a/lib/ssh.py b/lib/ssh.py index 435e06f..663ee86 100644 --- a/lib/ssh.py +++ b/lib/ssh.py @@ -880,3 +880,134 @@ def WriteKnownHostsFile(cfg, file_name): data += "%s ssh-dss %s\n" % (cfg.GetClusterName(), cfg.GetDsaHostKey()) utils.WriteFile(file_name, mode=0600, data=data) + + +def _EnsureCorrectGanetiVersion(cmd): + """Ensured the correct Ganeti version before running a command via SSH. + + Before a command is run on a node via SSH, it makes sense in some + situations to ensure that this node is indeed running the correct + version of Ganeti like the rest of the cluster. + + @type cmd: string + @param cmd: string + @rtype: list of strings + @return: a list of commands with the newly added ones at the beginning + + """ + logging.debug("Ensure correct Ganeti version: %s", cmd) + + version = constants.DIR_VERSION + all_cmds = [["test", "-d", os.path.join(pathutils.PKGLIBDIR, version)]] + if constants.HAS_GNU_LN: + all_cmds.extend([["ln", "-s", "-f", "-T", + os.path.join(pathutils.PKGLIBDIR, version), + os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")], + ["ln", "-s", "-f", "-T", + os.path.join(pathutils.SHAREDIR, version), + os.path.join(pathutils.SYSCONFDIR, "ganeti/share")]]) + else: + all_cmds.extend([["rm", "-f", + os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")], + ["ln", "-s", "-f", + os.path.join(pathutils.PKGLIBDIR, version), + os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")], + ["rm", "-f", + os.path.join(pathutils.SYSCONFDIR, "ganeti/share")], + ["ln", "-s", "-f", + os.path.join(pathutils.SHAREDIR, version), + os.path.join(pathutils.SYSCONFDIR, "ganeti/share")]]) + all_cmds.append(cmd) + return all_cmds + + +def RunSshCmdWithStdin(cluster_name, node, basecmd, debug, verbose, + use_cluster_key, ask_key, strict_host_check, + port, data, ssconf_store, ensure_version=False): + """Runs a command on a remote machine via SSH and provides input in stdin. + + @type cluster_name: string + @param cluster_name: Cluster name + @type node: string + @param node: Node name + @type basecmd: string + @param basecmd: Base command (path on the remote machine) + @type debug: bool + @param debug: Enable debug output + @type verbose: bool + @param verbose: Enable verbose output + @type use_cluster_key: bool + @param use_cluster_key: See L{ssh.SshRunner.BuildCmd} + @type ask_key: bool + @param ask_key: See L{ssh.SshRunner.BuildCmd} + @type strict_host_check: bool + @param strict_host_check: See L{ssh.SshRunner.BuildCmd} + @type port: int + @param port: The SSH port of the remote machine or None for the default + @param data: JSON-serializable input data for script (passed to stdin) + @type ssconf_store: C{ssconf.SimpleStore} + @param ssconf_store: a SimpleStore object to be queries for ssconf values + + """ + cmd = [basecmd] + + # Pass --debug/--verbose to the external script if set on our invocation + if debug: + cmd.append("--debug") + + if verbose: + cmd.append("--verbose") + + if ensure_version: + all_cmds = _EnsureCorrectGanetiVersion(cmd) + else: + all_cmds = [cmd] + + if port is None: + port = netutils.GetDaemonPort(constants.SSH) + + family = ssconf_store.GetPrimaryIPFamily() + srun = SshRunner(cluster_name, + ipv6=(family == netutils.IP6Address.family)) + scmd = srun.BuildCmd(node, constants.SSH_LOGIN_USER, + utils.ShellQuoteArgs( + utils.ShellCombineCommands(all_cmds)), + batch=False, ask_key=ask_key, quiet=False, + strict_host_check=strict_host_check, + use_cluster_key=use_cluster_key, + port=port) + + tempfh = tempfile.TemporaryFile() + try: + tempfh.write(serializer.DumpJson(data)) + tempfh.seek(0) + + result = utils.RunCmd(scmd, interactive=True, input_fd=tempfh) + finally: + tempfh.close() + + if result.failed: + raise errors.OpExecError("Command '%s' failed: %s" % + (result.cmd, result.fail_reason)) + + +def GetSshPortMap(nodes, cfg): + """Retrieves SSH ports of given nodes from the config. + + @param nodes: the names of nodes + @type nodes: a list of strings + @param cfg: a configuration object + @type cfg: L{ConfigWriter} + @return: a map from node names to ssh ports + @rtype: a dict from str to int + + """ + node_port_map = {} + node_groups = dict(map(lambda n: (n.name, n.group), + cfg.GetAllNodesInfo().values())) + group_port_map = cfg.GetGroupSshPorts() + for node in nodes: + group_uuid = node_groups.get(node) + ssh_port = group_port_map.get(group_uuid) + node_port_map[node] = ssh_port + return node_port_map diff --git a/lib/tools/ssh_update.py b/lib/tools/ssh_update.py index db0f189..169d733 100644 --- a/lib/tools/ssh_update.py +++ b/lib/tools/ssh_update.py @@ -46,7 +46,9 @@ _DATA_CHECK = ht.TStrictDict(False, True, { constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString, constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString, constants.SSHS_SSH_PUBLIC_KEYS: - ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)), + ht.TItems( + [ht.TElemOf(constants.SSHS_ACTIONS), + ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString))]), constants.SSHS_SSH_AUTHORIZED_KEYS: ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)), }) @@ -113,14 +115,39 @@ def UpdatePubKeyFile(data, dry_run, key_file=pathutils.SSH_PUB_KEYS): @param dry_run: Whether to perform a dry run """ - public_keys = data.get(constants.SSHS_SSH_PUBLIC_KEYS) - if not public_keys: - logging.info("No public keys received. Not modifying" - " the public key file at all.") + instructions = data.get(constants.SSHS_SSH_PUBLIC_KEYS) + if not instructions: + logging.info("No instructions to modify public keys received." + " Not modifying the public key file at all.") return - if dry_run: - logging.info("This is a dry run, not modifying %s", key_file) - ssh.OverridePubKeyFile(public_keys, key_file=key_file) + (action, public_keys) = instructions + + if action == constants.SSHS_OVERRIDE: + if dry_run: + logging.info("This is a dry run, not overriding %s", key_file) + else: + ssh.OverridePubKeyFile(public_keys, key_file=key_file) + elif action == constants.SSHS_ADD: + if dry_run: + logging.info("This is a dry run, not adding a key to %s", key_file) + else: + for uuid, keys in public_keys.items(): + for key in keys: + ssh.AddPublicKey(uuid, key, key_file=key_file) + elif action == constants.SSHS_REMOVE: + if dry_run: + logging.info("This is a dry run, not removing keys from %s", key_file) + else: + for uuid in public_keys.keys(): + ssh.RemovePublicKey(uuid, key_file=key_file) + elif action == constants.SSHS_CLEAR: + if dry_run: + logging.info("This is a dry run, not clearing file %s", key_file) + else: + ssh.ClearPubKeyFile(key_file=key_file) + else: + raise SshUpdateError("Action '%s' not implemented for public keys." + % action) def Main(): diff --git a/test/py/ganeti.backend_unittest.py b/test/py/ganeti.backend_unittest.py index 1f05197..e2dc1a7 100755 --- a/test/py/ganeti.backend_unittest.py +++ b/test/py/ganeti.backend_unittest.py @@ -44,6 +44,7 @@ from ganeti import hypervisor from ganeti import netutils from ganeti import objects from ganeti import pathutils +from ganeti import ssh from ganeti import utils @@ -948,5 +949,178 @@ class TestSpaceReportingConstants(unittest.TestCase): self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type]) +class TestAddNodeSshKey(testutils.GanetiTestCase): + + _CLUSTER_NAME = "mycluster" + _SSH_PORT = 22 + + def setUp(self): + testutils.GanetiTestCase.setUp(self) + self._ssh_add_authorized_patcher = testutils \ + .patch_object(ssh, "AddAuthorizedKeys") + self._ssh_add_authorized_mock = self._ssh_add_authorized_patcher.start() + + self._ssconf_mock = mock.Mock() + self._ssconf_mock.GetNodeList = mock.Mock() + self._ssconf_mock.GetMasterNode = mock.Mock() + self._ssconf_mock.GetClusterName = mock.Mock() + + self._run_cmd_mock = mock.Mock() + + self.noded_cert_file = testutils.TestDataFilename("cert1.pem") + + def tearDown(self): + super(testutils.GanetiTestCase, self).tearDown() + self._ssh_add_authorized_patcher.stop() + + def _SetupTestData(self, number_of_nodes=15, number_of_pot_mcs=5, + number_of_mcs=5): + """Sets up consistent test data for a cluster with a couple of nodes. + + """ + self._pub_key_file = self._CreateTempFile() + self._potential_master_candidates = [] + self._ssh_port_map = {} + + self._ssconf_mock.reset_mock() + self._ssconf_mock.GetNodeList.reset_mock() + self._ssconf_mock.GetMasterNode.reset_mock() + self._ssconf_mock.GetClusterName.reset_mock() + self._run_cmd_mock.reset_mock() + + for i in range(number_of_nodes): + node_name = "node_name_%s" % i + self._potential_master_candidates.append(node_name) + self._ssh_port_map[node_name] = self._SSH_PORT + + self._all_nodes = self._potential_master_candidates[:] + for j in range(number_of_pot_mcs, number_of_nodes): + node_name = "node_name_%s" + self._all_nodes.append(node_name) + self._ssh_port_map[node_name] = self._SSH_PORT + + self._master_node = "node_name_%s" % (number_of_pot_mcs / 2) + + self._ssconf_store = self._MySsconfStore( + self._CLUSTER_NAME, self._all_nodes, self._master_node) + self._command_runner = self._MyCommandRunner( + self._CLUSTER_NAME, self._master_node, self._all_nodes, + self._potential_master_candidates, + new_node_master_candidate) + + def _TearDownTestData(self): + os.remove(self._pub_key_file) + + def _KeyReceived(self, key_data, node_name, expected_type, + expected_key): + if not node_name in key_data: + return False + for data in key_data[node_name]: + if expected_type in data: + (action, key_dict) = data[expected_type] + if action in [constants.SSHS_ADD, constants.SSHS_OVERRIDE]: + for key_list in key_dict.values(): + if expected_key in key_list: + return True + return False + + def testAddNodeSshKeyValid(self): + new_node_name = "new_node_name" + new_node_uuid = "new_node_uuid" + new_node_key1 = "new_node_key1" + new_node_key2 = "new_node_key2" + + for (to_authorized_keys, to_public_keys, get_public_keys) in \ + [(True, True, False), (False, True, False), + (True, True, True), (False, True, True)]: + + self._SetupTestData() + + # set up public key file, ssconf store, and node lists + if to_public_keys: + for key in [new_node_key1, new_node_key2]: + ssh.AddPublicKey(new_node_name, key, key_file=self._pub_key_file) + self._potential_master_candidates.append(new_node_name) + + self._ssh_port_map[new_node_name] = self._SSH_PORT + + backend.AddNodeSshKey(new_node_uuid, new_node_name, + to_authorized_keys, + to_public_keys, + get_public_keys, + self._ssh_port_map, + self._potential_master_candidates, + pub_key_file=self._pub_key_file, + ssconf_store=self._ssconf_mock, + noded_cert_file=self.noded_cert_file, + run_cmd_fn=self._run_cmd_mock) + + calls_per_node = {} + for (pos, keyword) in self._run_cmd_mock.call_args_list: + (cluster_name, node, _, _, _, _, _, _, _, data, _) = pos + if not node in calls_per_node: + calls_per_node[node] = [] + calls_per_node[node].append(data) + + # one sample node per type (master candidate, potential master candidate, + # normal node) + mc_idx = 3 + pot_mc_idx = 7 + normal_idx = 12 + sample_nodes = [mc_idx, pot_mc_idx, normal_idx] + pot_sample_nodes = [mc_idx, pot_mc_idx] + + if to_authorized_keys: + for node_idx in sample_nodes: + self.assertTrue(self._KeyReceived( + calls_per_node, "node_name_%i" % node_idx, + constants.SSHS_SSH_AUTHORIZED_KEYS, new_node_key1), + "Node %i did not receive authorized key '%s' although it should" + " have." % (node_idx, new_node_key1)) + else: + for node_idx in sample_nodes: + self.assertFalse(self._KeyReceived( + calls_per_node, "node_name_%i" % node_idx, + constants.SSHS_SSH_AUTHORIZED_KEYS, new_node_key1), + "Node %i received authorized key '%s', although it should not have." + % (node_idx, new_node_key1)) + + if to_public_keys: + for node_idx in pot_sample_nodes: + self.assertTrue(self._KeyReceived( + calls_per_node, "node_name_%i" % node_idx, + constants.SSHS_SSH_PUBLIC_KEYS, new_node_key1), + "Node %i did not receive public key '%s', although it should have." + % (node_idx, new_node_key1)) + else: + for node_idx in sample_nodes: + self.assertFalse(self._KeyReceived( + calls_per_node, "node_name_%i" % node_idx, + constants.SSHS_SSH_PUBLIC_KEYS, new_node_key1), + "Node %i did receive public key '%s', although it should have." + % (node_idx, new_node_key1)) + + if get_public_keys: + for node_idx in sample_nodes: + if node_idx in pot_sample_nodes: + self.assertTrue(self._KeyReceived( + calls_per_node, new_node_name, + constants.SSHS_SSH_PUBLIC_KEYS, "key%s" % node_idx), + "The new node '%s' did not receive public key of node %i," + " although it should have." % + (new_node_name, node_idx)) + else: + self.assertFalse(self._KeyReceived( + calls_per_node, new_node_name, + constants.SSHS_SSH_PUBLIC_KEYS, "key%s" % node_idx), + "The new node '%s' did receive public key of node %i," + " although it should not have." % + (new_node_name, node_idx)) + else: + new_node_name not in calls_per_node + + self._TearDownTestData() + + if __name__ == "__main__": testutils.GanetiTestProgram() -- 2.1.0.rc2.206.gedb03e5