On Fri, Jul 10, 2009 at 3:55 PM, John Schulman<[email protected]> wrote: > After reading the "Cython for Numpy Users" page, I tried rewriting a > function in cython, which I had previously implemented in python and > then scipy.weave. > It runs just as slow as the python version, which is 100 times slower > than the C version. > Maybe someone can tell me what is slowing it down. > Thanks, > John > > #cython_scripts.py > from __future__ import division > import numpy as np > cimport numpy as np > > DTYPE = np.int > ctypedef np.int_t DTYPE_t > > DTYPE2 = np.double > ctypedef np.double_t DTYPE2_t > > cdef inline int int_max(int a, int b): return a if a >= b else b > cdef inline int int_min(int a, int b): return a if a <= b else b > > def firstpass_labels(np.ndarray[DTYPE2_t,ndim=2] arr,list links,int > s_back,double thresh): > cdef int n_s,n_ch > n_s = arr.shape[0]; n_ch = arr.shape[1] > cdef np.ndarray labels = np.zeros([n_s,n_ch],dtype=DTYPE) > cdef DTYPE_t c_label > c_label = 1 > cdef int i_s, i_ch, j_s, j_ch, j_ch_ind, j_sstart > > cdef np.ndarray links_arr = np.zeros([n_ch,n_ch+1],dtype=DTYPE) > for (source,targs) in enumerate(links): > links_arr[source,0] = len(targs) > links_arr[source,1:(len(targs)+1)] = links[source] > > for i_s in range(n_s): > for i_ch in range(n_ch): > if arr[i_s,i_ch] > thresh: > j_sstart = int_max(0,i_s-s_back) > for j_s in range(j_sstart,i_s+1): > for j_ch_ind in range(1,links_arr[i_ch][0]+1): > j_ch = links_arr[i_ch,j_ch_ind] > if labels[j_s,j_ch] != 0: > labels[i_s,i_ch] = labels[j_s,j_ch] > > if labels[i_s,i_ch] == 0: > labels[i_s,i_ch] = c_label > c_label += 1 > > return labels > > > > And here's the function I use to run it: > #test_cy.py > import time > import numpy as np > > import pyximport; pyximport.install() > import cython_scripts > > > > arr = np.random.random( (400000,3)) > s_back = 3 > thresh = .7 > links2 = [[0,1],[0,1,2],[1,2]] > t = time.time() > print cython_scripts.firstpass_labels(arr,links2,s_back,thresh) > print "cython %f"%(time.time()-t)
I get a X 70-80 speedup with the following changes. You can boost it more by turning off boundschecking & declaring every buffer's mode to be 'c'; see the documentation here: http://docs.cython.org/docs/numpy_tutorial.html#tuning-indexing-further #--------------------------------------------------------------------------- from __future__ import division import numpy as np cimport numpy as np DTYPE = np.int ctypedef np.int_t DTYPE_t DTYPE2 = np.double ctypedef np.double_t DTYPE2_t cdef inline int int_max(int a, int b): return a if a >= b else b cdef inline int int_min(int a, int b): return a if a <= b else b def firstpass_labels(np.ndarray[DTYPE2_t,ndim=2] arr,list links,int s_back,double thresh): cdef int n_s,n_ch n_s = arr.shape[0]; n_ch = arr.shape[1] # XXX: changed! cdef np.ndarray[DTYPE_t, ndim=2] labels = np.zeros([n_s,n_ch],dtype=DTYPE) cdef DTYPE_t c_label c_label = 1 cdef int i_s, i_ch, j_s, j_ch, j_ch_ind, j_sstart #XXX: changed! cdef np.ndarray[DTYPE_t, ndim=2] links_arr = np.zeros([n_ch,n_ch+1],dtype=DTYPE) for (source,targs) in enumerate(links): links_arr[source,0] = len(targs) links_arr[source,1:(len(targs)+1)] = links[source] for i_s in range(n_s): for i_ch in range(n_ch): if arr[i_s,i_ch] > thresh: j_sstart = int_max(0,i_s-s_back) for j_s in range(j_sstart,i_s+1): for j_ch_ind in range(1,links_arr[i_ch,0]+1): j_ch = links_arr[i_ch,j_ch_ind] if labels[j_s,j_ch] != 0: labels[i_s,i_ch] = labels[j_s,j_ch] if labels[i_s,i_ch] == 0: labels[i_s,i_ch] = c_label c_label += 1 return labels #------------------------------------------------------------------------------------------------------------------- _______________________________________________ Cython-dev mailing list [email protected] http://codespeak.net/mailman/listinfo/cython-dev
