As part of the work to resolve the issue:

https://github.com/autotest/autotest/issues/555

I've started to write unittests for the Env class,
and those unittests started to show some small
problems on that class.

So, create a new set of unittests for Env, and
modify it in the following ways:

1) Env._filename will always get set, to avoid
an Env object instantiated with no filename passed
to the constructor can be saved if no filename
was passed to the save() method;
2) All set/get operations will now be performed
in Env.data, for safety
3) The get_all_vms method was depending on a not
quite correct is_vm() function, that was checking
the name of the class of the object, while the
proper way to do it would be calling isinstance,
instead since we guarantee all vm keys will
start with vm__, use that instead, creating a
more correct and more testable method.

Signed-off-by: Lucas Meneghel Rodrigues <[email protected]>
---
 virttest/env_process.py         |   2 +-
 virttest/utils_misc.py          |  39 +++++------
 virttest/utils_misc_unittest.py | 146 ++++++++++++++++++++++++++++++++++++++++
 3 files changed, 167 insertions(+), 20 deletions(-)

diff --git a/virttest/env_process.py b/virttest/env_process.py
index 6f9f7da..f309fe2 100644
--- a/virttest/env_process.py
+++ b/virttest/env_process.py
@@ -298,7 +298,7 @@ def preprocess(test, params, env):
     requested_vms = params.objects("vms")
     for key in env.keys():
         vm = env[key]
-        if not utils_misc.is_vm(vm):
+        if not isinstance(vm, virt_vm.BaseVM):
             continue
         if not vm.name in requested_vms:
             vm.destroy()
diff --git a/virttest/utils_misc.py b/virttest/utils_misc.py
index 56be94c..d37cf87 100644
--- a/virttest/utils_misc.py
+++ b/virttest/utils_misc.py
@@ -63,15 +63,6 @@ def unlock_file(f):
     f.close()
 
 
-def is_vm(obj):
-    """
-    Tests whether a given object is a VM object.
-
-    @param obj: Python object.
-    """
-    return obj.__class__.__name__ == "VM"
-
-
 class NetError(Exception):
     pass
 
@@ -177,6 +168,10 @@ class DbNoLockError(NetError):
         return "Attempt made to access database with improper locking"
 
 
+class EnvSaveError(Exception):
+    pass
+
+
 class Env(UserDict.IterableUserDict):
     """
     A dict-like object containing global objects used by tests.
@@ -194,8 +189,8 @@ class Env(UserDict.IterableUserDict):
         """
         UserDict.IterableUserDict.__init__(self)
         empty = {"version": version}
+        self._filename = filename
         if filename:
-            self._filename = filename
             try:
                 if os.path.isfile(filename):
                     f = open(filename, "r")
@@ -230,6 +225,8 @@ class Env(UserDict.IterableUserDict):
                 use the filename from which the dict was loaded.
         """
         filename = filename or self._filename
+        if filename is None:
+            raise EnvSaveError("No filename specified for this env file")
         f = open(filename, "w")
         cPickle.dump(self.data, f)
         f.close()
@@ -239,7 +236,11 @@ class Env(UserDict.IterableUserDict):
         """
         Return a list of all VM objects in this Env object.
         """
-        return [o for o in self.values() if is_vm(o)]
+        vm_list = []
+        for key in self.data.keys():
+            if key.startswith("vm__"):
+                vm_list.append(self[key])
+        return vm_list
 
 
     def get_vm(self, name):
@@ -248,7 +249,7 @@ class Env(UserDict.IterableUserDict):
 
         @param name: VM name.
         """
-        return self.get("vm__%s" % name)
+        return self.data.get("vm__%s" % name)
 
 
     def register_vm(self, name, vm):
@@ -258,7 +259,7 @@ class Env(UserDict.IterableUserDict):
         @param name: VM name.
         @param vm: VM object.
         """
-        self["vm__%s" % name] = vm
+        self.data["vm__%s" % name] = vm
 
 
     def unregister_vm(self, name):
@@ -267,7 +268,7 @@ class Env(UserDict.IterableUserDict):
 
         @param name: VM name.
         """
