This is an additional patch to the SSH patch series which simplifies the handling of public SSH keys by using the utility function WriteFile as often as possible. As it is a mess to merge it back into the series, I am sending this as an additional patch at the end of the series.
Signed-off-by: Helga Velroyen <hel...@google.com> --- lib/ssh.py | 153 ++++++++++++++++++++++++++----------------------------------- 1 file changed, 66 insertions(+), 87 deletions(-) diff --git a/lib/ssh.py b/lib/ssh.py index a33bf99..78f2120 100644 --- a/lib/ssh.py +++ b/lib/ssh.py @@ -27,7 +27,6 @@ import logging import os import tempfile -import stat from functools import partial @@ -259,8 +258,7 @@ def RemoveAuthorizedKey(file_name, key): RemoveAuthorizedKeys(file_name, [key]) -def _AddPublicKeyProcessLine(new_uuid, new_key, line_uuid, line_key, tmp_file, - found): +def _AddPublicKeyProcessLine(new_uuid, new_key, line_uuid, line_key, found): """Processes one line of the public key file when adding a key. This is a sub function that can be called within the @@ -276,25 +274,20 @@ def _AddPublicKeyProcessLine(new_uuid, new_key, line_uuid, line_key, tmp_file, is processed in this function call @param line_key: the SSH key of the node whose line in the public key file is processed in this function call - @type tmp_file: file descriptor - @param tmp_file: the temporary file to which the manipulated public key - file is written to, before replacing the original public key file - automically @type found: boolean @param found: whether or not the (UUID, key) pair of the node whose key is being added was found in the public key file already. - @rtype: boolean - @return: a possibly updated value of C{found} + @rtype: (boolean, string) + @return: a possibly updated value of C{found} and the processed line """ if line_uuid == new_uuid and line_key == new_key: logging.debug("SSH key of node '%s' already in key file.", new_uuid) found = True - tmp_file.write("%s %s\n" % (line_uuid, line_key)) - return found + return (found, "%s %s\n" % (line_uuid, line_key)) -def _AddPublicKeyElse(new_uuid, new_key, tmp_file): +def _AddPublicKeyElse(new_uuid, new_key): """Adds a new SSH key to the key file if it did not exist already. This is an auxiliary function for C{_ManipulatePublicKeyFile} which @@ -306,16 +299,16 @@ def _AddPublicKeyElse(new_uuid, new_key, tmp_file): @param new_uuid: the UUID of the node whose key is added @type new_key: string @param new_key: the SSH key to be added - @type tmp_file: file descriptor - @param tmp_file: the file where the key is appended + @rtype: string + @return: a new line to be added to the file """ - tmp_file.write("%s %s\n" % (new_uuid, new_key)) + return "%s %s\n" % (new_uuid, new_key) def _RemovePublicKeyProcessLine( target_uuid, _target_key, - line_uuid, line_key, tmp_file, found): + line_uuid, line_key, found): """Processes a line in the public key file when aiming for removing a key. This is an auxiliary function for C{_ManipulatePublicKeyFile} when we @@ -331,22 +324,21 @@ def _RemovePublicKeyProcessLine( @param line_uuid: UUID of the node whose line is processed in this call @type line_key: string @param line_key: SSH key of the nodes whose line is processed in this call - @type tmp_file: file descriptor - @param tmp_file: temporary file which eventually replaces the ganeti public - key file @type found: boolean @param found: whether or not the UUID was already found. + @rtype: (boolean, string) + @return: a tuple, indicating if the target line was found and the processed + line; the line is 'None', if the original line is removed """ if line_uuid != target_uuid: - tmp_file.write("%s %s\n" % (line_uuid, line_key)) - return found + return (found, "%s %s\n" % (line_uuid, line_key)) else: - return True + return (True, None) def _RemovePublicKeyElse( - target_uuid, _target_key, _tmp_file): + target_uuid, _target_key): """Logs when we tried to remove a key that does not exist. This is an auxiliary function for C{_ManipulatePublicKeyFile} which is @@ -358,18 +350,17 @@ def _RemovePublicKeyElse( @type _target_key: string @param _target_key: the key of the node which was supposed to be removed (not used) - @type _tmp_file: file descriptor - @param _tmp_file: the temporary file which eventually will replace the public - key file (not used) + @rtype: string + @return: in this case, always None """ logging.debug("Trying to remove key of node '%s' which is not in list" " of public keys.", target_uuid) + return None def _ReplaceNameByUuidProcessLine( - node_name, _key, line_identifier, line_key, tmp_file, found, - node_uuid=None): + node_name, _key, line_identifier, line_key, found, node_uuid=None): """Replaces a node's name with its UUID on a matching line in the key file. This is an auxiliary function for C{_ManipulatePublicKeyFile} which processes @@ -386,25 +377,23 @@ def _ReplaceNameByUuidProcessLine( got replaced already or not. @type line_key: string @param line_key: SSH key of the node whose line is processed - @type tmp_file: file descriptor - @param tmp_file: temporary file which will eventually replace the public - key file @type found: boolean @param found: whether or not the line matches the node's name @type node_uuid: string @param node_uuid: the node's UUID which will replace the node name + @rtype: (boolean, string) + @return: a tuple indicating whether the target line was found and the + processed line """ if node_name == line_identifier: - found = True - tmp_file.write("%s %s\n" % (node_uuid, line_key)) + return (True, "%s %s\n" % (node_uuid, line_key)) else: - tmp_file.write("%s %s\n" % (line_identifier, line_key)) - return found + return (found, "%s %s\n" % (line_identifier, line_key)) def _ReplaceNameByUuidElse( - node_uuid, node_name, _key, _tmp_file): + node_uuid, node_name, _key): """Logs a debug message when we try to replace a key that is not there. This is an implementation of the auxiliary C{process_else_fn} function for @@ -418,13 +407,13 @@ def _ReplaceNameByUuidElse( @param node_name: the node's UUID @type _key: string (not used) @param _key: the node's SSH key (not used) - @type _tmp_file: file descriptor - @param _tmp_file: temporary file for manipulating the public key file - (not used) + @rtype: string + @return: in this case, always None """ logging.debug("Trying to replace node name '%s' with UUID '%s', but" " no line with that name was found.", node_name, node_uuid) + return None def _ParseKeyLine(line, error_fn): @@ -459,18 +448,16 @@ def _ManipulatePubKeyFile(target_identifier, target_key, This is a general function to manipulate the public key file. It needs two auxiliary functions C{process_line_fn} and C{process_else_fn} to work. Generally, the public key file is processed as follows: - 1) A temporary file is opened to write the content of the ganeti public key - file to (possibly with changes). 2) The function processes each line of the original ganeti public key file, - applies the C{process_line_fn} function on it, which possibly writes the - original line, a changed line or no line to the temporary file. If - the return value of the C{process_line_fn} function is True, it will - be recorded in the 'found' variable for later use. + applies the C{process_line_fn} function on it, which returns a possibly + manipulated line and an indicator whether the line in question was found. + If a line is returned, it is added to a list of lines for later writing + to the file. 3) If all lines are processed and the 'found' variable is False, the seconds auxiliary function C{process_else_fn} is called to possibly - add more lines to the temporary file. - 4) Finally, the temporary file is written to disk and moved to the original - files name to ensure atomic writing. + add more lines to the list of lines. + 4) Finally, the list of lines is assembled to a string and written + atomically to the public key file, thereby overriding it. @type target_identifier: str @param target_identifier: identifier of the node whose key is added; in most @@ -494,31 +481,31 @@ def _ManipulatePubKeyFile(target_identifier, target_key, assert process_else_fn is not None assert process_line_fn is not None - fd_tmp, tmpname = tempfile.mkstemp(dir=os.path.dirname(key_file)) + old_lines = [] try: - f_tmp = os.fdopen(fd_tmp, "w") - try: - f_orig = open(key_file, "r") - try: - found = False - for line in f_orig: - (uuid, key) = _ParseKeyLine(line, error_fn) - if not uuid: - continue - if process_line_fn(target_identifier, target_key, uuid, - key, f_tmp, found): - found = True - if not found: - process_else_fn(target_identifier, target_key, f_tmp) - f_tmp.flush() - os.rename(tmpname, key_file) - finally: - f_orig.close() - finally: - f_tmp.close() - except: - utils.RemoveFile(tmpname) - raise + f_orig = open(key_file, "r") + old_lines = f_orig.readlines() + finally: + f_orig.close() + + found = False + new_lines = [] + for line in old_lines: + (uuid, key) = _ParseKeyLine(line, error_fn) + if not uuid: + continue + (new_found, new_line) = process_line_fn(target_identifier, target_key, + uuid, key, found) + if new_found: + found = True + if new_line is not None: + new_lines.append(new_line) + if not found: + new_line = process_else_fn(target_identifier, target_key) + if new_line is not None: + new_lines.append(new_line) + new_file_content = "".join(new_lines) + utils.WriteFile(key_file, data=new_file_content) def AddPublicKey(new_uuid, new_key, key_file=pathutils.SSH_PUB_KEYS, @@ -578,27 +565,19 @@ def ClearPubKeyFile(key_file=pathutils.SSH_PUB_KEYS, mode=0600): utils.WriteFile(key_file, data="", mode=mode) -def OverridePubKeyFile(key_map, key_file=pathutils.SSH_PUB_KEYS, - error_fn=errors.ProgrammerError): +def OverridePubKeyFile(key_map, key_file=pathutils.SSH_PUB_KEYS): """Overrides the public key file with a list of given keys. @type key_map: dict from str to list of str @param key_map: dictionary mapping uuids to lists of SSH keys """ - try: - fd_tmp, tmpname = tempfile.mkstemp(dir=os.path.dirname(key_file)) - f_tmp = os.fdopen(fd_tmp, "w") - for (uuid, keys) in key_map.items(): - for key in keys: - f_tmp.write("%s %s\n" % (uuid, key)) - f_tmp.flush() - os.rename(tmpname, key_file) - os.chmod(key_file, stat.S_IRUSR | stat.S_IWUSR) - except IOError, e: - raise error_fn("Cannot override key file due to error '%s'" % e) - finally: - f_tmp.close() + new_lines = [] + for (uuid, keys) in key_map.items(): + for key in keys: + new_lines.append("%s %s\n" % (uuid, key)) + new_file_content = "".join(new_lines) + utils.WriteFile(key_file, data=new_file_content) def QueryPubKeyFile(target_uuids, key_file=pathutils.SSH_PUB_KEYS, -- 2.1.0.rc2.206.gedb03e5