Skip to content

Commit 9194f27

Browse files
committed
Multiple minor fixes:
Save bad channels and save annotations moved to separate functions. _save_psds moved to public domain save_psds In dashboard set reference after channel interpolation Add return in ICA plots.
1 parent a44b8cc commit 9194f27

File tree

3 files changed

+50
-32
lines changed

3 files changed

+50
-32
lines changed

sleepeeg/base.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -141,37 +141,56 @@ def plot(
141141
"""
142142
kwargs.setdefault("theme", "dark")
143143
kwargs.setdefault("bad_color", "r")
144-
kwargs.setdefault("scalings", "auto")
145144
kwargs["block"] = True
146145

147146
self.mne_raw.plot(**kwargs)
148147

149148
if save_annotations:
150-
self.mne_raw.annotations.save(
151-
self.output_dir / self.__class__.__name__ / "annotations.txt",
152-
overwrite=overwrite,
153-
)
149+
self.save_annotations(overwrite=overwrite)
154150

155151
if save_bad_channels:
156-
new_bads = self.mne_raw.info["bads"]
157-
158-
if new_bads:
159-
from natsort import natsorted
160-
161-
fpath = self.output_dir / self.__class__.__name__ / "bad_channels.txt"
162-
old_bads = []
163-
if fpath.exists():
164-
with open(fpath, "r") as f:
165-
old_bads = f.read().split()
166-
167-
with open(fpath, "w") as f:
168-
bads = (
169-
natsorted(new_bads)
170-
if overwrite
171-
else natsorted(set(old_bads + new_bads))
172-
)
173-
for bad in bads:
174-
f.write(f"{bad}\n")
152+
self.save_bad_channels(overwrite=overwrite)
153+
154+
def save_bad_channels(self, overwrite=False):
155+
"""Adds bad channels from info["bads"] to the "bad_channels.txt" file.
156+
157+
Args:
158+
overwrite: Whether to overwrite the file if exists.
159+
If False will add unique new channels to the file.
160+
Defaults to False.
161+
"""
162+
new_bads = self.mne_raw.info["bads"]
163+
164+
if new_bads:
165+
from natsort import natsorted
166+
167+
fpath = self.output_dir / self.__class__.__name__ / "bad_channels.txt"
168+
old_bads = []
169+
if fpath.exists():
170+
with open(fpath, "r") as f:
171+
old_bads = f.read().split()
172+
173+
with open(fpath, "w") as f:
174+
bads = (
175+
natsorted(new_bads)
176+
if overwrite
177+
else natsorted(set(old_bads + new_bads))
178+
)
179+
for bad in bads:
180+
f.write(f"{bad}\n")
181+
182+
def save_annotations(self, overwrite=False):
183+
"""Writes annotations to "annotations.txt" file.
184+
185+
Args:
186+
overwrite: Whether to overwrite the file if exists.
187+
If False and the file exists will throw an exception.
188+
Defaults to False.
189+
"""
190+
self.mne_raw.annotations.save(
191+
self.output_dir / self.__class__.__name__ / "annotations.txt",
192+
overwrite=overwrite,
193+
)
175194

176195
def plot_sensors(
177196
self, legend: Iterable[str] = None, legend_args: dict = None, **kwargs
@@ -1143,7 +1162,7 @@ def plot_topomap_collage(
11431162
)
11441163

11451164
@logger_wraps()
1146-
def _save_psds(self, overwrite):
1165+
def save_psds(self, overwrite):
11471166
import re
11481167

11491168
for stage, spectrum in self.psds.items():

sleepeeg/dashboard.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,6 @@ def create_dashboard(
286286
bandpass_filter_freqs[1],
287287
)
288288

289-
pipe.set_eeg_reference(ref_channels=ref_channels)
290-
291289
if path_to_bad_channels is not None:
292290
pipe.read_bad_channels(path=path_to_bad_channels)
293291
bads = pipe.mne_raw.info["bads"]
@@ -297,6 +295,7 @@ def create_dashboard(
297295

298296
bads = literal_eval(pipe.mne_raw.info["description"])
299297

298+
pipe.set_eeg_reference(ref_channels=ref_channels)
300299
s_pipe, sleep_stages = _init_s_pipe(pipe, path_to_hypnogram, hypno_freq)
301300
min_psd, max_psd = _hypno_psd(
302301
s_pipe,

sleepeeg/pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,15 +222,15 @@ def fit(self, filter_kwargs: dict = None, **fit_kwargs):
222222

223223
def plot_sources(self, **kwargs):
224224
"""A wrapper for :py:meth:`mne:mne.preprocessing.ICA.plot_sources`."""
225-
self.mne_ica.plot_sources(inst=self.mne_raw, **kwargs)
225+
return self.mne_ica.plot_sources(inst=self.mne_raw, **kwargs)
226226

227227
def plot_components(self, **kwargs):
228228
"""A wrapper for :py:meth:`mne:mne.preprocessing.ICA.plot_components`."""
229-
self.mne_ica.plot_components(inst=self.mne_raw, **kwargs)
229+
return self.mne_ica.plot_components(inst=self.mne_raw, **kwargs)
230230

231231
def plot_properties(self, picks=None, **kwargs):
232232
"""A wrapper for :py:meth:`mne:mne.preprocessing.ICA.plot_properties`."""
233-
self.mne_ica.plot_properties(self.mne_raw, picks=picks, **kwargs)
233+
return self.mne_ica.plot_properties(self.mne_raw, picks=picks, **kwargs)
234234

235235
@logger_wraps()
236236
def apply(self, exclude=None, **kwargs):
@@ -321,7 +321,7 @@ def compute_psds_per_stage(
321321
**psd_kwargs,
322322
)
323323
if save:
324-
self._save_psds(overwrite)
324+
self.save_psds(overwrite)
325325

326326
@logger_wraps()
327327
def read_spectra(self, dirpath: str | None = None):
@@ -740,4 +740,4 @@ def compute_psds_per_stage(
740740
self.psds[stage] = avg_func(spectra, axis=0)
741741

742742
if save:
743-
self._save_psds(overwrite)
743+
self.save_psds(overwrite)

0 commit comments

Comments
 (0)