-        del self["vm__%s" % name]
+        del self.data["vm__%s" % name]
 
 
     def register_syncserver(self, port, server):
@@ -277,7 +278,7 @@ class Env(UserDict.IterableUserDict):
         @param port: Sync Server port.
         @param server: Sync Server object.
         """
-        self["sync__%s" % port] = server
+        self.data["sync__%s" % port] = server
 
 
     def unregister_syncserver(self, port):
@@ -286,7 +287,7 @@ class Env(UserDict.IterableUserDict):
 
         @param port: Sync Server port.
         """
-        del self["sync__%s" % port]
+        del self.data["sync__%s" % port]
 
 
     def get_syncserver(self, port):
@@ -295,7 +296,7 @@ class Env(UserDict.IterableUserDict):
 
         @param port: Sync Server port.
         """
-        return self.get("sync__%s" % port)
+        return self.data.get("sync__%s" % port)
 
 
     def register_installer(self, installer):
@@ -306,14 +307,14 @@ class Env(UserDict.IterableUserDict):
         information about the installed KVM modules and qemu-kvm can be used by
         them.
         """
-        self['last_installer'] = installer
+        self.data['last_installer'] = installer
 
 
     def previous_installer(self):
         """
         Return the last installer that was registered
         """
-        return self.get('last_installer')
+        return self.data.get('last_installer')
 
 
 class Params(UserDict.IterableUserDict):
diff --git a/virttest/utils_misc_unittest.py b/virttest/utils_misc_unittest.py
index a4b5d7c..6175041 100755
--- a/virttest/utils_misc_unittest.py
+++ b/virttest/utils_misc_unittest.py
@@ -812,6 +812,152 @@ class FakeVm(object):
     def is_alive(self):
         logging.info("Fake VM %s (instance %s)", self.name, self.instance)
 
