"""
DBPool.py

Implements a pool of cached connections to a database. This should result in
a speedup for persistent apps. The pool of connections is threadsafe
regardless of whether the DB API module question in general has a
threadsafety of 1 or 2.

For more information on the DB API, see:
http://www.python.org/topics/database/DatabaseAPI-2.0.html

The idea behind DBPool is that it's completely seamless, so once you have
established your connection, use it just as you would any other DB-API
compliant module. For example:

dbPool = DBPool(MySQLdb, 5, host=xxx, user=xxx, ...)
db = dbPool.getConnection()

Now use "db" exactly as if it were a MySQLdb connection. It's really
just a proxy class.

db.close() will return the connection to the pool, not actually
close it. This is so your existing code works nicely.


FUTURE

* If in the presence of WebKit, register ourselves as a Can.


CREDIT

* Contributed by Dan Green
* thread safety bug found by Tom Schwaller
* Fixes by Geoff Talvola (thread safety in _threadsafe_getConnection()).
* Clean up by Chuck Esterbrook.
* Fix unthreadsafe functions which were leaking, Jay Love
* Eli Green's webware-discuss comments were lifted for additional docs.
"""


import threading, time
from Queue import Queue, Empty, Full

class Stack(Queue):
    def _get(self):
        item = self.queue[-1]
        del self.queue[-1]
        return item


class DBPoolError(Exception): pass
class UnsupportedError(DBPoolError): pass


class PooledConnection:
    """ A wrapper for database connections to help with DBPool. You don't normally deal with this class directly, but use DBPool to get new connections. """
    
    def __init__(self, pool, con):
        self._con = con
        self._pool = pool
        self._timestamp = time.time()
        self._out = 0
    
    def out(self):
        '''how many times have I been given out?'''
        return self._out
    
    def ref(self):
        '''give me out one more time'''
        self._out = self._out + 1
    
    def deref(self):
        '''return me, no longer out'''
        self._out = self._out - 1
        if self._out < 0:
            self._out = 0
    
    def close(self):
        self._timestamp = time.time()
        if self._con is not None:
            self._pool.returnConnection(self)
            #self._con = None
    
    def prune(self,timeout):
        if self._timestamp + timeout < time.time() and self._out <= 0:
            self._con.close()
            self._con = None
            return 1
        return 0
    
    def __getattr__(self, name):
        return getattr(self._con, name)
    
    def __del__(self):
        self.close()

_loglock = threading.RLock()
    
