Author: cito
Date: Tue Nov  6 14:37:49 2012
New Revision: 464

Log:
Implement context manager in pg similar to pgdb.

Modified:
   trunk/module/TEST_PyGreSQL_classic.py
   trunk/module/pg.py
   trunk/module/test_pg.py

Modified: trunk/module/TEST_PyGreSQL_classic.py
==============================================================================
--- trunk/module/TEST_PyGreSQL_classic.py       Tue Nov  6 10:31:03 2012        
(r463)
+++ trunk/module/TEST_PyGreSQL_classic.py       Tue Nov  6 14:37:49 2012        
(r464)
@@ -1,5 +1,7 @@
 #!/usr/bin/env python
 
+from __future__ import with_statement
+
 import unittest
 from pg import *
 
@@ -100,6 +102,31 @@
         db.insert('_test_schema', _test=1235)
         self.assertEqual(d['dvar'], 999)
 
+    def test_context_manager(self):
+        t = '_test_schema'
+        d = dict(_test=1235)
+        with db:
+            db.insert(t, d)
+            d['_test'] += 1
+            db.insert(t, d)
+        try:
+            with db:
+                d['_test'] += 1
+                db.insert(t, d)
+                db.insert(t, d)
+        except ProgrammingError:
+            pass
+        with db:
+            d['_test'] += 1
+            db.insert(t, d)
+            d['_test'] += 1
+            db.insert(t, d)
+        self.assertTrue(db.get(t, 1235))
+        self.assertTrue(db.get(t, 1236))
+        self.assertRaises(DatabaseError, db.get, t, 1237)
+        self.assertTrue(db.get(t, 1238))
+        self.assertTrue(db.get(t, 1239))
+
     def test_sqlstate(self):
         db.query("INSERT INTO _test_schema VALUES (1234)")
         try:

Modified: trunk/module/pg.py
==============================================================================
--- trunk/module/pg.py  Tue Nov  6 10:31:03 2012        (r463)
+++ trunk/module/pg.py  Tue Nov  6 14:37:49 2012        (r464)
@@ -165,7 +165,6 @@
         self._attnames = {}
         self._pkeys = {}
         self._privileges = {}
-        self._transaction = False
         self._args = args, kw
         self.debug = None  # For debugging scripts, this can be set
             # * to a string format specification (e.g. in CGI set to "%s<BR>"),
@@ -183,19 +182,16 @@
     # Context manager methods
 
     def __enter__(self):
-        if self._transaction:
-            self.begin()
+        """Enter the runtime context. This will start a transaction."""
+        self.begin()
         return self
 
-    def __exit__(self, typ, val, tb):
-        if self._transaction:
-            self._transaction = False
-            if tb is None:
-                self.commit()
-            else:
-                self.rollback()
+    def __exit__(self, et, ev, tb):
+        """Exit the runtime context. This will end the transaction."""
+        if et is None and ev is None and tb is None:
+            self.commit()
         else:
-            self.close()
+            self.rollback()
 
     # Auxiliary methods
 
@@ -373,7 +369,7 @@
         qstr = 'ROLLBACK'
         if name:
             qstr += ' TO ' + name
-        return self.query('ROLLBACK')
+        return self.query(qstr)
 
     def savepoint(self, name=None):
         """Define a new savepoint within the current transaction."""
@@ -386,12 +382,6 @@
         """Destroy a previously defined savepoint."""
         return self.query('RELEASE ' + name)
 
