import numpy as np
import tables
import cPickle
import zlib
import blosc
import os
from time import time


N = 1000  # the number of arrays
L = 100*1000  # the length of arrays
zlib_lvl = 6  # the default compression level for zlib
blosc_lvl = 9  # the default compression level for blosc

def write_a_hdf(arrays):
    h5file = tables.openFile("out/test_a.h5", "w")
    for index, array in enumerate(arrays):
        h5file.createArray(h5file.root, 'array%i' % index, array)
    h5file.close()

def write_ca_hdf(arrays, clib):
    h5file = tables.openFile("out/test_ca_%s.h5"%clib, "w")
    clevel = zlib_lvl if clib == "zlib" else blosc_lvl
    ca = h5file.createCArray(h5file.root, 'carray',
                             tables.Float64Atom(), (N,) + arrays[0].shape,
                             filters=tables.Filters(clevel, clib))
    for index, array in enumerate(arrays):
        ca[index] = array
    h5file.close()

def write_ea_hdf(arrays, clib):
    h5file = tables.openFile("out/test_ea_%s.h5"%clib, "w")
    clevel = zlib_lvl if clib == "zlib" else blosc_lvl
    ea = h5file.createEArray(h5file.root, 'earray',
                             tables.Float64Atom(), (0,) + arrays[0].shape,
                             filters=tables.Filters(clevel, clib))
    ea.append(arrays)
    h5file.close()

def write_vl_hdf(arrays):
    h5file = tables.openFile("out/test_vl.h5", "w")
    vlarray = h5file.createVLArray(h5file.root, "vlarray",
                                   tables.Float64Atom())
    for array in arrays:
        vlarray.append(array)
    h5file.close()

def write_vl_zlib_hdf(arrays):
    h5file = tables.openFile("out/test_vl_zlib.h5", "w")
    vlarray = h5file.createVLArray(h5file.root, "vlarray",
                                   tables.VLStringAtom())
    for array in arrays:
        parray = cPickle.dumps(array, cPickle.HIGHEST_PROTOCOL)
        zarray = zlib.compress(parray, zlib_lvl)
        vlarray.append(zarray)
    h5file.close()

def write_vl_blosc_hdf(arrays):
    h5file = tables.openFile("out/test_vl_blosc.h5", "w")
    vlarray = h5file.createVLArray(h5file.root, "vlarray",
                                   tables.VLStringAtom())
    for array in arrays:
        parray = cPickle.dumps(array, cPickle.HIGHEST_PROTOCOL)
        zarray = blosc.compress(parray, array.dtype.itemsize, blosc_lvl)
        vlarray.append(zarray)
    h5file.close()

def write_vl_blosc2_hdf(arrays):
    h5file = tables.openFile("out/test_vl_blosc2.h5", "w")
    vlarray = h5file.createVLArray(h5file.root, "vlarray",
                                   tables.VLStringAtom())
    for array in arrays:
        vlarray.append(blosc.pack_array(array))
    h5file.close()

def read_a_hdf():
    h5file = tables.openFile("out/test_a.h5", "r")
    out = list()
    for node in h5file.iterNodes(h5file.root):
        out.append(node.read())
    h5file.close()
    return out

def read_ca_hdf(clib):
    h5file = tables.openFile("out/test_ca_%s.h5"%clib, "r")
    out = [h5file.root.carray[i] for i in xrange(N)]
    h5file.close()
    h5file.close()
    return out

def read_ea_hdf(clib):
    h5file = tables.openFile("out/test_ea_%s.h5"%clib, "r")
    out = [h5file.root.earray[i] for i in xrange(N)]
    h5file.close()
    return out

def read_vl_hdf():
    h5file = tables.openFile("out/test_vl.h5", "r")
    out = h5file.root.vlarray.read()
    h5file.close()
    return out

def read_vl_zlib_hdf():
    h5file = tables.openFile("out/test_vl_zlib.h5", "r")
    out = list()
    for row in h5file.root.vlarray:
        parray = zlib.decompress(row)
        array = cPickle.loads(parray)
        out.append(array)
    h5file.close()
    return out

def read_vl_blosc_hdf():
    h5file = tables.openFile("out/test_vl_blosc.h5", "r")
    out = list()
    for row in h5file.root.vlarray:
        parray = blosc.decompress(row)
        array = cPickle.loads(parray)
        out.append(array)
    h5file.close()
    return out

def read_vl_blosc2_hdf():
    h5file = tables.openFile("out/test_vl_blosc2.h5", "r")
    out = list()
    for row in h5file.root.vlarray:
        out.append(blosc.unpack_array(row))
    h5file.close()
    return out


if __name__ == "__main__":
    arrays = [np.linspace(0, i, L) for i in range(1,N+1)]
    
    if not os.path.exists("out"):
        os.mkdir("out")

    t0 = time()
    write_a_hdf(arrays)
    print "Time to write (A):", round(time()-t0, 3)
    t0 = time()
    arrays2 = read_a_hdf()
    print "Time to read (A):", round(time()-t0, 3)
    assert len(arrays2) == N

    for clib in ('zlib', 'blosc'):
        t0 = time()
        write_ca_hdf(arrays, clib)
        print "Time to write (CA, %s): %.3f" % (clib, time()-t0)
        t0 = time()
        arrays2 = read_ca_hdf(clib)
        print "Time to read (CA, %s): %.3f" % (clib, time()-t0)
        assert len(arrays2) == N

        t0 = time()
        write_ea_hdf(arrays, clib)
        print "Time to write (EA, %s): %.3f" % (clib, time()-t0)
        t0 = time()
        arrays2 = read_ea_hdf(clib)
        print "Time to read (EA, %s): %.3f" % (clib, time()-t0)
        assert len(arrays2) == N

    t0 = time()
    write_vl_hdf(arrays)
    print "Time to write (VL):", round(time()-t0, 3)
    t0 = time()
    arrays2 = read_vl_hdf()
    print "Time to read (VL):", round(time()-t0, 3)
    assert len(arrays2) == N

    t0 = time()
    write_vl_zlib_hdf(arrays)
    print "Time to write (VL, zlib):", round(time()-t0, 3)
    t0 = time()
    arrays2 = read_vl_zlib_hdf()
    print "Time to read (VL, zlib):", round(time()-t0, 3)
    assert len(arrays2) == N

    t0 = time()
    write_vl_blosc_hdf(arrays)
    print "Time to write (VL, blosc):", round(time()-t0, 3)
    t0 = time()
    arrays2 = read_vl_blosc_hdf()
    print "Time to read (VL, blosc):", round(time()-t0, 3)
    assert len(arrays2) == N

    t0 = time()
    write_vl_blosc2_hdf(arrays)
    print "Time to write (VL, blosc2):", round(time()-t0, 3)
    t0 = time()
    arrays2 = read_vl_blosc2_hdf()
    print "Time to read (VL, blosc2):", round(time()-t0, 3)
    assert len(arrays2) == N
