# -*- coding: utf-8 -*-
"""
Created on Thu Nov 15 21:07:58 2012

@author: Joe
"""

##############################################################################
#Define Modules and Inputs
##############################################################################
import time, numpy
import pyopencl as cl
import pyopencl.array

a = numpy.arange(4).reshape(2, 2)+1.0
b = numpy.arange(4).reshape(2, 2)+1.0
a = a.astype(numpy.float32)
b = b.astype(numpy.float32)
##############################################################################
#Functions
##############################################################################
def add(a,b):
    return a + b

def sub(a,b):
    return a - b
    
def mult(a,b):
    return a * b
    
def const_mult(a,b):
    return (a * b * 3.0)
    
def dev(a,b):
    return a / b
    
def const_dev(a,b):
    return (a / b / 3.0)
    
def dot(a,b):
    return numpy.dot(a,b)

def trig(a,b):
    return numpy.sin(a)*numpy.cos(b)
    
functions = [add,sub,mult,const_mult,dev,const_dev,dot,trig]

##############################################################################
#Series
##############################################################################
start_time_serial = time.time()
c_serial = []
time_each_serial = []
for i in range(8):
    c_serial.append(functions[i](a,b))
    
print "Total Time Serial Elapsed: ", time.time() - start_time_serial, "s"

##############################################################################
#OpenCl
##############################################################################
PYOPENCL_COMPILER_OUTPUT=1

ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)

mf = cl.mem_flags
a_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=a)
b_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=b)
dest_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes*2)

prg = cl.Program(ctx, """
    __kernel void sum(__global const float *a, __global const float *b, __global float *c)
    {
      int gid = get_global_id(0);
      c[gid] = a[gid] + b[gid];
    }
    
    __kernel void sub(__global const float *a, __global const float *b, __global float *c)
    {
      int gid = get_global_id(0);
      c[gid] = a[gid] - b[gid];
    }
    
    __kernel void mult(__global const float *a, __global const float *b, __global float *c)
    {
      int gid = get_global_id(0);
      c[gid] = a[gid] * b[gid];
    }
    
    __kernel void const_mult(__global const float *a, __global const float *b, __global float *c)
    {
      int gid = get_global_id(0);
      c[gid] = a[gid] * b[gid] * 3.0;
    }
    
    __kernel void dev(__global const float *a, __global const float *b, __global float *c)
    {
      int gid = get_global_id(0);
      c[gid] = a[gid] / b[gid];
    }
    
    __kernel void const_dev(__global const float *a, __global const float *b, __global float *c)
    {
      int gid = get_global_id(0);
      c[gid] = a[gid]/ b[gid] / 3.0;
    }
    
    __kernel void trig(__global const float *a, __global const float *b, __global float *c)
    {
      int gid = get_global_id(0);
      c[gid] = sin(a[gid]) * cos(b[gid]);
    }
    """).build()

start_time_open = time.time()
c_opencl = []
temp = numpy.empty_like(a)

prg.sum(queue, a.shape, None, a_buf, b_buf, dest_buf)
cl.enqueue_copy(queue, temp, dest_buf)
c_opencl.append(numpy.copy(temp))

prg.sub(queue, a.shape, None, a_buf, b_buf, dest_buf)
cl.enqueue_copy(queue, temp, dest_buf)
c_opencl.append(numpy.copy(temp))

prg.mult(queue, a.shape, None, a_buf, b_buf, dest_buf)
cl.enqueue_copy(queue, temp, dest_buf)
c_opencl.append(numpy.copy(temp))

prg.const_mult(queue, a.shape, None, a_buf, b_buf, dest_buf)
cl.enqueue_copy(queue, temp, dest_buf)
c_opencl.append(numpy.copy(temp))

prg.dev(queue, a.shape, None, a_buf, b_buf, dest_buf)
cl.enqueue_copy(queue, temp, dest_buf)
c_opencl.append(numpy.copy(temp))

prg.const_dev(queue, a.shape, None, a_buf, b_buf, dest_buf)
cl.enqueue_copy(queue, temp, dest_buf)
c_opencl.append(numpy.copy(temp))

a_temp = cl.array.to_device(queue,a)
b_temp = cl.array.to_device(queue,b)
c_opencl.append(cl.array.dot(a_temp,b_temp,numpy.float32,queue))

prg.trig(queue, a.shape, None, a_buf, b_buf, dest_buf)
cl.enqueue_copy(queue, temp, dest_buf)
c_opencl.append(numpy.copy(temp))


print "Time OpenCl elapsed: ", time.time() - start_time_open, "s"