Author: darcy
Date: Sat Jan  5 10:19:07 2013
New Revision: 486

Log:
Open database in every test to assure that there are no interactions.
Add tests for pgnotify.

Modified:
   trunk/module/TEST_PyGreSQL_classic.py

Modified: trunk/module/TEST_PyGreSQL_classic.py
==============================================================================
--- trunk/module/TEST_PyGreSQL_classic.py       Sat Jan  5 10:07:57 2013        
(r485)
+++ trunk/module/TEST_PyGreSQL_classic.py       Sat Jan  5 10:19:07 2013        
(r486)
@@ -2,6 +2,7 @@
 
 from __future__ import with_statement
 
+import sys, thread, time
 import unittest
 from pg import *
 
@@ -16,17 +17,27 @@
 except ImportError:
     pass
 
-db = DB(dbname, dbhost, dbport)
-db.query("SET DATESTYLE TO 'ISO'")
-db.query("SET TIME ZONE 'EST5EDT'")
-db.query("SET DEFAULT_WITH_OIDS=TRUE")
-db.query("SET STANDARD_CONFORMING_STRINGS=FALSE")
-
+def opendb():
+    db = DB(dbname, dbhost, dbport)
+    db.query("SET DATESTYLE TO 'ISO'")
+    db.query("SET TIME ZONE 'EST5EDT'")
+    db.query("SET DEFAULT_WITH_OIDS=TRUE")
+    db.query("SET STANDARD_CONFORMING_STRINGS=FALSE")
+    return db
+
+def cb1(arg_dict):
+    global cb1_return
+    if arg_dict is None:
+        cb1_return = 'timed out'
+    else:
+        cb1_return = arg_dict
 
 class UtilityTest(unittest.TestCase):
 
     def setUp(self):
         """Setup test tables or empty them if they already exist."""
+        db = opendb()
+
         for t in ('_test1', '_test2'):
             try:
                 db.query("CREATE SCHEMA " + t)
@@ -50,10 +61,12 @@
 
     def test_invalidname(self):
         """Make sure that invalid table names are caught"""
+        db = opendb()
         self.failUnlessRaises(ProgrammingError, db.get_attnames, 'x.y.z')
 
     def test_schema(self):
         """Does it differentiate the same table name in different schemas"""
+        db = opendb()
         # see if they differentiate the table names properly
         self.assertEqual(
             db.get_attnames('_test_schema'),
@@ -73,6 +86,7 @@
         )
 
     def test_pkey(self):
+        db = opendb()
         self.assertEqual(db.pkey('_test_schema'), '_test')
         self.assertEqual(db.pkey('public._test_schema'), '_test')
         self.assertEqual(db.pkey('_test1._test_schema'), '_test1')
@@ -84,6 +98,7 @@
         self.assertEqual(db.pkey('public.test1'), 'a')
 
     def test_get(self):
+        db = opendb()
         db.query("INSERT INTO _test_schema VALUES (1234)")
         db.get('_test_schema', 1234)
         db.get('_test_schema', 1234, keyname='_test')
@@ -91,11 +106,13 @@
         db.get('_test_vschema', 1234, keyname='_test')
 
     def test_params(self):
+        db = opendb()
         db.query("INSERT INTO _test_schema VALUES ($1, $2, $3)", 12, None, 34)
         d = db.get('_test_schema', 12)
         self.assertEqual(d['dvar'], 34)
 
     def test_insert(self):
+        db = opendb()
         d = dict(_test=1234)
         db.insert('_test_schema', d)
         self.assertEqual(d['dvar'], 999)
@@ -103,6 +120,7 @@
         self.assertEqual(d['dvar'], 999)
 
     def test_context_manager(self):
+        db = opendb()
         t = '_test_schema'
         d = dict(_test=1235)
         with db:
@@ -128,6 +146,7 @@
         self.assertTrue(db.get(t, 1239))
 
     def test_sqlstate(self):
+        db = opendb()
         db.query("INSERT INTO _test_schema VALUES (1234)")
         try:
             db.query("INSERT INTO _test_schema VALUES (1234)")
@@ -138,6 +157,7 @@
             self.assertEqual(error.sqlstate, '23505')
 
     def test_mixed_case(self):
+        db = opendb()
         try:
             db.query('CREATE TABLE _test_mc ("_Test" int PRIMARY KEY)')
         except Error:
@@ -146,6 +166,7 @@
         db.insert('_test_mc', d)
 
     def test_update(self):
+        db = opendb()
         db.query("INSERT INTO _test_schema VALUES (1234)")
 
         r = db.get('_test_schema', 1234)
@@ -165,6 +186,7 @@
         self.assertEqual(r['dvar'], 456)
 
     def test_quote(self):
+        db = opendb()
         q = db._quote
         self.assertEqual(q(0, 'int'), "0")
         self.assertEqual(q(0, 'num'), "0")
@@ -204,6 +226,79 @@
         self.assertEqual(q("'", 'text'), "''''")
         self.assertEqual(q("\\", 'text'), "'\\\\'")
 
+    # note that notify can be created as part of the DB class or
+    # independently.
+
+    def test_notify_DB(self):
+        global cb1_return
+        
+        db = opendb()
+        db2 = opendb()
+        # Listen for 'event_1'
+        pgn = db2.pgnotify('event_1', cb1)
+        thread.start_new_thread(pgn, ())
+        time.sleep(1)
+        # Generate notification from the other connection.
+        db.query('notify event_1')
+        time.sleep(1)
+        # Check that callback has been invoked.
+        self.assertEquals(cb1_return['event'], 'event_1')
+
+    def test_notify_timeout_DB(self):
+        db = opendb()
+        db2 = opendb()
+        global cb1_return
+        # Listen for 'event_1'
+        pgn = db2.pgnotify('event_1', cb1, {}, 1)
+        thread.start_new_thread(pgn, ())
+        # Sleep long enough to time out.
+        time.sleep(2)
+        # Verify that we've indeed timed out.
+        self.assertEquals(cb1_return, 'timed out')
+
+    def test_notify(self):
+        db = opendb()
+        db2 = opendb()
+        global cb1_return
+        # Listen for 'event_1'
+        pgn = pgnotify(db2, 'event_1', cb1)
+        thread.start_new_thread(pgn, ())
+        time.sleep(1)
+        # Generate notification from the other connection.
+        db.query('notify event_1')
+        time.sleep(1)
+        # Check that callback has been invoked.
+        self.assertEquals(cb1_return['event'], 'event_1')
+
+    def test_notify_timeout(self):
+        db = opendb()
+        db2 = opendb()
+        global cb1_return
+        # Listen for 'event_1'
+        pgn = pgnotify(db2, 'event_1', cb1, {}, 1)
+        thread.start_new_thread(pgn, ())
+        # Sleep long enough to time out.
+        time.sleep(2)
+        # Verify that we've indeed timed out.
+        self.assertEquals(cb1_return, 'timed out')
 
 if __name__ == '__main__':
-    unittest.main()
+    suite = unittest.TestSuite()
+    
+    if len(sys.argv) > 1: test_list = sys.argv[1:]
+    else: test_list = unittest.getTestCaseNames(UtilityTest, 'test_')
+
+    if len(sys.argv) == 2 and sys.argv[1] == '-l':
+        print '\n'.join(unittest.getTestCaseNames(UtilityTest, 'test_'))
+        sys.exit(1)
+
+    for test_name in test_list:
+        try:
+            suite.addTest(UtilityTest(test_name))
+        except:
+            print "\n ERROR: %s.\n" % sys.exc_value
+            sys.exit(1)
+
+    rc = unittest.TextTestRunner(verbosity=1).run(suite)
+    sys.exit(len(rc.errors+rc.failures) != 0)
+
_______________________________________________
PyGreSQL mailing list
PyGreSQL@Vex.Net
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to