This is an automated email from the git hooks/post-receive script. yoh pushed a commit to tag 0.4 in repository python-mne.
commit bf7fa3447926b9962bb3cc00b5aec54be7d6d7d7 Author: Daniel Strohmeier <[email protected]> Date: Wed Jul 18 15:44:30 2012 +0200 added bootstrap and crop function to epochs --- mne/epochs.py | 70 +++++++++++++++++++++++++++++++++++++++++++++--- mne/tests/test_epochs.py | 53 ++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 4 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index a49f60d..ceb9618 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -384,10 +384,12 @@ class Epochs(object): if isinstance(key, slice): epochs._data = self._data[key] else: - #make sure data remains a 3D array - #Note: np.atleast_3d() doesn't do what we want - epochs._data = np.array([self._data[key]]) - + if isinstance(key, list): + key = np.array(key) + if np.ndim(key) == 0: + epochs._data = self._data[key][np.newaxis, :, :] + else: + epochs._data = self._data[key] return epochs def average(self, keep_only_data_channels=True): @@ -441,6 +443,39 @@ class Epochs(object): evoked.info['nchan'] = len(data_picks) evoked.data = evoked.data[data_picks] return evoked + + def crop(self, tmin, tmax): + """Crops a time interval from epochs object. + + Parameters + ---------- + tmin : float + Start time of selection in seconds + tmax : float + End time of selection in seconds + + Returns + ------- + epochs : Epochs instance + The bootstrap samples + """ + if not self.preload: + raise RuntimeError('Modifying data of epochs is only supported ' + 'when preloading is used. Use preload=True ' + 'in the constructor.') + if tmin < self.tmin: + tmin = self.tmin + if tmax > self.tmax: + tmax = self.tmax + + sfreq = self.info['sfreq'] + first_samp = int((tmin - self.tmin) * sfreq) + last_samp = int((tmax - self.tmax) * sfreq) - 1 + + self.tmin = tmin + self.tmax = tmax + self._data = self._data[:, :, first_samp:last_samp] + return self def _is_good(e, ch_names, channel_type_idx, reject, flat): @@ -477,3 +512,30 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat): return False return True + + +def bootstrap(epochs, rng): + """Compute average of epochs selected by bootstrapping + + Parameters + ---------- + epochs : Epochs instance + epochs data to be bootstrapped + rng: + random number generator. + + Returns + ------- + epochs : Epochs instance + The bootstrap samples + """ + if not epochs.preload: + raise RuntimeError('Modifying data of epochs is only supported ' + 'when preloading is used. Use preload=True ' + 'in the constructor.') + + epochs_bootstrap = copy.deepcopy(epochs) + n_events = len(epochs_bootstrap.events) + idx = rng.randint(0, n_events, n_events) + epochs_bootstrap = epochs_bootstrap[idx] + return epochs_bootstrap, idx diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 29bfc1c..d2eb23c 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -8,6 +8,7 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal import numpy as np from .. import fiff, Epochs, read_events, pick_events +from ..epochs import bootstrap raw_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data', 'test_raw.fif') @@ -135,6 +136,23 @@ def test_indexing_slicing(): data = epochs2_sliced[pos].get_data() assert_array_equal(data[0], data_normal[idx]) pos += 1 + + # using indexing with int + idx = np.random.randint(0, data_epochs2_sliced.shape[0], 1) + data = epochs2[idx].get_data() + assert_array_equal(data, data_normal[idx]) + + # using indexing with array + idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10) + data = epochs2[idx].get_data() + assert_array_equal(data, data_normal[idx]) + + # using indexing with list of indices + #idx = list() + #for k in range(3): + # idx.append(np.random.randint(0, data_epochs2_sliced.shape[0], 1)) + # data = epochs2[idx].get_data() + # assert_array_equal(data, data_normal[idx]) def test_comparision_with_c(): @@ -152,3 +170,38 @@ def test_comparision_with_c(): assert_true(evoked.nave == c_evoked.nave) assert_array_almost_equal(evoked_data, c_evoked_data, 10) assert_array_almost_equal(evoked.times, c_evoked.times, 12) + + +def test_crop(): + """Test of crop of epochs + """ + epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks, + baseline=(None, 0), preload=False, + reject=reject, flat=flat) + epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax, + picks=picks, baseline=(None, 0), preload=True, + reject=reject, flat=flat) + data_normal = epochs.get_data() + + # indices for slicing + start_tsamp = tmin + 60 * epochs.info['sfreq'] + end_tsamp = tmax - 60 * epochs.info['sfreq'] + tmask = (epochs.times >= start_tsamp) & (epochs.times <= end_tsamp) + assert((start_tsamp) > tmin) + assert((end_tsamp) < tmax) + epochs2.crop(start_tsamp, end_tsamp) + data = epochs2.get_data() + assert_array_equal(data, data_normal[:, :, tmask]) + + +def test_bootstrap(): + """Test of crop of epochs + """ + epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks, + baseline=(None, 0), preload=True, + reject=reject, flat=flat) + data_normal = epochs._data + rng = np.random.RandomState(0) + epochs2, idx = bootstrap(epochs, rng) + n_events = len(epochs.events) + assert_array_equal(epochs2._data, data_normal[idx]) -- Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git _______________________________________________ debian-med-commit mailing list [email protected] http://lists.alioth.debian.org/cgi-bin/mailman/listinfo/debian-med-commit
