try:
    import rtree
except ImportError:
    raise ImportError( 'The rtree package is required. Please install from http://pypi.python.org/pypi/Rtree.' )

from rtree.index import Rtree, CustomStorage, Property

import transaction
from persistent import Persistent
from persistent.dict import PersistentDict
from BTrees.LOBTree import LOBTree

import random


class SpatialObject(Persistent):
    ''' A wrapper for a spatial object '''
    def __init__(self, id, object, coordinates):
        self.id = id
        self.object = object
        self.coordinates = coordinates

DEBUG = False
if DEBUG:
    def log(msg):
        print msg
else:
    log = lambda msg: None
    
class SpatialIndex(Persistent):
    ''' A spatial index. You can insert objects with their coordinates and later
         perform queries on the index.
    '''
    
    def __init__(self, settings = {}, initialValuesGenerator = None):
        ''' Init. settings provide many means to customize the spatial tree.
            E.g. setting leaf_capacity and near_minimum_overlap_factor to "good"
            values can help to reduce the possibility of write conflict errors.
            
            Below is a list of the currently available properties. For more info
            about these and other properties see the rtree docs and/or code.
            
                writethrough
                buffering_capacity
                pagesize
                leaf_capacity
                near_minimum_overlap_factor
                type
                variant
                dimension
                index_capacity
                index_pool_capacity
                point_pool_capacity
                region_pool_capacity
                tight_mbr
                fill_factor
                split_distribution_factor
                tpr_horizon
                reinsert_factor
                
            If you supply an initialValuesGenerator you can build a spatial index
            from initial values. This is much faster than doing repeated insert()s.
        '''
        Persistent.__init__( self )
        self.settings = PersistentDict( settings )
        self.pageData = LOBTree()
        self.idToItem = LOBTree()

        # this creates the tree and creates header and root pages
        self.getTree( initialValuesGenerator )

    def getTree(self, initialValuesGenerator = None):
        ''' Creates the r-tree if it is not already created yet and returns it '''
        tree = getattr( self, '_v_tree', None )
        if not tree:
            # create r-tree property object
            properties = Property()
            settings = getattr(self, 'settings', None)
            if not settings:
                raise ValueError('invalid spatial index')
            for name, value in settings.items():
                if not hasattr( properties, name ):
                    raise ValueError( 'Invalid setting "%s"' % name )
                setattr( properties, name, value )
            # create r-tree storage object
            storage = SpatialStorage( self.pageData )
            # create r-tree
            if not initialValuesGenerator:
                tree = Rtree( storage, properties = properties )
            else:
                tree = Rtree( storage, initialValuesGenerator, properties = properties )
            self._v_tree = tree
        else:
            if initialValuesGenerator:
                raise ValueError(initialValuesGenerator)
            
        return tree
        
    tree = property( lambda self: self.getTree() )

    def insert(self, object, coordinates):
        ''' Inserts object with bounds into this index. Returns the added item. '''
        # Generate a random 64-bit value.
        # The collision probabilities can be seen here http://en.wikipedia.org/wiki/Birthday_attack .
        # Ideally we could use 128-bit values (or generate ids in an ascending sequence).
        # 128-bit values don't work, because the delete function accepts only 64-bit values.
        # Generating an ascending sequence might increase the possibility of ConflictErrors, but I am
        #  not sure about this.
        # So we stick to 64-bit numbers now while being aware that collisions are likely to occur with
        #  very many entries.
        self._registerDataManager()
        id = random.getrandbits(63)
        item = SpatialObject( id, object, coordinates )
        if id in self.idToItem:
            # should we just generate ids here until we found a unique one?
            raise ValueError( 'duplicate id %s' % id )
        self.idToItem[ id ] = item
        self.tree.add( id, coordinates, object )
        return item
        
    add = insert
        
    def delete(self, item):
        ''' Deletes an item from this index '''
        self._registerDataManager()
        del self.idToItem[ item.id ]
        self.tree.delete( item.id, item.coordinates )

    def count(self, coordinates):
        ''' Counts the number of objects within coordinates '''
        self._registerDataManager()
        return self.tree.count( coordinates )

    def intersection(self, coordinates, object = True):
        ''' Returns all objects which are within the given bounds.
            If object is True, the raw object given to insert() is returned, else
            the item returned by insert():
        '''
        self._registerDataManager()
        tree = self.tree
        if object:
            for id in tree.nearest( coordinates, objects = False ):
                yield self.idToItem[ id ].object
        else:
            for id in tree.nearest( coordinates, objects = False ):
                yield self.idToItem[ id ]

    def nearest(self, coordinates, num_results = 1, object = True):
        ''' Returns the num_results objects which are closest to coordinates
            If object is True, the raw object given to insert() is returned, else
            the item returned by insert():
        '''
        self._registerDataManager()
        tree = self.tree
        if object:
            for id in tree.nearest( coordinates, num_results, objects = False ):
                yield self.idToItem[ id ].object
        else:
            for id in tree.nearest( coordinates, num_results, objects = False ):
                yield self.idToItem[ id ]

    def leaves(self):
        ''' Returns all leaves in the tree. A leaf is a tuple (id, child_ids, bounds) '''
        self._registerDataManager()
        return self.tree.leaves()

    def get_bounds(self, coordinate_interleaved = None):
        ''' Returns the bounds of the whole tree '''
        self._registerDataManager()
        return self.tree.get_bounds( coordinate_interleaved )
    
    bounds = property( get_bounds )
    
    def clearBuffer(self, blockWrites):
        tree = getattr( self, '_v_tree', None )
        if not tree:
            return
        if blockWrites:
            tree.customstorage.blockWrites = True
        #log( 'PRE-CLEAR blockWrites:%s tree:%s bounds:%s' % ( blockWrites, self.tree, self.bounds ) )
        log( 'PRE-CLEAR' )
        tree.clearBuffer()
        #log( 'POST-CLEAR bounds:%s' % self.bounds )
        log( 'POST-CLEAR' )
        if blockWrites:
            tree.customstorage.blockWrites = False

    def _registerDataManager(self):
        ''' This registers a custom data manager to flush all the buffers when
             they are dirty. '''
        registered = getattr( self, '_v_dataManagerRegistered', False )
        if registered:
            return
        self._v_dataManagerRegistered = True
        
        # haha, this is really ugly, but zodb's transaction module sorts
        #  data managers only on commit(). That's not good for us, our data
        #  manager's savepoint/rollback/abort calls need to be executed before
        #  the connection's savepoint/rollback/abort.
        t = transaction.get()
        org_join = t.join
        def join(resource):
            org_join( resource )
            t._resources = sorted( t._resources, transaction._transaction.rm_cmp )
        t.join = join
        t.join( SpatialDataManager(self) )
        
    def _unregisterDataManager(self):
        self._v_dataManagerRegistered = False

    
    
