from tables import openFile, Expr

### Control of the number of threads used when issuing the
### Expr::eval() command
#import numexpr
#numexpr.set_num_threads(2)

def create_ntuple_file(filename, npoints, pmodel):
    '''
        create an hdf5 file with a single table which contains
        npoints number of rows of type row_t (defined below)
    '''
    from numpy import random, poly1d
    from tables import IsDescription, Float32Col

    class row_t(IsDescription):
        '''
            the rows of the table to be created
        '''
        a = Float32Col()
        b = Float32Col()

    def append_row(h5row, pmodel):
        '''
            consider this a single "event" being appended
            to the dataset (table)
        '''
        h5row['a'] = random.uniform(0,10)

        h5row['b'] = h5row['a'] # reality (or model)
        h5row['b'] = h5row['b'] - poly1d(pmodel)(h5row['a']) # systematics
        h5row['b'] = h5row['b'] + random.normal(0,0.1) # noise

        h5row.append()

    h5file = openFile(filename, 'w')
    h5table = h5file.createTable('/', 'table', row_t, "Data")
    h5row = h5table.row

    # recording data to file...
    for n in xrange(npoints):
        append_row(h5row, pmodel)

    h5file.close()

def create_ntuple_file_if_needed(filename, npoints, pmodel):
    '''
        looks to see if the file is already there and if so,
        it makes sure its the right size. Otherwise, it
        removes the existing file and creates a new one.
    '''
    from os import path, remove

    print 'model parameters:', pmodel

    if path.exists(filename):
        h5file = openFile(filename, 'r')
        h5table = h5file.root.table
        if len(h5table) != npoints:
            h5file.close()
            remove(filename)

    if not path.exists(filename):
        create_ntuple_file(filename, npoints, pmodel)

def fn(p, h5table):
    '''
        actual function we are going to minimize. It consists of
        the pytables Table object and a list of parameters.
    '''
    uv = h5table.colinstances

    # store parameters in a dict object with names
    # like p0, p1, p2, etc. so they can be used in
    # the Expr object.
    for i in xrange(len(p)):
        k = 'p'+str(i)
        uv[k] = p[i]

    # systematic shift on b is a polynomial in a
    db = 'p0 * a*a  +  p1 * a  +  p2'

    # the element-wise function
    fn_str = '(a - (b + %s))**2' % db

    expr = Expr(fn_str,uservars=uv)
    expr.eval()

    # returning the "sum of squares"
    return sum(expr)

if __name__ == '__main__':
    '''
    usage:
        python pytables_expr_test.py [npoints]

    Hint: try this with 10M points
    '''
    from sys import argv
    from time import time

    npoints = 1000000
    if len(argv) > 1:
        npoints = int(argv[1])

    filename = 'tmp.'+str(npoints)+'.hdf5'

    pmodel = [-0.04,0.002,0.001]

    print 'creating file (if it doesn\'t exist)...'
    create_ntuple_file_if_needed(filename, npoints, pmodel)

    h5file = openFile(filename, 'r')
    h5table = h5file.root.table

    print 'evaluating function'
    starttime = time()
    print fn([0.,0.,0.], h5table)
    print 'evaluated file in',time()-starttime,'seconds.'
