This is an automated email from the git hooks/post-receive script. yoh pushed a commit to tag 0.4 in repository python-mne.
commit a7331a3e2677e4579582020898c29394d87ac089 Author: Daniel Strohmeier <[email protected]> Date: Wed Jul 18 17:59:33 2012 +0200 added comments on the pull request --- mne/epochs.py | 67 +++++++++++++++++++++++++++++------------------- mne/tests/test_epochs.py | 47 ++++++++++++++++++--------------- 2 files changed, 66 insertions(+), 48 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index fdc723d..f7199ef 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1,17 +1,20 @@ # Authors: Alexandre Gramfort <[email protected]> # Matti Hamalainen <[email protected]> +# Daniel Strohmeier <[email protected]> # # License: BSD (3-clause) -import copy +import copy as cp +import warnings + import numpy as np + import fiff -import warnings from .fiff import Evoked from .fiff.pick import pick_types, channel_indices_by_type from .fiff.proj import activate_proj, make_eeg_average_ref_proj from .baseline import rescale - +from .utils import check_random_state class Epochs(object): """List of Epochs @@ -122,7 +125,7 @@ class Epochs(object): self._bad_dropped = False # Handle measurement info - self.info = copy.deepcopy(raw.info) + self.info = cp.deepcopy(raw.info) if picks is not None: self.info['chs'] = [self.info['chs'][k] for k in picks] self.info['ch_names'] = [self.info['ch_names'][k] for k in picks] @@ -377,7 +380,7 @@ class Epochs(object): warnings.warn("Bad epochs have not been dropped, indexing will be " "inaccurate. Use drop_bad_epochs() or preload=True") - epochs = copy.copy(self) # XXX : should use deepcopy but breaks ... + epochs = cp.copy(self) # XXX : should use deepcopy but breaks ... epochs.events = np.atleast_2d(self.events[key]) if self.preload: @@ -386,6 +389,8 @@ class Epochs(object): else: if isinstance(key, list): key = np.array(key) + print key + print np.ndim(key) if np.ndim(key) == 0: epochs._data = self._data[key][np.newaxis, :, :] else: @@ -407,7 +412,7 @@ class Epochs(object): The averaged epochs """ evoked = Evoked(None) - evoked.info = copy.deepcopy(self.info) + evoked.info = cp.deepcopy(self.info) n_channels = len(self.ch_names) n_times = len(self.times) if self.preload: @@ -444,7 +449,7 @@ class Epochs(object): evoked.data = evoked.data[data_picks] return evoked - def crop(self, tmin, tmax): + def crop(self, tmin=None, tmax=None, copy=False): """Crops a time interval from epochs object. Parameters @@ -453,7 +458,9 @@ class Epochs(object): Start time of selection in seconds tmax : float End time of selection in seconds - + copy : bool + If False epochs is cropped in place + Returns ------- epochs : Epochs instance @@ -463,19 +470,27 @@ class Epochs(object): raise RuntimeError('Modifying data of epochs is only supported ' 'when preloading is used. Use preload=True ' 'in the constructor.') - if tmin < self.tmin: + if tmin is None: tmin = self.tmin - if tmax > self.tmax: + elif tmin < self.tmin: + warnings.warn("tmin is not in epochs' time interval." + "tmin is set to epochs.tmin") + tmin = self.tmin + if tmax is None: + tmax = self.tmax + elif tmax > self.tmax: + warnings.warn("tmax is not in epochs' time interval." + "tmax is set to epochs.tmax") tmax = self.tmax - sfreq = self.info['sfreq'] - first_samp = int((tmin - self.tmin) * sfreq) - last_samp = int((tmax - self.tmax) * sfreq) - 1 - - self.tmin = tmin - self.tmax = tmax - self._data = self._data[:, :, first_samp:last_samp] - return self + tmask = (self.times >= tmin) & (self.times <= tmax) + + this_epochs = self if not copy else cp.deepcopy(self) + this_epochs.tmin = tmin + this_epochs.tmax = tmax + this_epochs.times = this_epochs.times[tmask] + this_epochs._data = this_epochs._data[:, :, tmask] + return this_epochs def _is_good(e, ch_names, channel_type_idx, reject, flat): @@ -514,15 +529,15 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat): return True -def bootstrap(epochs, rng, return_idx=False): - """Compute average of epochs selected by bootstrapping +def bootstrap(epochs, random_state=None): + """Compute epochs selected by bootstrapping Parameters ---------- epochs : Epochs instance epochs data to be bootstrapped - rng : - random number generator. + random_state : None | int | np.random.RandomState + To specify the random generator state return_idx : bool If True the selected indices are provided as an output @@ -536,11 +551,9 @@ def bootstrap(epochs, rng, return_idx=False): 'when preloading is used. Use preload=True ' 'in the constructor.') - epochs_bootstrap = copy.deepcopy(epochs) + rng = check_random_state(random_state) + epochs_bootstrap = cp.deepcopy(epochs) n_events = len(epochs_bootstrap.events) idx = rng.randint(0, n_events, n_events) epochs_bootstrap = epochs_bootstrap[idx] - if return_idx: - return epochs_bootstrap, idx - else: - return epochs_bootstrap + return epochs_bootstrap diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index c751db3..a7f745b 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -137,22 +137,23 @@ def test_indexing_slicing(): assert_array_equal(data[0], data_normal[idx]) pos += 1 - # using indexing with int + # using indexing with an int idx = np.random.randint(0, data_epochs2_sliced.shape[0], 1) data = epochs2[idx].get_data() assert_array_equal(data, data_normal[idx]) - # using indexing with array + # using indexing with an array idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10) data = epochs2[idx].get_data() assert_array_equal(data, data_normal[idx]) - # using indexing with list of indices - #idx = list() - #for k in range(3): - # idx.append(np.random.randint(0, data_epochs2_sliced.shape[0], 1)) - # data = epochs2[idx].get_data() - # assert_array_equal(data, data_normal[idx]) + # using indexing with a list of indices + idx = [0] + data = epochs2[idx].get_data() + assert_array_equal(data, data_normal[idx]) + idx = [0, 1] + data = epochs2[idx].get_data() + assert_array_equal(data, data_normal[idx]) def test_comparision_with_c(): @@ -175,33 +176,37 @@ def test_comparision_with_c(): def test_crop(): """Test of crop of epochs """ - epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks, + epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=False, reject=reject, flat=flat) - epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax, + data_normal = epochs.get_data() + + epochs2 = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True, reject=reject, flat=flat) - data_normal = epochs.get_data() # indices for slicing start_tsamp = tmin + 60 * epochs.info['sfreq'] end_tsamp = tmax - 60 * epochs.info['sfreq'] tmask = (epochs.times >= start_tsamp) & (epochs.times <= end_tsamp) - assert((start_tsamp) > tmin) - assert((end_tsamp) < tmax) + assert_true(start_tsamp > tmin) + assert_true(end_tsamp < tmax) + epochs3 = epochs2.crop(start_tsamp, end_tsamp, copy=True) + data3 = epochs3.get_data() epochs2.crop(start_tsamp, end_tsamp) - data = epochs2.get_data() - assert_array_equal(data, data_normal[:, :, tmask]) - + data2 = epochs2.get_data() + assert_array_equal(data2, data_normal[:, :, tmask]) + assert_array_equal(data3, data_normal[:, :, tmask]) + def test_bootstrap(): - """Test of crop of epochs + """Test of bootstrapping of epochs """ - epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks, + epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True, reject=reject, flat=flat) data_normal = epochs._data - rng = np.random.RandomState(0) - epochs2, idx = bootstrap(epochs, rng, return_idx=True) + epochs2 = bootstrap(epochs, random_state=0) n_events = len(epochs.events) - assert_array_equal(epochs2._data, data_normal[idx]) + assert_true(len(epochs2.events) == len(epochs.events)) + assert_true(epochs._data.shape == epochs2._data.shape) -- Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git _______________________________________________ debian-med-commit mailing list [email protected] http://lists.alioth.debian.org/cgi-bin/mailman/listinfo/debian-med-commit
