This is an automated email from the git hooks/post-receive script. yoh pushed a commit to annotated tag v0.2 in repository python-mne.
commit 866040876254090dae33d2abd2d8545f8ba7c335 Author: Martin Luessi <[email protected]> Date: Tue Sep 27 17:26:02 2011 -0400 added indexing and slicing operations for epoch --- mne/epochs.py | 85 +++++++++++++++++++++++++++++++++++++++++++++--- mne/tests/test_epochs.py | 40 +++++++++++++++++++++++ 2 files changed, 120 insertions(+), 5 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index c1a52da..a3832cc 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -79,6 +79,17 @@ class Epochs(object): Return Evoked object containing averaged epochs as a 2D array [n_channels x n_times]. + drop_bad_epochs() : None + Drop all epochs marked as bad. Should be used before indexing and + slicing operations. + + Indexing and Slicing: + ------- + epochs = Epochs(...) + + epochs[idx] : Return epoch with index idx (2D array, [n_channels, n_times]) + + epochs[start:stop] : Return Epochs object with a subset of epochs """ def __init__(self, raw, events, event_id, tmin, tmax, baseline=(None, 0), @@ -96,6 +107,7 @@ class Epochs(object): self.preload = preload self.reject = reject self.flat = flat + self.bad_dropped = False # Handle measurement info self.info = copy.deepcopy(raw.info) @@ -183,7 +195,9 @@ class Epochs(object): self._reject_setup() if self.preload: - self._data = self._get_data_from_disk() + self._data, good_events = self._get_data_from_disk() + self.events = self.events[good_events,:] + self.bad_dropped = True def drop_picks(self, bad_picks): """Drop some picks @@ -206,6 +220,28 @@ class Epochs(object): if self.preload: self._data = self._data[:, idx, :] + def drop_bad_epochs(self): + """Drop bad epochs. + + Should be used before slicing operations. + + Warning: Operation is slow since all epochs have to be read from disk + """ + if self.bad_dropped: + return + + good = [] + n_events = len(self.events) + for idx in range(n_events): + epoch = self._get_epoch_from_disk(idx) + if self._is_good_epoch(epoch): + good.append(idx) + + self.events = self.events[good,:] + self.bad_dropped = True + + print "%d bad epochs dropped" % (n_events - len(good)) + def _get_epoch_from_disk(self, idx): """Load one epoch from disk""" sfreq = self.raw.info['sfreq'] @@ -235,18 +271,20 @@ class Epochs(object): data = np.empty((n_events, n_channels, n_times)) cnt = 0 n_reject = 0 + event_idx = [] for k in range(n_events): e = self._get_epoch_from_disk(k) if self._is_good_epoch(e): data[cnt] = self._get_epoch_from_disk(k) + event_idx.append(k) cnt += 1 else: n_reject += 1 print "Rejecting %d epochs." % n_reject - return data[:cnt] + return data[:cnt], event_idx def _is_good_epoch(self, data): - """Determine is epoch is good + """Determine if epoch is good """ n_times = len(self.times) if self.reject is None and self.flat is None: @@ -268,7 +306,8 @@ class Epochs(object): if self.preload: return self._data else: - return self._get_data_from_disk() + data, _ = self._get_data_from_disk() + return data def _reject_setup(self): """Setup reject process @@ -312,12 +351,48 @@ class Epochs(object): return epoch def __repr__(self): - s = "n_events : %s" % len(self.events) + if not self.bad_dropped: + s = "n_events : %s (good & bad)" % len(self.events) + else: + s = "n_events : %s (all good)" % len(self.events) s += ", tmin : %s (s)" % self.tmin s += ", tmax : %s (s)" % self.tmax s += ", baseline : %s" % str(self.baseline) return "Epochs (%s)" % s + def __getslice__(self, start, end): + """Return an Epoch object with a subset of epochs. + """ + if not self.bad_dropped: + print "Warning: bad epochs have not been dropped, indexing will " \ + "be inccurate. Use drop_bad_epochs() or preload=True" + + epoch_slice = copy.copy(self) + epoch_slice.events = self.events[start:end] + + if self.preload: + epoch_slice._data = self._data[start:end] + + return epoch_slice + + def __getitem__(self, index): + """Return epoch at index + """ + if index < 0 or index >= len(self.events): + raise IndexError("Epoch index out of bounds") + + if self.preload: + epoch = epoch = self._data[index] + else: + epoch = self._get_epoch_from_disk(index) + + if not self._is_good_epoch(epoch): + print "Warning: Bad epoch with index %d returned. Use " \ + "drop_bad_epochs() or preload=True to prevent this." \ + % (index) + + return epoch + def average(self): """Compute average of epochs diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 5d97441..94abf4f 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -70,6 +70,46 @@ def test_preload_epochs(): data_no_preload = epochs.get_data() assert_array_equal(data_preload, data_no_preload) +def test_indexing_slicing(): + """Test of indexing and slicing operations + """ + epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks, + baseline=(None, 0), preload=False, + reject=reject, flat=flat) + + data_normal = epochs.get_data() + + n_good_events = data_normal.shape[0] + + # indices for slicing + start_index = 1 + end_index = n_good_events - 1 + + assert((end_index - start_index) > 0) + + for preload in [True, False]: + epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax, + picks=picks, baseline=(None, 0), preload=preload, + reject=reject, flat=flat) + + if not preload: + epochs2.drop_bad_epochs() + + # get slice + epochs2_sliced = epochs2[start_index:end_index] + + # using get_data() + data_epochs2_sliced = epochs2_sliced.get_data() + assert_array_equal(data_epochs2_sliced, \ + data_normal[start_index:end_index]) + + # using indexing + pos = 0 + for idx in range(start_index, end_index): + assert_array_equal(epochs2_sliced[pos], data_normal[idx]) + pos += 1 + + def test_comparision_with_c(): """Test of average obtained vs C code -- Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git _______________________________________________ debian-med-commit mailing list [email protected] http://lists.alioth.debian.org/cgi-bin/mailman/listinfo/debian-med-commit
