# -*- coding: utf-8 -*-
"""
Created on Wed Nov 14 10:39:59 2012

@author: Joe
"""

##############################################################################
#Define Modules and inputs
##############################################################################
import math, time, numpy
import pyopencl as cl
inputs = numpy.array( [100000, 100100, 100200, 100300, 100400, 100500, 100600, 100700] ).astype(numpy.int32)
#inputs = numpy.array( [3, 3, 3, 3, 3, 3, 3, 3] ).astype(numpy.int32)
##############################################################################
#Define Functions
##############################################################################
def isprime(n):
    """Returns True if n is prime and False otherwise"""
    if n < 2:
        return False
    if n == 2:
        return True
    max = int(math.ceil(math.sqrt(n)))
    i = 2
    while i <= max:
        if n % i == 0:
            return False
        i += 1
    return True

def sum_primes(n):
    """Calculates sum of all primes below given integer n"""
    return sum([x for x in xrange(2,n) if isprime(x)])

##############################################################################
#Series
##############################################################################
result_ser = []
start_time_serial = time.time()
for i in range(len(inputs)):
    result_ser.append(sum_primes(inputs[i]))
    
print "Time Serial elapsed: ", time.time() - start_time_serial, "s"
print result_ser
##############################################################################
#OpenCl
##############################################################################
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)

mf = cl.mem_flags

prg = cl.Program(ctx, """

    #define CEILING_POS(X) ((X-(int)(X)) > 0 ? (int)(X+1) : (int)(X))
    #define CEILING_NEG(X) ((X-(int)(X)) < 0 ? (int)(X-1) : (int)(X))
    #define CEILING(X) ( ((X) > 0) ? CEILING_POS(X) : CEILING_NEG(X) )
    
    bool prime(int count)
    {
        if(count < 2)
            return false;
        
        if(count == 2)
            return true;
            
        int count_max = (int)(CEILING(sqrt((double)count)));
        int i = 2;
        
        while(i <= count_max)
        {
            if(count % i == 0)
              return false;

        i = i + 1;
        }
        return true;
    }
    
    __kernel void sum(__global const int *a, __global int *c)
    {
    int gid = get_global_id(0);
    int count;
    int a1 = a[gid];
    int temp2;
    temp2 = 0;
    for( count = 2; count < a1 ; count++)
    {
      if( prime(count) )
           temp2 =  temp2 + count;
    }
    
    c[gid] = temp2;
    }
    """).build()
    

    
results_open = []
start_time = time.time()
a_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=inputs)
dest_buf = cl.Buffer(ctx, mf.READ_WRITE, inputs.nbytes)
prg.sum(queue, inputs.shape, None, a_buf, dest_buf)
a_prime = numpy.empty_like(inputs)
cl.enqueue_copy(queue, a_prime, dest_buf).wait()
queue.finish()

print "Time open elapsed: ", time.time() - start_time, "s"
print a_prime, numpy.sum(numpy.array(result_ser) - a_prime)
