After reading the "Cython for Numpy Users" page, I tried rewriting a
function in cython, which I had previously implemented in python and
then scipy.weave.
It runs just as slow as the python version, which is 100 times slower
than the C version.
Maybe someone can tell me what is slowing it down.
Thanks,
John

#cython_scripts.py
from __future__ import division
import numpy as np
cimport numpy as np

DTYPE = np.int
ctypedef np.int_t DTYPE_t

DTYPE2 = np.double
ctypedef np.double_t DTYPE2_t

cdef inline int int_max(int a, int b): return a if a >= b else b
cdef inline int int_min(int a, int b): return a if a <= b else b

def firstpass_labels(np.ndarray[DTYPE2_t,ndim=2] arr,list links,int
s_back,double thresh):
    cdef int n_s,n_ch
    n_s = arr.shape[0]; n_ch = arr.shape[1]
    cdef np.ndarray labels = np.zeros([n_s,n_ch],dtype=DTYPE)
    cdef DTYPE_t c_label
    c_label = 1
    cdef int i_s, i_ch, j_s, j_ch, j_ch_ind, j_sstart

    cdef np.ndarray links_arr = np.zeros([n_ch,n_ch+1],dtype=DTYPE)
    for (source,targs) in enumerate(links):
        links_arr[source,0] = len(targs)
        links_arr[source,1:(len(targs)+1)] = links[source]

    for i_s in range(n_s):
        for i_ch in range(n_ch):
            if arr[i_s,i_ch] > thresh:
                j_sstart = int_max(0,i_s-s_back)
                for j_s in range(j_sstart,i_s+1):
                    for j_ch_ind in range(1,links_arr[i_ch][0]+1):
                        j_ch = links_arr[i_ch,j_ch_ind]
                        if labels[j_s,j_ch] != 0:
                            labels[i_s,i_ch] = labels[j_s,j_ch]

                if labels[i_s,i_ch] == 0:
                    labels[i_s,i_ch] = c_label
                    c_label += 1

    return labels



And here's the function I use to run it:
#test_cy.py
import time
import numpy as np

import pyximport; pyximport.install()
import cython_scripts



arr = np.random.random( (400000,3))
s_back = 3
thresh = .7
links2 = [[0,1],[0,1,2],[1,2]]
t = time.time()
print cython_scripts.firstpass_labels(arr,links2,s_back,thresh)
print "cython %f"%(time.time()-t)
_______________________________________________
Cython-dev mailing list
[email protected]
http://codespeak.net/mailman/listinfo/cython-dev

Reply via email to