Author: cito Date: Tue Nov 6 14:37:49 2012 New Revision: 464 Log: Implement context manager in pg similar to pgdb.
Modified: trunk/module/TEST_PyGreSQL_classic.py trunk/module/pg.py trunk/module/test_pg.py Modified: trunk/module/TEST_PyGreSQL_classic.py ============================================================================== --- trunk/module/TEST_PyGreSQL_classic.py Tue Nov 6 10:31:03 2012 (r463) +++ trunk/module/TEST_PyGreSQL_classic.py Tue Nov 6 14:37:49 2012 (r464) @@ -1,5 +1,7 @@ #!/usr/bin/env python +from __future__ import with_statement + import unittest from pg import * @@ -100,6 +102,31 @@ db.insert('_test_schema', _test=1235) self.assertEqual(d['dvar'], 999) + def test_context_manager(self): + t = '_test_schema' + d = dict(_test=1235) + with db: + db.insert(t, d) + d['_test'] += 1 + db.insert(t, d) + try: + with db: + d['_test'] += 1 + db.insert(t, d) + db.insert(t, d) + except ProgrammingError: + pass + with db: + d['_test'] += 1 + db.insert(t, d) + d['_test'] += 1 + db.insert(t, d) + self.assertTrue(db.get(t, 1235)) + self.assertTrue(db.get(t, 1236)) + self.assertRaises(DatabaseError, db.get, t, 1237) + self.assertTrue(db.get(t, 1238)) + self.assertTrue(db.get(t, 1239)) + def test_sqlstate(self): db.query("INSERT INTO _test_schema VALUES (1234)") try: Modified: trunk/module/pg.py ============================================================================== --- trunk/module/pg.py Tue Nov 6 10:31:03 2012 (r463) +++ trunk/module/pg.py Tue Nov 6 14:37:49 2012 (r464) @@ -165,7 +165,6 @@ self._attnames = {} self._pkeys = {} self._privileges = {} - self._transaction = False self._args = args, kw self.debug = None # For debugging scripts, this can be set # * to a string format specification (e.g. in CGI set to "%s<BR>"), @@ -183,19 +182,16 @@ # Context manager methods def __enter__(self): - if self._transaction: - self.begin() + """Enter the runtime context. This will start a transaction.""" + self.begin() return self - def __exit__(self, typ, val, tb): - if self._transaction: - self._transaction = False - if tb is None: - self.commit() - else: - self.rollback() + def __exit__(self, et, ev, tb): + """Exit the runtime context. This will end the transaction.""" + if et is None and ev is None and tb is None: + self.commit() else: - self.close() + self.rollback() # Auxiliary methods @@ -373,7 +369,7 @@ qstr = 'ROLLBACK' if name: qstr += ' TO ' + name - return self.query('ROLLBACK') + return self.query(qstr) def savepoint(self, name=None): """Define a new savepoint within the current transaction.""" @@ -386,12 +382,6 @@ """Destroy a previously defined savepoint.""" return self.query('RELEASE ' + name) - @property - def transaction(self): - """Return a context manager for running a transaction.""" - self._transaction = True - return self - def query(self, qstr, *args): """Executes a SQL command string. Modified: trunk/module/test_pg.py ============================================================================== --- trunk/module/test_pg.py Tue Nov 6 10:31:03 2012 (r463) +++ trunk/module/test_pg.py Tue Nov 6 14:37:49 2012 (r464) @@ -646,12 +646,13 @@ self.assertEqual(r, '5') def testPrint(self): + import os + import sys q = "select 1 as a, 'hello' as h, 'w' as world" \ " union select 2, 'xyz', 'uvw'" r = self.c.query(q) t = '~test_pg_testPrint_temp.tmp' s = open(t, 'w') - import os, sys stdout, sys.stdout = sys.stdout, s try: print r @@ -685,7 +686,7 @@ self.assertTrue(self.c.getnotify() is None) try: self.c.query("notify test_notify, 'test_payload'") - except pg.ProgrammingError: # PostgreSQL < 9.0 + except pg.ProgrammingError: # PostgreSQL < 9.0 pass else: r = self.c.getnotify() @@ -724,10 +725,10 @@ query = self.c.query self.assertEqual(query("select 1+1").getresult(), [(2,)]) self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)]) - self.assertEqual(query("select 1+$1", [1,]).getresult(), [(2,)]) + self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)]) self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)]) - self.assertEqual(query("select $1::text", (2,) ).getresult(), [('2',)]) - self.assertEqual(query("select 1+$1::numeric", [1,]).getresult(), + self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)]) + self.assertEqual(query("select 1+$1::numeric", [1]).getresult(), [(Decimal('2'),)]) self.assertEqual(query("select 1, $1::integer", (2,) ).getresult(), [(1, 2)]) @@ -852,7 +853,7 @@ def testInserttableMultipleCalls(self): num_rows = 10 data = [(1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1")] - for i in range(num_rows): + for _i in range(num_rows): self.c.inserttable("test", data) r = self.c.query("select count(*) from test").getresult()[0][0] self.assertEqual(r, num_rows) @@ -864,8 +865,8 @@ self.assertEqual(r, data) def testInserttableMaxValues(self): - data = [(2**15 - 1, int(2**31 - 1), long(2**31 - 1), - None, 1.0 + 1.0/32, 1.0 + 1.0/32, None, + data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), + None, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, "1234", "1234", "1234" * 10)] self.c.inserttable("test", data) r = self.c.query("select * from test").getresult() @@ -915,12 +916,14 @@ $$ language plpgsql''') try: received = {} + def notice_receiver(notice): for attr in dir(notice): value = getattr(notice, attr) if isinstance(value, str): value = value.replace('WARNUNG', 'WARNING') received[attr] = value + c.set_notice_receiver(notice_receiver) c.query('''select bilbo_notice()''') self.assertEqual(received, dict( @@ -945,14 +948,15 @@ self.db.close() def testAllDBAttributes(self): - attributes = '''cancel clear close db dbname debug delete endcopy - error escape_bytea escape_identifier escape_literal escape_string - fileno get get_attnames get_databases get_notice_receiver - get_relations get_tables getline getlo getnotify - has_table_privilege host insert inserttable locreate loimport - options parameter pkey port protocol_version putline query - reopen reset server_version set_notice_receiver source status - transaction tty unescape_bytea update use_regtypes user'''.split() + attributes = '''begin cancel clear close commit db dbname debug delete + end endcopy error escape_bytea escape_identifier escape_literal + escape_string fileno get get_attnames get_databases + get_notice_receiver get_relations get_tables getline getlo + getnotify has_table_privilege host insert inserttable locreate + loimport options parameter pkey port protocol_version putline query + release reopen reset rollback savepoint server_version + set_notice_receiver source start status transaction tty + unescape_bytea update use_regtypes user'''.split() db_attributes = [a for a in dir(self.db) if not a.startswith('_')] self.assertEqual(attributes, db_attributes) @@ -1145,8 +1149,10 @@ self.assertEqual(self.db.db, db.db) db = pg.DB(db=self.db.db) self.assertEqual(self.db.db, db.db) + class DB2: pass + db2 = DB2() db2._cnx = self.db.db db = pg.DB(db2) @@ -1337,7 +1343,7 @@ def testQueryWithParams(self): smart_ddl(self.db, "drop table test_table") q = "create table test_table (n1 integer, n2 integer) with oids" - r = self.db.query(q) + self.db.query(q) q = "insert into test_table values ($1, $2)" r = self.db.query(q, (1, 2)) self.assertTrue(isinstance(r, int)) @@ -1486,7 +1492,7 @@ "n integer, t text)" % table) for n, t in enumerate('xyz'): self.db.query('insert into "%s" values(' - "%d, '%s')" % (table, n+1, t)) + "%d, '%s')" % (table, n + 1, t)) self.assertRaises(pg.ProgrammingError, self.db.get, table, 2) r = self.db.get(table, 2, 'n') oid_table = table @@ -1526,7 +1532,7 @@ "n integer, t text, primary key (n))" % table) for n, t in enumerate('abc'): self.db.query("insert into %s values(" - "%d, '%s')" % (table, n+1, t)) + "%d, '%s')" % (table, n + 1, t)) self.assertEqual(self.db.get(table, 2)['t'], 'b') table = 'get_test_table_2' smart_ddl(self.db, "drop table %s" % table) @@ -1534,9 +1540,9 @@ "n integer, m integer, t text, primary key (n, m))" % table) for n in range(3): for m in range(2): - t = chr(ord('a') + 2*n + m) + t = chr(ord('a') + 2 * n + m) self.db.query("insert into %s values(" - "%d, %d, '%s')" % (table, n+1, m+1, t)) + "%d, %d, '%s')" % (table, n + 1, m + 1, t)) self.assertRaises(pg.ProgrammingError, self.db.get, table, 2) self.assertEqual(self.db.get(table, dict(n=2, m=2))['t'], 'd') self.assertEqual(self.db.get(table, dict(n=1, m=2), @@ -1560,9 +1566,10 @@ "d numeric, f4 real, f8 double precision, m money, " "v4 varchar(4), c4 char(4), t text," "b boolean, ts timestamp)" % table) - data = dict(i2=2**15 - 1, i4=int(2**31 - 1), i8=long(2**31 - 1), + data = dict( + i2=2 ** 15 - 1, i4=int(2 ** 31 - 1), i8=long(2 ** 31 - 1), d=Decimal('123456789.9876543212345678987654321'), - f4=1.0 + 1.0/32, f8 = 1.0 + 1.0/32, + f4=1.0 + 1.0 / 32, f8=1.0 + 1.0 / 32, m="1234.56", v4="1234", c4="1234", t="1234" * 10, b=1, ts='2012-12-21') r = self.db.insert(table, data) @@ -1585,7 +1592,7 @@ "n integer, t text)" % table) for n, t in enumerate('xyz'): self.db.query('insert into "%s" values(' - "%d, '%s')" % (table, n+1, t)) + "%d, '%s')" % (table, n + 1, t)) self.assertRaises(pg.ProgrammingError, self.db.get, table, 2) r = self.db.get(table, 2, 'n') r['t'] = 'u' @@ -1602,7 +1609,7 @@ "n integer, t text, primary key (n))" % table) for n, t in enumerate('abc'): self.db.query("insert into %s values(" - "%d, '%s')" % (table, n+1, t)) + "%d, '%s')" % (table, n + 1, t)) self.assertRaises(pg.ProgrammingError, self.db.update, table, dict(t='b')) self.assertEqual(self.db.update(table, dict(n=2, t='d'))['t'], 'd') @@ -1615,9 +1622,9 @@ "n integer, m integer, t text, primary key (n, m))" % table) for n in range(3): for m in range(2): - t = chr(ord('a') + 2*n +m) + t = chr(ord('a') + 2 * n + m) self.db.query("insert into %s values(" - "%d, %d, '%s')" % (table, n+1, m+1, t)) + "%d, %d, '%s')" % (table, n + 1, m + 1, t)) self.assertRaises(pg.ProgrammingError, self.db.update, table, dict(n=2, t='b')) self.assertEqual(self.db.update(table, @@ -1649,7 +1656,7 @@ "n integer, t text)" % table) for n, t in enumerate('xyz'): self.db.query('insert into "%s" values(' - "%d, '%s')" % (table, n+1, t)) + "%d, '%s')" % (table, n + 1, t)) self.assertRaises(pg.ProgrammingError, self.db.get, table, 2) r = self.db.get(table, 1, 'n') s = self.db.delete(table, r) @@ -1678,7 +1685,7 @@ "n integer, t text, primary key (n))" % table) for n, t in enumerate('abc'): self.db.query("insert into %s values(" - "%d, '%s')" % (table, n+1, t)) + "%d, '%s')" % (table, n + 1, t)) self.assertRaises(pg.ProgrammingError, self.db.delete, table, dict(t='b')) self.assertEqual(self.db.delete(table, dict(n=2)), 1) @@ -1695,9 +1702,9 @@ "n integer, m integer, t text, primary key (n, m))" % table) for n in range(3): for m in range(2): - t = chr(ord('a') + 2*n +m) + t = chr(ord('a') + 2 * n + m) self.db.query("insert into %s values(" - "%d, %d, '%s')" % (table, n+1, m+1, t)) + "%d, %d, '%s')" % (table, n + 1, m + 1, t)) self.assertRaises(pg.ProgrammingError, self.db.delete, table, dict(n=2, t='b')) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1) @@ -1713,6 +1720,64 @@ ' order by m' % table).getresult()] self.assertEqual(r, ['f']) + def testTransaction(self): + smart_ddl(self.db, "drop table test_table") + self.db.query("create table test_table (n integer)") + self.db.begin() + self.db.query("insert into test_table values (1)") + self.db.query("insert into test_table values (2)") + self.db.commit() + self.db.begin() + self.db.query("insert into test_table values (3)") + self.db.query("insert into test_table values (4)") + self.db.rollback() + self.db.begin() + self.db.query("insert into test_table values (5)") + self.db.savepoint('before6') + self.db.query("insert into test_table values (6)") + self.db.rollback('before6') + self.db.query("insert into test_table values (7)") + self.db.commit() + self.db.begin() + self.db.savepoint('before8') + self.db.query("insert into test_table values (8)") + self.db.release('before8') + self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8') + self.db.commit() + self.db.start() + self.db.query("insert into test_table values (9)") + self.db.end() + r = [r[0] for r in self.db.query( + "select * from test_table order by 1").getresult()] + self.assertEqual(r, [1, 2, 5, 7, 9]) + + def testContextManager(self): + smart_ddl(self.db, "drop table test_table") + self.db.query("create table test_table (n integer check(n>0))") + with self.db: + self.db.query("insert into test_table values (1)") + self.db.query("insert into test_table values (2)") + try: + with self.db: + self.db.query("insert into test_table values (3)") + self.db.query("insert into test_table values (4)") + raise ValueError('test transaction should rollback') + except ValueError, error: + self.assertEqual(str(error), 'test transaction should rollback') + with self.db: + self.db.query("insert into test_table values (5)") + try: + with self.db: + self.db.query("insert into test_table values (6)") + self.db.query("insert into test_table values (-1)") + except pg.ProgrammingError, error: + self.assertTrue('check' in str(error)) + with self.db: + self.db.query("insert into test_table values (7)") + r = [r[0] for r in self.db.query( + "select * from test_table order by 1").getresult()] + self.assertEqual(r, [1, 2, 5, 7]) + def testBytea(self): smart_ddl(self.db, 'drop table bytea_test') smart_ddl(self.db, 'create table bytea_test (' @@ -1783,13 +1848,13 @@ self.db.query("set search_path to s2,s4") self.assertRaises(PrgError, self.db.get, "t1", 1, 'n') self.assertEqual(self.db.get("t4", 1, 'n')['d'], 4) - self.assertRaises(pg.ProgrammingError, self.db.get, "t3", 1, 'n') + self.assertRaises(PrgError, self.db.get, "t3", 1, 'n') self.assertEqual(self.db.get("t", 1, 'n')['d'], 2) self.assertEqual(self.db.get("s3.t3", 1, 'n')['d'], 3) self.db.query("set search_path to s1,s3") self.assertRaises(PrgError, self.db.get, "t2", 1, 'n') self.assertEqual(self.db.get("t3", 1, 'n')['d'], 3) - self.assertRaises(pg.ProgrammingError, self.db.get, "t4", 1, 'n') + self.assertRaises(PrgError, self.db.get, "t4", 1, 'n') self.assertEqual(self.db.get("t", 1, 'n')['d'], 1) self.assertEqual(self.db.get("s4.t4", 1, 'n')['d'], 4) _______________________________________________ PyGreSQL mailing list PyGreSQL@Vex.Net https://mail.vex.net/mailman/listinfo.cgi/pygresql