class DBPool:
    
    def __init__(self, dbModule, maxConnections, con_timeout=10, con_prune=60, *args, **kwargs):
        '''
            inputs:
                dbModule: module with method "connect" for creating connections
                maxConnections: maximum connections to be created (on demand)
                timeout: minimum inactivity timeout (in seconds) before connections are closed.
                prune: minimum time between pruning runs (not guaranteed, as only checked.
                *args: any arguments to be passwd to dbModule.connect
                **kwargs: any named arguments to be passed to dbModule.connect
            
        '''
        # @@ 2002-02-14 lo: modification for mx.ODBC's not-quite compliant interface.
        if hasattr(dbModule, 'threadlevel'):
            threadsafe = dbModule.threadlevel
        else:
            threadsafe = dbModule.threadsafety
        if threadsafe==0:
            raise UnsupportedError, "Database module does not support any level of threading."
        elif threadsafe==1:
            self._lock = threading.Lock()
            # stack so that loading is NOT distributed across all connections: we want some to expire if possible!
            self._queue = Stack(maxConnections)
            self.addConnection = self._unthreadsafe_addConnection
            self.getConnection = self._unthreadsafe_getConnection
            self.returnConnection = self._unthreadsafe_returnConnection
            self.usedConnections = self._unthreadsafe_usedConnections
            self._pruneConnections = self._unthreadsafe_pruneConnections
        elif threadsafe>=2:
            self._lock = threading.Lock()
            self._nextCon = 0
            self._connections = []
            self.addConnection = self._threadsafe_addConnection
            self.getConnection = self._threadsafe_getConnection
            self.returnConnection = self._threadsafe_returnConnection
            self._pruneConnections = self._threadsafe_pruneConnections
            self.usedConnections = self._threadsafe_usedConnections
    
        self._db = dbModule
        self._args = args
        self._kwargs = kwargs
        self._maxConnections = maxConnections
        self._timeout = con_timeout
        self._prune = con_prune
        self._lastPruned = time.time()
        
    
        # @@ 2000-12-04 ce: Should we really make all the connections now?
        # Couldn't we do this on demand?
        # @@ 2002-02-14 lo: Yes!
        #for i in range(maxConnections):
        #	con = apply(dbModule.connect, args, kwargs)
        #	self.addConnection(con)

    #~ def log(self, msg):
        #~ _loglock.acquire()
        #~ try:
            #~ l = open('e:/dbpool.log','a')
            #~ l.write("%s: %s\n" % (time.strftime('%X'), msg))
            #~ l.close()
        #~ finally:
            #~ _loglock.release()

    
    def createConnection(self):
        #self.log('created connection')
        return PooledConnection(self, apply(self._db.connect, self._args, self._kwargs))
    
    # threadsafe/unthreadsafe refers to the database _module_, not THIS class..
    # this class is definitely threadsafe (um. that is, I hope so - Dan)
    
    def _threadsafe_usedConnections(self):
        return len(self._connections)
    
    def _threadsafe_addConnection(self, con):
        self._connections.append(con)
    
    def _threadsafe_getConnection(self):
        self._lock.acquire()
        try:
            if len(self._connections) == 0:
                #self.log('empty queue')
                self.addConnection(self.createConnection())
            
            low_out = 9999 # @@ 2003-02-16 lo: ugh. arbitrarily high cutoff
            con = None
            low_con = None
            # find least-allocated connection
            for c in self._connections:
                if c.out() is 0:
                    con = c
                    break
                if c.out() <= low_out:
                    low_out = c.out()
                    low_con = c
            
            
            if not con:
                # if all are allocated, first try creating a new one
                if len(self._connections) < self._maxConnections:
                    con = self.createConnection()
                    #self.log('low_con was %s or %s' % (low_con.out(), low_out))
                    self.addConnection(con)
                # but in the end, hand out least allocated one.
                else:
                    #self.log('giving one with ref %s' % low_con.out())
                    con = low_con
                
            con.ref()
            # safe to prune afterwards, as ref count for con will stop it from being pruned.
            self._pruneConnections()
            return con
        finally:
            self._lock.release()
    
    def _threadsafe_returnConnection(self, con):
        con.deref()
        return

    def _threadsafe_pruneConnections(self):
        if self._lastPruned + self._prune < time.time():
            #self.log('pruning...%s' % len(self._connections))
            self._lastPruned = time.time()
            temp = []
            for con in self._connections:
                #self.log('trying %s' % con)
                if not con.prune(self._timeout):
                    #self.log('saved con %s' % con)
                    temp.append(con)
            self._connections = temp
            
            #~ self.log(str(temp))
            #~ self.log(str(self._connections))
            #~ #temp = [ c for c in self._connections if c is not None ]
            #~ temp = []
            #~ for c in self._connections:
                #~ if c == 'hahah':
                    #~ self.log('got a Nun!')
                #~ else:
                    #~ temp.append(c)
            #~ self.log(str(temp))
            #~ self._connections = temp
            #~ self.log(str(self._connections))
            #~ self.log('done, %s remaining' % len(self._connections))
            
                

    # These functions are used with DB modules that have connection level threadsafety, like PostgreSQL.
    #
    
    def _unthreadsafe_usedConnections(self):
        '''not guaranteed by Queue class, don't count on this number.'''
        return self._queue.qsize()
    
    def _unthreadsafe_addConnection(self, con):
        '''may raise Empty/Full exceptions! to be caught and dealt with by caller.'''
        self._queue.put(con)
    
    def _unthreadsafe_getConnection(self):
        con = None
        while con is None:
            try:
                con = self._queue.get_nowait()
            except Empty:
                #self.log('empty queue...')
                try:
                    new_con = self.createConnection()
                    self.addConnection(new_con)
                    #loop back to make sure we pull from the queue...
                except Full:
                    #self.log('full queue!')
                    del new_con
                    con = self._queue.get()
        # @@ 2003-02-16 lo: maybe easiest to prune first
        #this way we need to use ref/deref like in _threadsafe version, when ref will only ever be 1.
        # drawback, possibility of pruning all connections and needing to recreate
        # even though we know we need one now. hmm.
        self._pruneConnections()
        con.ref()
        #self.log('gave %s' % con)
        return con
    
    def _unthreadsafe_returnConnection(self, con):
        """
        This should never be called explicitly outside of this module.
        """
        try:
            con.deref()
            self.addConnection(con)
        except Full:
            del con

    def _unthreadsafe_pruneConnections(self):
        self._lock.acquire()
        try:
            if self._lastPruned + self._prune < time.time():
                self._lastPruned = time.time()
                temp_queue = Queue(self._maxConnections)
                #self.log('pruning...')
                try:
                    con = self._queue.get_nowait()
                except Empty:
                    return
                while con:
                    #self.log('trying to prune %s' % con)
                    if not con.prune(self._timeout):
                        #self.log('saved con %s' % con)
                        temp_queue.put(con)
                    try:
                        con = self._queue.get_nowait()
                    except Empty:
                        break
                
                while not temp_queue.empty():
                    try:
                        #self.log('returning con %s' % con)
                        self._queue.put_nowait(temp_queue.get())
                    except Full:
                        break
                    
                temp_queue = None
        finally:
            self._lock.release()