-    @property
-    def transaction(self):
-        """Return a context manager for running a transaction."""
-        self._transaction = True
-        return self
-
     def query(self, qstr, *args):
         """Executes a SQL command string.
 

Modified: trunk/module/test_pg.py
==============================================================================
--- trunk/module/test_pg.py     Tue Nov  6 10:31:03 2012        (r463)
+++ trunk/module/test_pg.py     Tue Nov  6 14:37:49 2012        (r464)
@@ -646,12 +646,13 @@
         self.assertEqual(r, '5')
 
     def testPrint(self):
+        import os
+        import sys
         q = "select 1 as a, 'hello' as h, 'w' as world" \
             " union select 2, 'xyz', 'uvw'"
         r = self.c.query(q)
         t = '~test_pg_testPrint_temp.tmp'
         s = open(t, 'w')
-        import os, sys
         stdout, sys.stdout = sys.stdout, s
         try:
             print r
@@ -685,7 +686,7 @@
             self.assertTrue(self.c.getnotify() is None)
             try:
                 self.c.query("notify test_notify, 'test_payload'")
-            except pg.ProgrammingError: # PostgreSQL < 9.0
+            except pg.ProgrammingError:  # PostgreSQL < 9.0
                 pass
             else:
                 r = self.c.getnotify()
@@ -724,10 +725,10 @@
         query = self.c.query
         self.assertEqual(query("select 1+1").getresult(), [(2,)])
         self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
-        self.assertEqual(query("select 1+$1", [1,]).getresult(), [(2,)])
+        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
         self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
-        self.assertEqual(query("select $1::text", (2,) ).getresult(), [('2',)])
-        self.assertEqual(query("select 1+$1::numeric", [1,]).getresult(),
+        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
+        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
             [(Decimal('2'),)])
         self.assertEqual(query("select 1, $1::integer", (2,)
             ).getresult(), [(1, 2)])
@@ -852,7 +853,7 @@
     def testInserttableMultipleCalls(self):
         num_rows = 10
         data = [(1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1")]
-        for i in range(num_rows):
+        for _i in range(num_rows):
             self.c.inserttable("test", data)
         r = self.c.query("select count(*) from test").getresult()[0][0]
         self.assertEqual(r, num_rows)
@@ -864,8 +865,8 @@
         self.assertEqual(r, data)
 
     def testInserttableMaxValues(self):
-        data = [(2**15 - 1, int(2**31 - 1), long(2**31 - 1),
-            None, 1.0 + 1.0/32, 1.0 + 1.0/32, None,
+        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
+            None, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
             "1234", "1234", "1234" * 10)]
         self.c.inserttable("test", data)
         r = self.c.query("select * from test").getresult()
@@ -915,12 +916,14 @@
                 $$ language plpgsql''')
             try:
                 received = {}
+
                 def notice_receiver(notice):
                     for attr in dir(notice):
                         value = getattr(notice, attr)
                         if isinstance(value, str):
                             value = value.replace('WARNUNG', 'WARNING')
                         received[attr] = value
+
                 c.set_notice_receiver(notice_receiver)
                 c.query('''select bilbo_notice()''')
                 self.assertEqual(received, dict(
@@ -945,14 +948,15 @@
         self.db.close()
 
     def testAllDBAttributes(self):
-        attributes = '''cancel clear close db dbname debug delete endcopy
-            error escape_bytea escape_identifier escape_literal escape_string
-            fileno get get_attnames get_databases get_notice_receiver
-            get_relations get_tables getline getlo getnotify
-            has_table_privilege host insert inserttable locreate loimport
-            options parameter pkey port protocol_version putline query
-            reopen reset server_version set_notice_receiver source status
-            transaction tty unescape_bytea update use_regtypes user'''.split()
+        attributes = '''begin cancel clear close commit db dbname debug delete
+            end endcopy error escape_bytea escape_identifier escape_literal
+            escape_string fileno get get_attnames get_databases
+            get_notice_receiver get_relations get_tables getline getlo
+            getnotify has_table_privilege host insert inserttable locreate
+            loimport options parameter pkey port protocol_version putline query
+            release reopen reset rollback savepoint server_version
+            set_notice_receiver source start status transaction tty
+            unescape_bytea update use_regtypes user'''.split()
         db_attributes = [a for a in dir(self.db)
             if not a.startswith('_')]
         self.assertEqual(attributes, db_attributes)
@@ -1145,8 +1149,10 @@
         self.assertEqual(self.db.db, db.db)
         db = pg.DB(db=self.db.db)
         self.assertEqual(self.db.db, db.db)
+
         class DB2:
             pass
+
         db2 = DB2()
         db2._cnx = self.db.db
         db = pg.DB(db2)
@@ -1337,7 +1343,7 @@
     def testQueryWithParams(self):
         smart_ddl(self.db, "drop table test_table")
         q = "create table test_table (n1 integer, n2 integer) with oids"
-        r = self.db.query(q)
+        self.db.query(q)
         q = "insert into test_table values ($1, $2)"
         r = self.db.query(q, (1, 2))
         self.assertTrue(isinstance(r, int))
@@ -1486,7 +1492,7 @@
                 "n integer, t text)" % table)
             for n, t in enumerate('xyz'):
                 self.db.query('insert into "%s" values('
-                    "%d, '%s')" % (table, n+1, t))
+                    "%d, '%s')" % (table, n + 1, t))
             self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
             r = self.db.get(table, 2, 'n')
             oid_table = table
@@ -1526,7 +1532,7 @@
             "n integer, t text, primary key (n))" % table)
         for n, t in enumerate('abc'):
             self.db.query("insert into %s values("
-                "%d, '%s')" % (table, n+1, t))
+                "%d, '%s')" % (table, n + 1, t))
         self.assertEqual(self.db.get(table, 2)['t'], 'b')
         table = 'get_test_table_2'
         smart_ddl(self.db, "drop table %s" % table)
@@ -1534,9 +1540,9 @@
             "n integer, m integer, t text, primary key (n, m))" % table)
         for n in range(3):
             for m in range(2):
-                t = chr(ord('a') + 2*n + m)
+                t = chr(ord('a') + 2 * n + m)
                 self.db.query("insert into %s values("
-                    "%d, %d, '%s')" % (table, n+1, m+1, t))
+                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
         self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
         self.assertEqual(self.db.get(table, dict(n=2, m=2))['t'], 'd')
         self.assertEqual(self.db.get(table, dict(n=1, m=2),
@@ -1560,9 +1566,10 @@
                 "d numeric, f4 real, f8 double precision, m money, "
                 "v4 varchar(4), c4 char(4), t text,"
                 "b boolean, ts timestamp)" % table)
-            data = dict(i2=2**15 - 1, i4=int(2**31 - 1), i8=long(2**31 - 1),
+            data = dict(
+                i2=2 ** 15 - 1, i4=int(2 ** 31 - 1), i8=long(2 ** 31 - 1),
                 d=Decimal('123456789.9876543212345678987654321'),
-                f4=1.0 + 1.0/32, f8 = 1.0 + 1.0/32,
+                f4=1.0 + 1.0 / 32, f8=1.0 + 1.0 / 32,
                 m="1234.56", v4="1234", c4="1234",  t="1234" * 10,
                 b=1, ts='2012-12-21')
             r = self.db.insert(table, data)
@@ -1585,7 +1592,7 @@
                 "n integer, t text)" % table)
             for n, t in enumerate('xyz'):
                 self.db.query('insert into "%s" values('
-                    "%d, '%s')" % (table, n+1, t))
+                    "%d, '%s')" % (table, n + 1, t))
             self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
             r = self.db.get(table, 2, 'n')
             r['t'] = 'u'
@@ -1602,7 +1609,7 @@
             "n integer, t text, primary key (n))" % table)
         for n, t in enumerate('abc'):
             self.db.query("insert into %s values("
-                "%d, '%s')" % (table, n+1, t))
+                "%d, '%s')" % (table, n + 1, t))
         self.assertRaises(pg.ProgrammingError, self.db.update,
             table, dict(t='b'))
         self.assertEqual(self.db.update(table, dict(n=2, t='d'))['t'], 'd')
@@ -1615,9 +1622,9 @@
             "n integer, m integer, t text, primary key (n, m))" % table)
         for n in range(3):
             for m in range(2):
-                t = chr(ord('a') + 2*n +m)
+                t = chr(ord('a') + 2 * n + m)
                 self.db.query("insert into %s values("
-                    "%d, %d, '%s')" % (table, n+1, m+1, t))
+                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
         self.assertRaises(pg.ProgrammingError, self.db.update,
             table, dict(n=2, t='b'))
         self.assertEqual(self.db.update(table,
@@ -1649,7 +1656,7 @@
                 "n integer, t text)" % table)
             for n, t in enumerate('xyz'):
                 self.db.query('insert into "%s" values('
-                    "%d, '%s')" % (table, n+1, t))
+                    "%d, '%s')" % (table, n + 1, t))
             self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
             r = self.db.get(table, 1, 'n')
             s = self.db.delete(table, r)
@@ -1678,7 +1685,7 @@
             "n integer, t text, primary key (n))" % table)
         for n, t in enumerate('abc'):
             self.db.query("insert into %s values("
-                "%d, '%s')" % (table, n+1, t))
+                "%d, '%s')" % (table, n + 1, t))
         self.assertRaises(pg.ProgrammingError, self.db.delete,
             table, dict(t='b'))
         self.assertEqual(self.db.delete(table, dict(n=2)), 1)
@@ -1695,9 +1702,9 @@
             "n integer, m integer, t text, primary key (n, m))" % table)
         for n in range(3):
             for m in range(2):
-                t = chr(ord('a') + 2*n +m)
+                t = chr(ord('a') + 2 * n + m)
                 self.db.query("insert into %s values("
-                    "%d, %d, '%s')" % (table, n+1, m+1, t))
+                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
         self.assertRaises(pg.ProgrammingError, self.db.delete,
             table, dict(n=2, t='b'))
         self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
@@ -1713,6 +1720,64 @@
             ' order by m' % table).getresult()]
         self.assertEqual(r, ['f'])
 
+    def testTransaction(self):
+        smart_ddl(self.db, "drop table test_table")
+        self.db.query("create table test_table (n integer)")
+        self.db.begin()
+        self.db.query("insert into test_table values (1)")
+        self.db.query("insert into test_table values (2)")
+        self.db.commit()
+        self.db.begin()
+        self.db.query("insert into test_table values (3)")
+        self.db.query("insert into test_table values (4)")
+        self.db.rollback()
+        self.db.begin()
+        self.db.query("insert into test_table values (5)")
+        self.db.savepoint('before6')
+        self.db.query("insert into test_table values (6)")
+        self.db.rollback('before6')
+        self.db.query("insert into test_table values (7)")
+        self.db.commit()
+        self.db.begin()
+        self.db.savepoint('before8')
+        self.db.query("insert into test_table values (8)")
+        self.db.release('before8')
+        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
+        self.db.commit()
+        self.db.start()
+        self.db.query("insert into test_table values (9)")
+        self.db.end()
+        r = [r[0] for r in self.db.query(
+            "select * from test_table order by 1").getresult()]
+        self.assertEqual(r, [1, 2, 5, 7, 9])
+
+    def testContextManager(self):
+        smart_ddl(self.db, "drop table test_table")
+        self.db.query("create table test_table (n integer check(n>0))")
+        with self.db:
+            self.db.query("insert into test_table values (1)")
+            self.db.query("insert into test_table values (2)")
+        try:
+            with self.db:
+                self.db.query("insert into test_table values (3)")
+                self.db.query("insert into test_table values (4)")
+                raise ValueError('test transaction should rollback')
+        except ValueError, error:
+            self.assertEqual(str(error), 'test transaction should rollback')
+        with self.db:
+            self.db.query("insert into test_table values (5)")
+        try:
+            with self.db:
+                self.db.query("insert into test_table values (6)")
+                self.db.query("insert into test_table values (-1)")
+        except pg.ProgrammingError, error:
+            self.assertTrue('check' in str(error))
+        with self.db:
+            self.db.query("insert into test_table values (7)")
+        r = [r[0] for r in self.db.query(
+            "select * from test_table order by 1").getresult()]
+        self.assertEqual(r, [1, 2, 5, 7])
+
     def testBytea(self):
         smart_ddl(self.db, 'drop table bytea_test')
         smart_ddl(self.db, 'create table bytea_test ('
@@ -1783,13 +1848,13 @@
         self.db.query("set search_path to s2,s4")
         self.assertRaises(PrgError, self.db.get, "t1", 1, 'n')
         self.assertEqual(self.db.get("t4", 1, 'n')['d'], 4)
-        self.assertRaises(pg.ProgrammingError, self.db.get, "t3", 1, 'n')
+        self.assertRaises(PrgError, self.db.get, "t3", 1, 'n')
         self.assertEqual(self.db.get("t", 1, 'n')['d'], 2)
         self.assertEqual(self.db.get("s3.t3", 1, 'n')['d'], 3)
         self.db.query("set search_path to s1,s3")
         self.assertRaises(PrgError, self.db.get, "t2", 1, 'n')
         self.assertEqual(self.db.get("t3", 1, 'n')['d'], 3)
-        self.assertRaises(pg.ProgrammingError, self.db.get, "t4", 1, 'n')
+        self.assertRaises(PrgError, self.db.get, "t4", 1, 'n')
         self.assertEqual(self.db.get("t", 1, 'n')['d'], 1)
         self.assertEqual(self.db.get("s4.t4", 1, 'n')['d'], 4)
 
_______________________________________________
PyGreSQL mailing list
PyGreSQL@Vex.Net
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to