This is an automated email from the git hooks/post-receive script. yoh pushed a commit to annotated tag v0.1 in repository python-mne.
commit 47e17da9884d03ed1a66ed4c2e262f010a34280d Author: Alexandre Gramfort <[email protected]> Date: Tue Mar 1 16:03:24 2011 -0500 ENH : refactoring time frequency for speed up in parallel settings --- examples/time_frequency/plot_time_frequency.py | 2 +- mne/tests/test_tfr.py | 12 +- mne/tfr.py | 152 ++++++++++++++----------- 3 files changed, 93 insertions(+), 73 deletions(-) diff --git a/examples/time_frequency/plot_time_frequency.py b/examples/time_frequency/plot_time_frequency.py index 8498588..d113dc3 100644 --- a/examples/time_frequency/plot_time_frequency.py +++ b/examples/time_frequency/plot_time_frequency.py @@ -50,7 +50,7 @@ evoked_data = np.mean(epochs, axis=0) # compute evoked fields frequencies = np.arange(4, 30, 3) # define frequencies of interest Fs = raw['info']['sfreq'] # sampling in Hz power, phase_lock = time_frequency(epochs, Fs=Fs, frequencies=frequencies, - n_cycles=2) + n_cycles=2, n_jobs=1, use_fft=False) ############################################################################### # View time-frequency plots diff --git a/mne/tests/test_tfr.py b/mne/tests/test_tfr.py index e3f59fe..9d923ac 100644 --- a/mne/tests/test_tfr.py +++ b/mne/tests/test_tfr.py @@ -1,11 +1,10 @@ import numpy as np import os.path as op -from numpy.testing import assert_allclose - import mne from mne import fiff from mne import time_frequency +from mne.tfr import cwt_morlet raw_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data', 'test_raw.fif') @@ -13,7 +12,7 @@ event_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data', 'test-eve.fif') def test_time_frequency(): - """Test IO for STC files + """Test time frequency transform (PSD and phase lock) """ # Set parameters event_id = 1 @@ -35,9 +34,8 @@ def test_time_frequency(): data, times, channel_names = mne.read_epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0)) epochs = np.array([d['epoch'] for d in data]) # as 3D matrix - evoked_data = np.mean(epochs, axis=0) # compute evoked fields - frequencies = np.arange(4, 20, 5) # define frequencies of interest + frequencies = np.arange(6, 20, 5) # define frequencies of interest Fs = raw['info']['sfreq'] # sampling in Hz power, phase_lock = time_frequency(epochs, Fs=Fs, frequencies=frequencies, n_cycles=2, use_fft=True) @@ -54,4 +52,6 @@ def test_time_frequency(): assert power.shape == phase_lock.shape assert np.sum(phase_lock >= 1) == 0 assert np.sum(phase_lock <= 0) == 0 - \ No newline at end of file + + tfr = cwt_morlet(epochs[0], Fs, frequencies, use_fft=True, n_cycles=2) + assert tfr.shape == (len(picks), len(frequencies), len(times)) diff --git a/mne/tfr.py b/mne/tfr.py index 05645a9..6a61cb4 100644 --- a/mne/tfr.py +++ b/mne/tfr.py @@ -69,71 +69,74 @@ def _centered(arr, newsize): return arr[tuple(myslice)] -def _cwt_morlet_fft(x, Fs, freqs, mode="same", Ws=None): +def _cwt_fft(X, Ws, mode="same"): """Compute cwt with fft based convolutions + Return a generator over signals. """ - x = np.asarray(x) - freqs = np.asarray(freqs) + X = np.asarray(X) # Precompute wavelets for given frequency range to save time - n_samples = x.size - n_freqs = freqs.size - - if Ws is None: - Ws = morlet(Fs, freqs) + n_signals, n_times = X.shape + n_freqs = len(Ws) Ws_max_size = max(W.size for W in Ws) - size = n_samples + Ws_max_size - 1 + size = n_times + Ws_max_size - 1 # Always use 2**n-sized FFT fsize = 2**np.ceil(np.log2(size)) - fft_x = fftn(x, [fsize]) - - if mode == "full": - tfr = np.zeros((n_freqs, fsize), dtype=np.complex128) - elif mode == "same" or mode == "valid": - tfr = np.zeros((n_freqs, n_samples), dtype=np.complex128) + # precompute FFTs of Ws + fft_Ws = np.empty((n_freqs, fsize), dtype=np.complex128) for i, W in enumerate(Ws): - ret = ifftn(fft_x * fftn(W, [fsize]))[:n_samples + W.size - 1] - if mode == "valid": - sz = abs(W.size - n_samples) + 1 - offset = (n_samples - sz) / 2 - tfr[i, offset:(offset + sz)] = _centered(ret, sz) - else: - tfr[i] = _centered(ret, n_samples) - return tfr - - -def _cwt_morlet_convolve(x, Fs, freqs, mode='same', Ws=None): + fft_Ws[i] = fftn(W, [fsize]) + + for k, x in enumerate(X): + if mode == "full": + tfr = np.zeros((n_freqs, fsize), dtype=np.complex128) + elif mode == "same" or mode == "valid": + tfr = np.zeros((n_freqs, n_times), dtype=np.complex128) + + fft_x = fftn(x, [fsize]) + for i, W in enumerate(Ws): + ret = ifftn(fft_x * fft_Ws[i])[:n_times + W.size - 1] + if mode == "valid": + sz = abs(W.size - n_times) + 1 + offset = (n_times - sz) / 2 + tfr[i, offset:(offset + sz)] = _centered(ret, sz) + else: + tfr[i, :] = _centered(ret, n_times) + yield tfr + + +def _cwt_convolve(X, Ws, mode='same'): """Compute time freq decomposition with temporal convolutions + Return a generator over signals. """ - x = np.asarray(x) - freqs = np.asarray(freqs) + X = np.asarray(X) - if Ws is None: - Ws = morlet(Fs, freqs) + n_signals, n_times = X.shape + n_freqs = len(Ws) - n_samples = x.size # Compute convolutions - tfr = np.zeros((freqs.size, len(x)), dtype=np.complex128) - for i, W in enumerate(Ws): - ret = np.convolve(x, W, mode=mode) - if mode == "valid": - sz = abs(W.size - n_samples) + 1 - offset = (n_samples - sz) / 2 - tfr[i, offset:(offset + sz)] = ret - else: - tfr[i] = ret - return tfr - - -def cwt_morlet(x, Fs, freqs, use_fft=True, n_cycles=7.0): + for x in X: + tfr = np.zeros((n_freqs, n_times), dtype=np.complex128) + for i, W in enumerate(Ws): + ret = np.convolve(x, W, mode=mode) + if mode == "valid": + sz = abs(W.size - n_times) + 1 + offset = (n_times - sz) / 2 + tfr[i, offset:(offset + sz)] = ret + else: + tfr[i] = ret + yield tfr + + +def cwt_morlet(X, Fs, freqs, use_fft=True, n_cycles=7.0): """Compute time freq decomposition with Morlet wavelets Parameters ---------- - x : array - signal + X : array of shape [n_signals, n_times] + signals (one per line) Fs : float sampling Frequency @@ -143,35 +146,48 @@ def cwt_morlet(x, Fs, freqs, use_fft=True, n_cycles=7.0): Returns ------- - tfr : 2D array - Time Frequency Decomposition (Frequencies x Timepoints) + tfr : 3D array + Time Frequency Decompositions (n_signals x n_frequencies x n_times) """ mode = 'same' # mode = "valid" + n_signals, n_times = X.shape + n_frequencies = len(freqs) # Precompute wavelets for given frequency range to save time Ws = morlet(Fs, freqs, n_cycles=n_cycles) if use_fft: - return _cwt_morlet_fft(x, Fs, freqs, mode, Ws) + coefs = _cwt_fft(X, Ws, mode) else: - return _cwt_morlet_convolve(x, Fs, freqs, mode, Ws) + coefs = _cwt_convolve(X, Ws, mode) + tfrs = np.empty((n_signals, n_frequencies, n_times)) + for k, tfr in enumerate(coefs): + tfrs[k] = tfr -def _time_frequency_one_channel(epochs, c, Fs, frequencies, use_fft, n_cycles): - """Aux of time_frequency for parallel computing""" - n_epochs, _, n_times = epochs.shape - n_frequencies = len(frequencies) - psd_c = np.zeros((n_frequencies, n_times)) # PSD - plf_c = np.zeros((n_frequencies, n_times), dtype=np.complex) # phase lock + return tfrs - for e in range(n_epochs): - tfr = cwt_morlet(epochs[e, c, :].ravel(), Fs, frequencies, - use_fft=use_fft, n_cycles=n_cycles) +def _time_frequency(X, Ws, use_fft): + """Aux of time_frequency for parallel computing over channels + """ + n_epochs, n_times = X.shape + n_frequencies = len(Ws) + psd = np.zeros((n_frequencies, n_times)) # PSD + plf = np.zeros((n_frequencies, n_times), dtype=np.complex) # phase lock + + mode = 'same' + if use_fft: + tfrs = _cwt_fft(X, Ws, mode) + else: + tfrs = _cwt_convolve(X, Ws, mode) + + for tfr in tfrs: tfr_abs = np.abs(tfr) - psd_c += tfr_abs**2 - plf_c += tfr / tfr_abs - return psd_c, plf_c + psd += tfr_abs**2 + plf += tfr / tfr_abs + + return psd, plf def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25, @@ -213,6 +229,9 @@ def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25, n_frequencies = len(frequencies) n_epochs, n_channels, n_times = epochs.shape + # Precompute wavelets for given frequency range to save time + Ws = morlet(Fs, frequencies, n_cycles=n_cycles) + try: import joblib except ImportError: @@ -224,13 +243,14 @@ def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25, plf = np.empty((n_channels, n_frequencies, n_times), dtype=np.complex) for c in range(n_channels): - psd[c,:,:], plf[c,:,:] = _time_frequency_one_channel(epochs, c, Fs, - frequencies, use_fft, n_cycles) + X = np.squeeze(epochs[:,c,:]) + psd[c], plf[c] = _time_frequency(X, Ws, use_fft) + else: from joblib import Parallel, delayed psd_plf = Parallel(n_jobs=n_jobs)( - delayed(_time_frequency_one_channel)( - epochs, c, Fs, frequencies, use_fft, n_cycles) + delayed(_time_frequency)( + np.squeeze(epochs[:,c,:]), Ws, use_fft) for c in range(n_channels)) psd = np.zeros((n_channels, n_frequencies, n_times)) -- Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git _______________________________________________ debian-med-commit mailing list [email protected] http://lists.alioth.debian.org/cgi-bin/mailman/listinfo/debian-med-commit