+class FakeSyncListenServer(object):
+    def __init__(self, address='', port=123, tmpdir=None):
+        self.instance = ( "%s-%s" % (
+            time.strftime("%Y%m%d-%H%M%S"),
+            utils_misc.generate_random_string(16)))
+        self.port = port
+
+    def close(self):
+        logging.info("Closing sync server (instance %s)", self.instance)
+
+
+class TestEnv(unittest.TestCase):
+    def test_save(self):
+        """
+        1) Verify that calling env.save() with no filename where env doesn't
+           specify a filename will throw an EnvSaveError.
+        2) Register a VM in environment, save env to a file, recover env from
+           that file, get the vm and verify that the instance attribute of the
+           2 objects is the same.
+        3) Register a SyncListenServer and don't save env. Restore env from
+           file and try to get the syncserver, verify it doesn't work.
+        4) Now save env to a file, restore env from file and verify that
+           the syncserver can be found there, and that the sync server
+           instance attribute is equal to the initial sync server instance.
+        """
+        fname="/dev/shm/EnvUnittest"
+        env = utils_misc.Env()
+
+        self.assertRaises(utils_misc.EnvSaveError, env.save, {})
+
+        params = utils_misc.Params({"main_vm": 'rhel7-migration'})
+        vm1 = FakeVm(params['main_vm'], params)
+        vm1.is_alive()
+        env.register_vm(params['main_vm'], vm1)
+        env.save(filename=fname)
+        env2 = utils_misc.Env(filename=fname)
+        vm2 = env2.get_vm(params['main_vm'])
+        vm2.is_alive()
+        assert vm1.instance == vm2.instance
+
+        sync1 = FakeSyncListenServer(port=222)
+        env.register_syncserver(222, sync1)
+        env3 = utils_misc.Env(filename=fname)
+        syncnone = env3.get_syncserver(222)
+        assert syncnone is None
+
+        env.save(filename=fname)
+        env4 = utils_misc.Env(filename=fname)
+        sync2 = env4.get_syncserver(222)
+        assert sync2.instance == sync1.instance
+        if os.path.isfile(fname):
+            os.unlink(fname)
+
+    def test_register_vm(self):
+        """
+        1) Create an env object.
+        2) Create a VM and register it from env.
+        3) Get the vm back from the env.
+        4) Verify that the 2 objects are the same.
+        """
+        env = utils_misc.Env()
+        params = utils_misc.Params({"main_vm": 'rhel7-migration'})
+        vm1 = FakeVm(params['main_vm'], params)
+        vm1.is_alive()
+        env.register_vm(params['main_vm'], vm1)
+        vm2 = env.get_vm(params['main_vm'])
+        vm2.is_alive()
+        assert vm1 == vm2
+
+    def test_unregister_vm(self):
+        """
+        1) Create an env object.
+        2) Register 2 vms to the env.
+        3) Verify both vms are in the env.
+        4) Remove one of those vms.
+        5) Verify that the removed vm is no longer in env.
+        """
+        env = utils_misc.Env()
+        params = utils_misc.Params({"main_vm": 'rhel7-migration'})
+        vm1 = FakeVm(params['main_vm'], params)
+        vm1.is_alive()
+        vm2 = FakeVm('vm2', params)
+        vm2.is_alive()
+        env.register_vm(params['main_vm'], vm1)
+        env.register_vm('vm2', vm2)
+        assert vm1 in env.get_all_vms()
+        assert vm2 in env.get_all_vms()
+        env.unregister_vm('vm2')
+        assert vm1 in env.get_all_vms()
+        assert vm2 not in env.get_all_vms()
+
+    def test_get_all_vms(self):
+        """
+        1) Create an env object.
+        2) Create 2 vms and register them in the env.
+        3) Create a SyncListenServer and register it in the env.
+        4) Verify that the 2 vms are in the output of get_all_vms.
+        5) Verify that the sync server is not in the output of get_all_vms.
+        """
+        env = utils_misc.Env()
+        params = utils_misc.Params({"main_vm": 'rhel7-migration'})
+        vm1 = FakeVm(params['main_vm'], params)
+        vm1.is_alive()
+        vm2 = FakeVm('vm2', params)
+        vm2.is_alive()
+        env.register_vm(params['main_vm'], vm1)
+        env.register_vm('vm2', vm2)
+        sync1 = FakeSyncListenServer(port=333)
+        env.register_syncserver(333, sync1)
+        assert vm1 in env.get_all_vms()
+        assert vm2 in env.get_all_vms()
+        assert sync1 not in env.get_all_vms()
+
+    def test_register_syncserver(self):
+        """
+        1) Create an env file.
+        2) Create a SyncListenServer object and register it in the env.
+        3) Get that SyncListenServer with get_syncserver.
+        4) Verify that both objects are the same.
+        """
+        env = utils_misc.Env()
+        sync1 = FakeSyncListenServer(port=333)
+        env.register_syncserver(333, sync1)
+        sync2 = env.get_syncserver(333)
+        assert sync1 == sync2
+
+    def test_unregister_syncserver(self):
+        """
+        1) Create an env file.
+        2) Create and register 2 SyncListenServers in the env.
+        4) Get one of the SyncListenServers in the env.
+        5) Unregister one of the SyncListenServers.
+        6) Verify that the SyncListenServer unregistered can't be retrieved
+           anymore with get_syncserver().
+        """
+        env = utils_misc.Env()
+        sync1 = FakeSyncListenServer(port=333)
+        env.register_syncserver(333, sync1)
+        sync2 = FakeSyncListenServer(port=444)
+        env.register_syncserver(444, sync2)
+        sync3 = env.get_syncserver(333)
+        assert sync1 == sync3
+        env.unregister_syncserver(444)
+        sync4 = env.get_syncserver(444)
+        assert sync4 is None
+
 
 if __name__ == '__main__':
     unittest.main()
-- 
1.7.11.4

_______________________________________________
Autotest-kernel mailing list
[email protected]
https://www.redhat.com/mailman/listinfo/autotest-kernel

Reply via email to