Author: cito
Date: Tue Nov 6 14:37:49 2012
New Revision: 464
Log:
Implement context manager in pg similar to pgdb.
Modified:
trunk/module/TEST_PyGreSQL_classic.py
trunk/module/pg.py
trunk/module/test_pg.py
Modified: trunk/module/TEST_PyGreSQL_classic.py
==============================================================================
--- trunk/module/TEST_PyGreSQL_classic.py Tue Nov 6 10:31:03 2012
(r463)
+++ trunk/module/TEST_PyGreSQL_classic.py Tue Nov 6 14:37:49 2012
(r464)
@@ -1,5 +1,7 @@
#!/usr/bin/env python
+from __future__ import with_statement
+
import unittest
from pg import *
@@ -100,6 +102,31 @@
db.insert('_test_schema', _test=1235)
self.assertEqual(d['dvar'], 999)
+ def test_context_manager(self):
+ t = '_test_schema'
+ d = dict(_test=1235)
+ with db:
+ db.insert(t, d)
+ d['_test'] += 1
+ db.insert(t, d)
+ try:
+ with db:
+ d['_test'] += 1
+ db.insert(t, d)
+ db.insert(t, d)
+ except ProgrammingError:
+ pass
+ with db:
+ d['_test'] += 1
+ db.insert(t, d)
+ d['_test'] += 1
+ db.insert(t, d)
+ self.assertTrue(db.get(t, 1235))
+ self.assertTrue(db.get(t, 1236))
+ self.assertRaises(DatabaseError, db.get, t, 1237)
+ self.assertTrue(db.get(t, 1238))
+ self.assertTrue(db.get(t, 1239))
+
def test_sqlstate(self):
db.query("INSERT INTO _test_schema VALUES (1234)")
try:
Modified: trunk/module/pg.py
==============================================================================
--- trunk/module/pg.py Tue Nov 6 10:31:03 2012 (r463)
+++ trunk/module/pg.py Tue Nov 6 14:37:49 2012 (r464)
@@ -165,7 +165,6 @@
self._attnames = {}
self._pkeys = {}
self._privileges = {}
- self._transaction = False
self._args = args, kw
self.debug = None # For debugging scripts, this can be set
# * to a string format specification (e.g. in CGI set to "%s<BR>"),
@@ -183,19 +182,16 @@
# Context manager methods
def __enter__(self):
- if self._transaction:
- self.begin()
+ """Enter the runtime context. This will start a transaction."""
+ self.begin()
return self
- def __exit__(self, typ, val, tb):
- if self._transaction:
- self._transaction = False
- if tb is None:
- self.commit()
- else:
- self.rollback()
+ def __exit__(self, et, ev, tb):
+ """Exit the runtime context. This will end the transaction."""
+ if et is None and ev is None and tb is None:
+ self.commit()
else:
- self.close()
+ self.rollback()
# Auxiliary methods
@@ -373,7 +369,7 @@
qstr = 'ROLLBACK'
if name:
qstr += ' TO ' + name
- return self.query('ROLLBACK')
+ return self.query(qstr)
def savepoint(self, name=None):
"""Define a new savepoint within the current transaction."""
@@ -386,12 +382,6 @@
"""Destroy a previously defined savepoint."""
return self.query('RELEASE ' + name)
- @property
- def transaction(self):
- """Return a context manager for running a transaction."""
- self._transaction = True
- return self
-
def query(self, qstr, *args):
"""Executes a SQL command string.
Modified: trunk/module/test_pg.py
==============================================================================
--- trunk/module/test_pg.py Tue Nov 6 10:31:03 2012 (r463)
+++ trunk/module/test_pg.py Tue Nov 6 14:37:49 2012 (r464)
@@ -646,12 +646,13 @@
self.assertEqual(r, '5')
def testPrint(self):
+ import os
+ import sys
q = "select 1 as a, 'hello' as h, 'w' as world" \
" union select 2, 'xyz', 'uvw'"
r = self.c.query(q)
t = '~test_pg_testPrint_temp.tmp'
s = open(t, 'w')
- import os, sys
stdout, sys.stdout = sys.stdout, s
try:
print r
@@ -685,7 +686,7 @@
self.assertTrue(self.c.getnotify() is None)
try:
self.c.query("notify test_notify, 'test_payload'")
- except pg.ProgrammingError: # PostgreSQL < 9.0
+ except pg.ProgrammingError: # PostgreSQL < 9.0
pass
else:
r = self.c.getnotify()
@@ -724,10 +725,10 @@
query = self.c.query
self.assertEqual(query("select 1+1").getresult(), [(2,)])
self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
- self.assertEqual(query("select 1+$1", [1,]).getresult(), [(2,)])
+ self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
- self.assertEqual(query("select $1::text", (2,) ).getresult(), [('2',)])
- self.assertEqual(query("select 1+$1::numeric", [1,]).getresult(),
+ self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
+ self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
[(Decimal('2'),)])
self.assertEqual(query("select 1, $1::integer", (2,)
).getresult(), [(1, 2)])
@@ -852,7 +853,7 @@
def testInserttableMultipleCalls(self):
num_rows = 10
data = [(1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1")]
- for i in range(num_rows):
+ for _i in range(num_rows):
self.c.inserttable("test", data)
r = self.c.query("select count(*) from test").getresult()[0][0]
self.assertEqual(r, num_rows)
@@ -864,8 +865,8 @@
self.assertEqual(r, data)
def testInserttableMaxValues(self):
- data = [(2**15 - 1, int(2**31 - 1), long(2**31 - 1),
- None, 1.0 + 1.0/32, 1.0 + 1.0/32, None,
+ data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
+ None, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
"1234", "1234", "1234" * 10)]
self.c.inserttable("test", data)
r = self.c.query("select * from test").getresult()
@@ -915,12 +916,14 @@
$$ language plpgsql''')
try:
received = {}
+
def notice_receiver(notice):
for attr in dir(notice):
value = getattr(notice, attr)
if isinstance(value, str):
value = value.replace('WARNUNG', 'WARNING')
received[attr] = value
+
c.set_notice_receiver(notice_receiver)
c.query('''select bilbo_notice()''')
self.assertEqual(received, dict(
@@ -945,14 +948,15 @@
self.db.close()
def testAllDBAttributes(self):
- attributes = '''cancel clear close db dbname debug delete endcopy
- error escape_bytea escape_identifier escape_literal escape_string
- fileno get get_attnames get_databases get_notice_receiver
- get_relations get_tables getline getlo getnotify
- has_table_privilege host insert inserttable locreate loimport
- options parameter pkey port protocol_version putline query
- reopen reset server_version set_notice_receiver source status
- transaction tty unescape_bytea update use_regtypes user'''.split()
+ attributes = '''begin cancel clear close commit db dbname debug delete
+ end endcopy error escape_bytea escape_identifier escape_literal
+ escape_string fileno get get_attnames get_databases
+ get_notice_receiver get_relations get_tables getline getlo
+ getnotify has_table_privilege host insert inserttable locreate
+ loimport options parameter pkey port protocol_version putline query
+ release reopen reset rollback savepoint server_version
+ set_notice_receiver source start status transaction tty
+ unescape_bytea update use_regtypes user'''.split()
db_attributes = [a for a in dir(self.db)
if not a.startswith('_')]
self.assertEqual(attributes, db_attributes)
@@ -1145,8 +1149,10 @@
self.assertEqual(self.db.db, db.db)
db = pg.DB(db=self.db.db)
self.assertEqual(self.db.db, db.db)
+
class DB2:
pass
+
db2 = DB2()
db2._cnx = self.db.db
db = pg.DB(db2)
@@ -1337,7 +1343,7 @@
def testQueryWithParams(self):
smart_ddl(self.db, "drop table test_table")
q = "create table test_table (n1 integer, n2 integer) with oids"
- r = self.db.query(q)
+ self.db.query(q)
q = "insert into test_table values ($1, $2)"
r = self.db.query(q, (1, 2))
self.assertTrue(isinstance(r, int))
@@ -1486,7 +1492,7 @@
"n integer, t text)" % table)
for n, t in enumerate('xyz'):
self.db.query('insert into "%s" values('
- "%d, '%s')" % (table, n+1, t))
+ "%d, '%s')" % (table, n + 1, t))
self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
r = self.db.get(table, 2, 'n')
oid_table = table
@@ -1526,7 +1532,7 @@
"n integer, t text, primary key (n))" % table)
for n, t in enumerate('abc'):
self.db.query("insert into %s values("
- "%d, '%s')" % (table, n+1, t))
+ "%d, '%s')" % (table, n + 1, t))
self.assertEqual(self.db.get(table, 2)['t'], 'b')
table = 'get_test_table_2'
smart_ddl(self.db, "drop table %s" % table)
@@ -1534,9 +1540,9 @@
"n integer, m integer, t text, primary key (n, m))" % table)
for n in range(3):
for m in range(2):
- t = chr(ord('a') + 2*n + m)
+ t = chr(ord('a') + 2 * n + m)
self.db.query("insert into %s values("
- "%d, %d, '%s')" % (table, n+1, m+1, t))
+ "%d, %d, '%s')" % (table, n + 1, m + 1, t))
self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
self.assertEqual(self.db.get(table, dict(n=2, m=2))['t'], 'd')
self.assertEqual(self.db.get(table, dict(n=1, m=2),
@@ -1560,9 +1566,10 @@
"d numeric, f4 real, f8 double precision, m money, "
"v4 varchar(4), c4 char(4), t text,"
"b boolean, ts timestamp)" % table)
- data = dict(i2=2**15 - 1, i4=int(2**31 - 1), i8=long(2**31 - 1),
+ data = dict(
+ i2=2 ** 15 - 1, i4=int(2 ** 31 - 1), i8=long(2 ** 31 - 1),
d=Decimal('123456789.9876543212345678987654321'),
- f4=1.0 + 1.0/32, f8 = 1.0 + 1.0/32,
+ f4=1.0 + 1.0 / 32, f8=1.0 + 1.0 / 32,
m="1234.56", v4="1234", c4="1234", t="1234" * 10,
b=1, ts='2012-12-21')
r = self.db.insert(table, data)
@@ -1585,7 +1592,7 @@
"n integer, t text)" % table)
for n, t in enumerate('xyz'):
self.db.query('insert into "%s" values('
- "%d, '%s')" % (table, n+1, t))
+ "%d, '%s')" % (table, n + 1, t))
self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
r = self.db.get(table, 2, 'n')
r['t'] = 'u'
@@ -1602,7 +1609,7 @@
"n integer, t text, primary key (n))" % table)
for n, t in enumerate('abc'):
self.db.query("insert into %s values("
- "%d, '%s')" % (table, n+1, t))
+ "%d, '%s')" % (table, n + 1, t))
self.assertRaises(pg.ProgrammingError, self.db.update,
table, dict(t='b'))
self.assertEqual(self.db.update(table, dict(n=2, t='d'))['t'], 'd')
@@ -1615,9 +1622,9 @@
"n integer, m integer, t text, primary key (n, m))" % table)
for n in range(3):
for m in range(2):
- t = chr(ord('a') + 2*n +m)
+ t = chr(ord('a') + 2 * n + m)
self.db.query("insert into %s values("
- "%d, %d, '%s')" % (table, n+1, m+1, t))
+ "%d, %d, '%s')" % (table, n + 1, m + 1, t))
self.assertRaises(pg.ProgrammingError, self.db.update,
table, dict(n=2, t='b'))
self.assertEqual(self.db.update(table,
@@ -1649,7 +1656,7 @@
"n integer, t text)" % table)
for n, t in enumerate('xyz'):
self.db.query('insert into "%s" values('
- "%d, '%s')" % (table, n+1, t))
+ "%d, '%s')" % (table, n + 1, t))
self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
r = self.db.get(table, 1, 'n')
s = self.db.delete(table, r)
@@ -1678,7 +1685,7 @@
"n integer, t text, primary key (n))" % table)
for n, t in enumerate('abc'):
self.db.query("insert into %s values("
- "%d, '%s')" % (table, n+1, t))
+ "%d, '%s')" % (table, n + 1, t))
self.assertRaises(pg.ProgrammingError, self.db.delete,
table, dict(t='b'))
self.assertEqual(self.db.delete(table, dict(n=2)), 1)
@@ -1695,9 +1702,9 @@
"n integer, m integer, t text, primary key (n, m))" % table)
for n in range(3):
for m in range(2):
- t = chr(ord('a') + 2*n +m)
+ t = chr(ord('a') + 2 * n + m)
self.db.query("insert into %s values("
- "%d, %d, '%s')" % (table, n+1, m+1, t))
+ "%d, %d, '%s')" % (table, n + 1, m + 1, t))
self.assertRaises(pg.ProgrammingError, self.db.delete,
table, dict(n=2, t='b'))
self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
@@ -1713,6 +1720,64 @@
' order by m' % table).getresult()]
self.assertEqual(r, ['f'])
+ def testTransaction(self):
+ smart_ddl(self.db, "drop table test_table")
+ self.db.query("create table test_table (n integer)")
+ self.db.begin()
+ self.db.query("insert into test_table values (1)")
+ self.db.query("insert into test_table values (2)")
+ self.db.commit()
+ self.db.begin()
+ self.db.query("insert into test_table values (3)")
+ self.db.query("insert into test_table values (4)")
+ self.db.rollback()
+ self.db.begin()
+ self.db.query("insert into test_table values (5)")
+ self.db.savepoint('before6')
+ self.db.query("insert into test_table values (6)")
+ self.db.rollback('before6')
+ self.db.query("insert into test_table values (7)")
+ self.db.commit()
+ self.db.begin()
+ self.db.savepoint('before8')
+ self.db.query("insert into test_table values (8)")
+ self.db.release('before8')
+ self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
+ self.db.commit()
+ self.db.start()
+ self.db.query("insert into test_table values (9)")
+ self.db.end()
+ r = [r[0] for r in self.db.query(
+ "select * from test_table order by 1").getresult()]
+ self.assertEqual(r, [1, 2, 5, 7, 9])
+
+ def testContextManager(self):
+ smart_ddl(self.db, "drop table test_table")
+ self.db.query("create table test_table (n integer check(n>0))")
+ with self.db:
+ self.db.query("insert into test_table values (1)")
+ self.db.query("insert into test_table values (2)")
+ try:
+ with self.db:
+ self.db.query("insert into test_table values (3)")
+ self.db.query("insert into test_table values (4)")
+ raise ValueError('test transaction should rollback')
+ except ValueError, error:
+ self.assertEqual(str(error), 'test transaction should rollback')
+ with self.db:
+ self.db.query("insert into test_table values (5)")
+ try:
+ with self.db:
+ self.db.query("insert into test_table values (6)")
+ self.db.query("insert into test_table values (-1)")
+ except pg.ProgrammingError, error:
+ self.assertTrue('check' in str(error))
+ with self.db:
+ self.db.query("insert into test_table values (7)")
+ r = [r[0] for r in self.db.query(
+ "select * from test_table order by 1").getresult()]
+ self.assertEqual(r, [1, 2, 5, 7])
+
def testBytea(self):
smart_ddl(self.db, 'drop table bytea_test')
smart_ddl(self.db, 'create table bytea_test ('
@@ -1783,13 +1848,13 @@
self.db.query("set search_path to s2,s4")
self.assertRaises(PrgError, self.db.get, "t1", 1, 'n')
self.assertEqual(self.db.get("t4", 1, 'n')['d'], 4)
- self.assertRaises(pg.ProgrammingError, self.db.get, "t3", 1, 'n')
+ self.assertRaises(PrgError, self.db.get, "t3", 1, 'n')
self.assertEqual(self.db.get("t", 1, 'n')['d'], 2)
self.assertEqual(self.db.get("s3.t3", 1, 'n')['d'], 3)
self.db.query("set search_path to s1,s3")
self.assertRaises(PrgError, self.db.get, "t2", 1, 'n')
self.assertEqual(self.db.get("t3", 1, 'n')['d'], 3)
- self.assertRaises(pg.ProgrammingError, self.db.get, "t4", 1, 'n')
+ self.assertRaises(PrgError, self.db.get, "t4", 1, 'n')
self.assertEqual(self.db.get("t", 1, 'n')['d'], 1)
self.assertEqual(self.db.get("s4.t4", 1, 'n')['d'], 4)
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql