This is an automated email from the git hooks/post-receive script. yoh pushed a commit to annotated tag v0.1 in repository python-mne.
commit bda9bd53e3382d342c66ad5d6548e98e9cd82515 Author: Alexandre Gramfort <[email protected]> Date: Wed Jun 8 12:16:05 2011 -0400 factoring joblib parallel code --- examples/stats/plot_cluster_stats_evoked.py | 3 +- mne/minimum_norm/time_frequency.py | 19 +----- mne/parallel.py | 50 +++++++++++++++ mne/stats/cluster_level.py | 95 ++++++++++++++++++----------- mne/stats/permutations.py | 20 +----- mne/time_frequency/tfr.py | 27 +++----- 6 files changed, 123 insertions(+), 91 deletions(-) diff --git a/examples/stats/plot_cluster_stats_evoked.py b/examples/stats/plot_cluster_stats_evoked.py index 34a4ad1..63d344c 100644 --- a/examples/stats/plot_cluster_stats_evoked.py +++ b/examples/stats/plot_cluster_stats_evoked.py @@ -58,7 +58,8 @@ condition2 = condition2[:, 0, :] # take only one channel to get a 2D array threshold = 6.0 T_obs, clusters, cluster_p_values, H0 = \ permutation_cluster_test([condition1, condition2], - n_permutations=1000, threshold=threshold, tail=1) + n_permutations=1000, threshold=threshold, tail=1, + n_jobs=2) ############################################################################### # Plot diff --git a/mne/minimum_norm/time_frequency.py b/mne/minimum_norm/time_frequency.py index 1262669..b36adce 100644 --- a/mne/minimum_norm/time_frequency.py +++ b/mne/minimum_norm/time_frequency.py @@ -10,6 +10,7 @@ from ..source_estimate import SourceEstimate from ..time_frequency.tfr import cwt, morlet from ..baseline import rescale from .inverse import combine_xyz, prepare_inverse_operator +from ..parallel import parallel_func def _compute_power(data, K, sel, Ws, source_ori, use_fft, Vh): @@ -89,23 +90,7 @@ def source_induced_power(epochs, inverse_operator, bands, lambda2=1.0 / 9.0, Number of jobs to run in parallel """ - if n_jobs == -1: - try: - import multiprocessing - n_jobs = multiprocessing.cpu_count() - except ImportError: - print "multiprocessing not installed. Cannot run in parallel." - n_jobs = 1 - - try: - from scikits.learn.externals.joblib import Parallel, delayed - parallel = Parallel(n_jobs) - my_compute_power = delayed(_compute_power) - except ImportError: - print "joblib not installed. Cannot run in parallel." - n_jobs = 1 - my_compute_power = _compute_power - parallel = list + parallel, my_compute_power, n_jobs = parallel_func(_compute_power, n_jobs) # # Set up the inverse according to the parameters diff --git a/mne/parallel.py b/mne/parallel.py new file mode 100644 index 0000000..12e4164 --- /dev/null +++ b/mne/parallel.py @@ -0,0 +1,50 @@ +"""Parralle util function +""" + +# Author: Alexandre Gramfort <[email protected]> +# +# License: Simplified BSD + + +def parallel_func(func, n_jobs, verbose=5): + """Return parallel instance with delayed function + + Util function to use joblib only if available + + Parameters + ---------- + func: callable + A function + n_jobs: int + Number of jobs to run in parallel + verbose: int + Verbosity level + + Returns + ------- + parallel: instance of joblib.Parallel or list + The parallel object + my_func: callable + func if not parallel or delayed(func) + n_jobs: int + Number of jobs >= 0 + """ + try: + from scikits.learn.externals.joblib import Parallel, delayed + parallel = Parallel(n_jobs, verbose=verbose) + my_func = delayed(func) + + if n_jobs == -1: + try: + import multiprocessing + n_jobs = multiprocessing.cpu_count() + except ImportError: + print "multiprocessing not installed. Cannot run in parallel." + n_jobs = 1 + + except ImportError: + print "joblib not installed. Cannot run in parallel." + n_jobs = 1 + my_func = func + parallel = list + return parallel, my_func, n_jobs diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py index c1bda6c..eea5e10 100644 --- a/mne/stats/cluster_level.py +++ b/mne/stats/cluster_level.py @@ -10,6 +10,7 @@ import numpy as np from scipy import stats, sparse, ndimage from .parametric import f_oneway +from ..parallel import parallel_func def _get_components(x_in, connectivity): @@ -123,9 +124,23 @@ def _pval_from_histogram(T, H0, tail): return pval +def _one_permutation(X_full, slices, stat_fun, tail, threshold, connectivity): + np.random.shuffle(X_full) + X_shuffle_list = [X_full[s] for s in slices] + T_obs_surr = stat_fun(*X_shuffle_list) + _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail, + connectivity) + + if len(perm_clusters_sums) > 0: + return np.max(perm_clusters_sums) + else: + return 0 + + def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67, n_permutations=1000, tail=0, - connectivity=None, verbose=True): + connectivity=None, n_jobs=1, + verbose=5): """Cluster-level statistical permutation test For a list of 2d-arrays of data, e.g. power values, calculate some @@ -154,8 +169,10 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67, Defines connectivity between features. The matrix is assumed to be symmetric and only the upper triangular half is used. Defaut is None, i.e, no connectivity. - verbose: boolean - If True print some text. + verbose : int + If > 0, print some text during computation. + n_jobs : int + Number of permutations to run in parallel (requires joblib package.) Returns ------- @@ -195,24 +212,16 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67, slices = [slice(splits_idx[k], splits_idx[k + 1]) for k in range(len(X))] + parallel, my_one_permutation, _ = parallel_func(_one_permutation, n_jobs, + verbose) + # Step 2: If we have some clusters, repeat process on permuted data # ------------------------------------------------------------------- if len(clusters) > 0: - H0 = np.zeros(n_permutations) # histogram - for i_s in range(n_permutations): - if verbose: - print "Permutation %d / %d" % (i_s + 1, n_permutations) - np.random.shuffle(X_full) - X_shuffle_list = [X_full[s] for s in slices] - T_obs_surr = stat_fun(*X_shuffle_list) - _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail, - connectivity) - - if len(perm_clusters_sums) > 0: - H0[i_s] = np.max(perm_clusters_sums) - else: - H0[i_s] = 0 - + H0 = parallel(my_one_permutation(X_full, slices, stat_fun, tail, + threshold, connectivity) + for _ in range(n_permutations)) + H0 = np.array(H0) cluster_pv = _pval_from_histogram(cluster_stats, H0, tail) return T_obs, clusters, cluster_pv, H0 else: @@ -229,9 +238,28 @@ def ttest_1samp(X): return T +def _one_1samp_permutation(n_samples, shape_ones, X_copy, threshold, tail, + connectivity, stat_fun): + # new surrogate data with random sign flip + signs = np.sign(0.5 - np.random.rand(n_samples, *shape_ones)) + X_copy *= signs + + # Recompute statistic on randomized data + T_obs_surr = stat_fun(X_copy) + _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail, + connectivity) + + if len(perm_clusters_sums) > 0: + idx_max = np.argmax(np.abs(perm_clusters_sums)) + return perm_clusters_sums[idx_max] # get max with sign info + else: + return 0.0 + + def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000, tail=0, stat_fun=ttest_1samp, - connectivity=None): + connectivity=None, n_jobs=1, + verbose=5): """Non-parametric cluster-level 1 sample T-test From a array of observations, e.g. signal amplitudes or power spectrum @@ -259,6 +287,11 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000, Defines connectivity between features. The matrix is assumed to be symmetric and only the upper triangular half is used. Defaut is None, i.e, no connectivity. + verbose : int + If > 0, print some text during computation. + n_jobs : int + Number of permutations to run in parallel (requires joblib package.) + Returns ------- @@ -294,26 +327,16 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000, clusters, cluster_stats = _find_clusters(T_obs, threshold, tail, connectivity) + parallel, my_one_1samp_permutation, _ = parallel_func(_one_1samp_permutation, + n_jobs, verbose) + # Step 2: If we have some clusters, repeat process on permuted data # ------------------------------------------------------------------- if len(clusters) > 0: - H0 = np.empty(n_permutations) # histogram - for i_s in range(n_permutations): - # new surrogate data with random sign flip - signs = np.sign(0.5 - np.random.rand(n_samples, *shape_ones)) - X_copy *= signs - - # Recompute statistic on randomized data - T_obs_surr = stat_fun(X_copy) - _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail, - connectivity) - - if len(perm_clusters_sums) > 0: - idx_max = np.argmax(np.abs(perm_clusters_sums)) - H0[i_s] = perm_clusters_sums[idx_max] # get max with sign info - else: - H0[i_s] = 0 - + H0 = parallel(my_one_1samp_permutation(n_samples, shape_ones, X_copy, + threshold, tail, connectivity, stat_fun) + for _ in range(n_permutations)) + H0 = np.array(H0) cluster_pv = _pval_from_histogram(cluster_stats, H0, tail) return T_obs, clusters, cluster_pv, H0 diff --git a/mne/stats/permutations.py b/mne/stats/permutations.py index 59c9e74..39e141c 100644 --- a/mne/stats/permutations.py +++ b/mne/stats/permutations.py @@ -9,6 +9,8 @@ from math import sqrt import numpy as np +from ..parallel import parallel_func + def bin_perm_rep(ndim, a=0, b=1): """bin_perm_rep(ndim) -> ndim permutations with repetitions of (a,b). @@ -128,23 +130,7 @@ def permutation_t_test(X, n_permutations=10000, tail=0, n_jobs=1): else: perms = np.sign(0.5 - np.random.rand(n_permutations, n_samples)) - try: - from scikits.learn.externals.joblib import Parallel, delayed - parallel = Parallel(n_jobs) - my_max_stat = delayed(_max_stat) - except ImportError: - print "joblib not installed. Cannot run in parallel." - n_jobs = 1 - my_max_stat = _max_stat - parallel = list - - if n_jobs == -1: - try: - import multiprocessing - n_jobs = multiprocessing.cpu_count() - except ImportError: - print "multiprocessing not installed. Cannot run in parallel." - n_jobs = 1 + parallel, my_max_stat, n_jobs = parallel_func(_max_stat, n_jobs) max_abs = np.concatenate(parallel(my_max_stat(X, X2, p, dof_scaling) for p in np.array_split(perms, n_jobs))) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index ff97ed4..6af02ec 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -12,6 +12,7 @@ import numpy as np from scipy import linalg from scipy.fftpack import fftn, ifftn from ..baseline import rescale +from ..parallel import parallel_func def morlet(Fs, freqs, n_cycles=7, sigma=None): @@ -276,15 +277,7 @@ def single_trial_power(epochs, Fs, frequencies, use_fft=True, n_cycles=7, # Precompute wavelets for given frequency range to save time Ws = morlet(Fs, frequencies, n_cycles=n_cycles) - try: - from scikits.learn.externals.joblib import Parallel, delayed - parallel = Parallel(n_jobs) - my_cwt = delayed(cwt) - except ImportError: - print "joblib not installed. Cannot run in parallel." - n_jobs = 1 - my_cwt = cwt - parallel = list + parallel, my_cwt, _ = parallel_func(cwt, n_jobs) print "Computing time-frequency power on single epochs..." @@ -347,13 +340,9 @@ def induced_power(epochs, Fs, frequencies, use_fft=True, n_cycles=7, # Precompute wavelets for given frequency range to save time Ws = morlet(Fs, frequencies, n_cycles=n_cycles) - try: - import joblib - except ImportError: - print "joblib not installed. Cannot run in parallel." - n_jobs = 1 + parallel, my_time_frequency, _ = parallel_func(_time_frequency, n_jobs) - if n_jobs == 1: + if my_time_frequency is _time_frequency: # not parallel psd = np.empty((n_channels, n_frequencies, n_times)) plf = np.empty((n_channels, n_frequencies, n_times), dtype=np.complex) @@ -362,11 +351,9 @@ def induced_power(epochs, Fs, frequencies, use_fft=True, n_cycles=7, psd[c], plf[c] = _time_frequency(X, Ws, use_fft) else: - from joblib import Parallel, delayed - psd_plf = Parallel(n_jobs=n_jobs)( - delayed(_time_frequency)( - np.squeeze(epochs[:, c, :]), Ws, use_fft) - for c in range(n_channels)) + psd_plf = parallel(my_time_frequency(np.squeeze(epochs[:, c, :]), + Ws, use_fft) + for c in range(n_channels)) psd = np.zeros((n_channels, n_frequencies, n_times)) plf = np.zeros((n_channels, n_frequencies, n_times), dtype=np.complex) -- Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git _______________________________________________ debian-med-commit mailing list [email protected] http://lists.alioth.debian.org/cgi-bin/mailman/listinfo/debian-med-commit
