This patch makes prepare_node_join use some of the functions
that were moved to tools/common.py. The respective unittests
are removed, because they are already tested in
common_unittest.py.

Signed-off-by: Helga Velroyen <[email protected]>
---
 lib/tools/prepare_node_join.py                     | 41 ++---------------
 test/py/ganeti.tools.common_unittest.py            | 53 ++++++++++++++++++++++
 test/py/ganeti.tools.prepare_node_join_unittest.py | 48 --------------------
 3 files changed, 57 insertions(+), 85 deletions(-)

diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py
index 7eb1e5a..4db335f 100644
--- a/lib/tools/prepare_node_join.py
+++ b/lib/tools/prepare_node_join.py
@@ -43,10 +43,9 @@ from ganeti import constants
 from ganeti import errors
 from ganeti import pathutils
 from ganeti import utils
-from ganeti import serializer
 from ganeti import ht
 from ganeti import ssh
-from ganeti import ssconf
+from ganeti.tools import common
 
 
 _SSH_KEY_LIST_ITEM = \
@@ -89,17 +88,7 @@ def ParseOptions():
 
   (opts, args) = parser.parse_args()
 
-  return VerifyOptions(parser, opts, args)
-
-
-def VerifyOptions(parser, opts, args):
-  """Verifies options and arguments for correctness.
-
-  """
-  if args:
-    parser.error("No arguments are expected")
-
-  return opts
+  return common.VerifyOptions(parser, opts, args)
 
 
 def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate):
@@ -137,19 +126,6 @@ def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
     _verify_fn(cert)
 
 
-def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
-  """Verifies cluster name.
-
-  @type data: dict
-
-  """
-  name = data.get(constants.SSHS_CLUSTER_NAME)
-  if name:
-    _verify_fn(name)
-  else:
-    raise JoinError("Cluster name must be specified")
-
-
 def _UpdateKeyFiles(keys, dry_run, keyfiles):
   """Updates SSH key files.
 
@@ -238,15 +214,6 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None):
       utils.AddAuthorizedKey(auth_keys_file, public_key)
 
 
-def LoadData(raw):
-  """Parses and verifies input data.
-
-  @rtype: dict
-
-  """
-  return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
-
-
 def Main():
   """Main routine.
 
@@ -256,10 +223,10 @@ def Main():
   utils.SetupToolLogging(opts.debug, opts.verbose)
 
   try:
-    data = LoadData(sys.stdin.read())
+    data = common.LoadData(sys.stdin.read(), _DATA_CHECK)
 
     # Check if input data is correct
-    VerifyClusterName(data)
+    common.VerifyClusterName(data, JoinError)
     VerifyCertificate(data)
 
     # Update SSH files
diff --git a/test/py/ganeti.tools.common_unittest.py 
b/test/py/ganeti.tools.common_unittest.py
index 1652088..427b851 100755
--- a/test/py/ganeti.tools.common_unittest.py
+++ b/test/py/ganeti.tools.common_unittest.py
@@ -38,6 +38,8 @@ import OpenSSL
 import time
 
 from ganeti import constants
+from ganeti import errors
+from ganeti import serializer
 from ganeti import utils
 from ganeti.tools import common
 
@@ -78,5 +80,56 @@ class TestGenerateClientCert(unittest.TestCase):
     self.assertEqual(client_cert.get_subject().CN, my_node_name)
 
 
+class TestLoadData(unittest.TestCase):
+
+  def testNoJson(self):
+    self.assertRaises(errors.ParseError, common.LoadData, Exception, "")
+    self.assertRaises(errors.ParseError, common.LoadData, Exception, "}")
+
+  def testInvalidDataStructure(self):
+    raw = serializer.DumpJson({
+      "some other thing": False,
+      })
+    self.assertRaises(errors.ParseError, common.LoadData, Exception, raw)
+
+    raw = serializer.DumpJson([])
+    self.assertRaises(errors.ParseError, common.LoadData, Exception, raw)
+
+  def testValidData(self):
+    raw = serializer.DumpJson({})
+    self.assertEqual(common.LoadData(raw, Exception), {})
+
+
+class TestVerifyClusterName(unittest.TestCase):
+
+  class MyException(Exception):
+    pass
+
+  def setUp(self):
+    unittest.TestCase.setUp(self)
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    unittest.TestCase.tearDown(self)
+    shutil.rmtree(self.tmpdir)
+
+  def testNoName(self):
+    self.assertRaises(self.MyException, common.VerifyClusterName,
+                      {}, self.MyException, _verify_fn=NotImplemented)
+
+  @staticmethod
+  def _FailingVerify(name):
+    assert name == "cluster.example.com"
+    raise errors.GenericError()
+
+  def testFailingVerification(self):
+    data = {
+      constants.SSHS_CLUSTER_NAME: "cluster.example.com",
+      }
+
+    self.assertRaises(errors.GenericError, common.VerifyClusterName,
+                      data, Exception, _verify_fn=self._FailingVerify)
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
diff --git a/test/py/ganeti.tools.prepare_node_join_unittest.py 
b/test/py/ganeti.tools.prepare_node_join_unittest.py
index e0c60a4..92cb1de 100755
--- a/test/py/ganeti.tools.prepare_node_join_unittest.py
+++ b/test/py/ganeti.tools.prepare_node_join_unittest.py
@@ -34,11 +34,9 @@ import unittest
 import shutil
 import tempfile
 import os.path
-import OpenSSL
 
 from ganeti import errors
 from ganeti import constants
-from ganeti import serializer
 from ganeti import pathutils
 from ganeti import compat
 from ganeti import utils
@@ -50,25 +48,6 @@ import testutils
 _JoinError = prepare_node_join.JoinError
 
 
-class TestLoadData(unittest.TestCase):
-  def testNoJson(self):
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
-
-  def testInvalidDataStructure(self):
-    raw = serializer.DumpJson({
-      "some other thing": False,
-      })
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
-
-    raw = serializer.DumpJson([])
-    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
-
-  def testValidData(self):
-    raw = serializer.DumpJson({})
-    self.assertEqual(prepare_node_join.LoadData(raw), {})
-
-
 class TestVerifyCertificate(testutils.GanetiTestCase):
   def setUp(self):
     testutils.GanetiTestCase.setUp(self)
@@ -104,33 +83,6 @@ class TestVerifyCertificate(testutils.GanetiTestCase):
     prepare_node_join._VerifyCertificate(cert_pem, _check_fn=self._Check)
 
 
-class TestVerifyClusterName(unittest.TestCase):
-  def setUp(self):
-    unittest.TestCase.setUp(self)
-    self.tmpdir = tempfile.mkdtemp()
-
-  def tearDown(self):
-    unittest.TestCase.tearDown(self)
-    shutil.rmtree(self.tmpdir)
-
-  def testNoName(self):
-    self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
-                      {}, _verify_fn=NotImplemented)
-
-  @staticmethod
-  def _FailingVerify(name):
-    assert name == "cluster.example.com"
-    raise errors.GenericError()
-
-  def testFailingVerification(self):
-    data = {
-      constants.SSHS_CLUSTER_NAME: "cluster.example.com",
-      }
-
-    self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
-                      data, _verify_fn=self._FailingVerify)
-
-
 class TestUpdateSshDaemon(unittest.TestCase):
   def setUp(self):
     unittest.TestCase.setUp(self)
-- 
2.4.3.573.g4eafbef

Reply via email to