class SpatialStorage(CustomStorage):
    """ A storage which saves the pages in a BTree mapping """
    def __init__(self, mapping):
        CustomStorage.__init__( self )
        self.mapping = mapping
        self.blockWrites = False

    def create(self, returnError):
        """ Called when the storage is created on the C side """

    def destroy(self, returnError):
        """ Called when the storage is destroyed on the C side """

    def clear(self):
        """ Clear all our data """   
        self.mapping.clear()

    def loadByteArray(self, page, returnError):
        """ Returns the data for page or returns an error """
        log( 'READ page:%s' % page )
        try:
            return self.mapping[page]
        except KeyError:
            returnError.contents.value = self.InvalidPageError

    def storeByteArray(self, page, data, returnError):
        """ Stores the data for page """
        if self.blockWrites:
            log( 'STORE BLOCKED page:%s' % page )
            return page
        if page == self.NewPage:
            newPageId = len(self.mapping)
            log( 'STORE NEW pageId:%s' % newPageId )
            self.mapping[newPageId] = data
            return newPageId
        else:
            log( 'STORE pageId:%s' % page )
            if page not in self.mapping:
                returnError.value = self.InvalidPageError
                return 0
            self.mapping[page] = data
            import struct
            nodes = struct.unpack( 'I', self.mapping[0][8:12] ) if self.mapping else -1
            log( 'STORE nodes:%s' % nodes )
            return page

    def deleteByteArray(self, page, returnError):
        """ Deletes a page """
        log( 'DELETE pageId:%s' % page )
        try:
            del self.mapping[page]
        except KeyError:
            returnError.contents.value = self.InvalidPageError

    hasData = property( lambda self: bool(self.mapping) )
    """ Returns true if this storage contains some data """   



class SpatialDataManager(object):
    transaction_manager = None

    class Savepoint(object):
        def __init__(self, dataManager):
            self.dataManager = dataManager
            self.dataManager.clearBuffer( blockWrites = False )
        
        def rollback(self):
            self.dataManager.clearBuffer( blockWrites = True )

    def __init__(self, spatialIndex):
        self.spatialIndex = spatialIndex
        
    def clearBuffer(self, blockWrites):
        self.spatialIndex.clearBuffer( blockWrites )
        
    def unregister(self):
        self.spatialIndex._unregisterDataManager()

    def abort(self, transaction):
        self.clearBuffer( blockWrites = True )
        self.unregister()
    
    def savepoint(self):
        return self.Savepoint(self)

    def tpc_begin(self, transaction):
        self.clearBuffer( blockWrites = False )

    def commit(self, transaction):
        pass

    def tpc_vote(self, transaction):
        pass

    def tpc_finish(self, transaction):
        self.unregister()

    def tpc_abort(self, transaction):
        self.unregister()

    def sortKey(self):
        import sys
        return -sys.maxint