import os
import numpy
import sys
import tables
import time

from datetime import datetime
from dateutil.relativedelta import relativedelta


def GenerateDates(start_date, end_date):
    result = []
    nxt = start_date
    delta = relativedelta(months=1)
    
    while nxt <= end_date:
        result.append(nxt.toordinal())
        nxt += delta

    return result

ALL_DATES = GenerateDates(datetime(2012, 1, 1), datetime(2062, 1, 1))


def main(num_sim):

    start = time.time()

    # Open a file in 'w'rite mode
    h5file = 'tutorial_%d.h5'%num_sim
    fileh = tables.openFile(h5file, mode='w')

    # Get the HDF5 root group
    root = fileh.root

    well_names = ['KB%04d'%d for d in xrange(1, 1201)]

    filters = tables.Filters(complevel=1, complib='blosc')

    table = fileh.createTable(root, 'Wells', {'name'   : tables.StringCol(itemsize=25),
                                              'results': tables.Float32Col(shape=(num_sim, len(ALL_DATES), 7))},
                              'Wells:Results', filters=filters, expectedrows=len(well_names))

    # Fill the table with 257 particles
    for well in well_names:
        well_hdf = table.row
        # First, assign the values to the Particle record
        well_hdf['name'] = well
        # This injects the Record values
        well_hdf.append()

    # Flush the table buffers
    table.flush()

    print 'Number of simulations   : %-4d'%num_sim
    print 'H5 file creation time   : %0.3fs'%(time.time() - start)

    start = time.time()

    for i in xrange(len(table)):
        table.modifyColumn(i, i+1, 1, column=numpy.random.random(size=(num_sim, len(ALL_DATES), 7)), colname='results')
        
    table.flush()
    
    print 'Saving results for table: %0.3fs'%(time.time() - start)
    print 'H5 file size (MB)       : %d\n'%int(os.path.getsize(h5file)/(1024.0**2.0))

    # Finally, close the file (this also will flush all the remaining buffers!)
    fileh.close()


if __name__ == '__main__':
    main(int(sys.argv[1]))

    