diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo
index 18a90b6..a07b054 100644
--- a/docs/build/html/.buildinfo
+++ b/docs/build/html/.buildinfo
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: 7d440e1bf4ab7dd814452f500decfcbf
+config: 7410ac572283522d5aaa00048d56ca4f
tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/docs/build/html/_modules/index.html b/docs/build/html/_modules/index.html
index 974f3c5..893126a 100644
--- a/docs/build/html/_modules/index.html
+++ b/docs/build/html/_modules/index.html
@@ -3,11 +3,14 @@
from.spectralimportstft_powerfrom.numbaimport_detrend,_rmsfrom.ioimportset_log_level,is_tensorpac_installed,is_pyriemann_installed
-from.othersimport(moving_transform,trimbothstd,get_centered_indices,
- sliding_window,_merge_close,_zerocrossings)
-
-
-logger=logging.getLogger('yasa')
-
-__all__=['art_detect','spindles_detect','SpindlesResults','sw_detect','SWResults',
- 'rem_detect','REMResults']
+from.othersimport(
+ moving_transform,
+ trimbothstd,
+ get_centered_indices,
+ sliding_window,
+ _merge_close,
+ _zerocrossings,
+)
+
+
+logger=logging.getLogger("yasa")
+
+__all__=[
+ "art_detect",
+ "spindles_detect",
+ "SpindlesResults",
+ "sw_detect",
+ "SWResults",
+ "rem_detect",
+ "REMResults",
+ "compare_detection",
+]############################################################################## DATA PREPROCESSING#############################################################################
+
def_check_data_hypno(data,sf=None,ch_names=None,hypno=None,include=None,check_amp=True):"""Helper functions for preprocessing of data and hypnogram."""# 1) Extract data as a 2D NumPy arrayifisinstance(data,mne.io.BaseRaw):
- sf=data.info['sfreq']# Extract sampling frequency
+ sf=data.info["sfreq"]# Extract sampling frequencych_names=data.ch_names# Extract channel names
- data=data.get_data()*1e6# Convert from V to uV
+ data=data.get_data(units=dict(eeg="uV",emg="uV",eog="uV",ecg="uV"))else:
- assertsfisnotNone,'sf must be specified if not using MNE Raw.'
+ assertsfisnotNone,"sf must be specified if not using MNE Raw."
+ ifisinstance(sf,np.ndarray):# Deal with sf = array(100.) --> 100
+ sf=float(sf)
+ assertisinstance(sf,(int,float)),"sf must be int or float."data=np.asarray(data,dtype=np.float64)
- assertdata.ndimin[1,2],'data must be 1D (times) or 2D (chan, times).'
+ assertdata.ndimin[1,2],"data must be 1D (times) or 2D (chan, times)."ifdata.ndim==1:# Force to 2D array: (n_chan, n_samples)data=data[None,...]
@@ -145,40 +166,41 @@
Source code for yasa.detection
# 2) Check channel namesifch_namesisNone:
- ch_names=['CHAN'+str(i).zfill(3)foriinrange(n_chan)]
+ ch_names=["CHAN"+str(i).zfill(3)foriinrange(n_chan)]else:assertlen(ch_names)==n_chan# 3) Check hypnogramifhypnoisnotNone:hypno=np.asarray(hypno,dtype=int)
- asserthypno.ndim==1,'Hypno must be one dimensional.'
- asserthypno.size==n_samples,'Hypno must have same size as data.'
+ asserthypno.ndim==1,"Hypno must be one dimensional."
+ asserthypno.size==n_samples,"Hypno must have same size as data."unique_hypno=np.unique(hypno)
- logger.info('Number of unique values in hypno = %i',unique_hypno.size)
- assertincludeisnotNone,'include cannot be None if hypno is given'
+ logger.info("Number of unique values in hypno = %i",unique_hypno.size)
+ assertincludeisnotNone,"include cannot be None if hypno is given"include=np.atleast_1d(np.asarray(include))
- assertinclude.size>=1,'`include` must have at least one element.'
- asserthypno.dtype.kind==include.dtype.kind,('hypno and include must have same dtype')
- assertnp.in1d(hypno,include).any(),('None of the stages specified '
- 'in `include` are present in '
- 'hypno.')
+ assertinclude.size>=1,"`include` must have at least one element."
+ asserthypno.dtype.kind==include.dtype.kind,"hypno and include must have same dtype"
+ assertnp.in1d(hypno,include).any(),(
+ "None of the stages specified ""in `include` are present in ""hypno."
+ )# 4) Check data amplitude
- logger.info('Number of samples in data = %i',n_samples)
- logger.info('Sampling frequency = %.2f Hz',sf)
- logger.info('Data duration = %.2f seconds',n_samples/sf)
+ logger.info("Number of samples in data = %i",n_samples)
+ logger.info("Sampling frequency = %.2f Hz",sf)
+ logger.info("Data duration = %.2f seconds",n_samples/sf)all_ptp=np.ptp(data,axis=-1)all_trimstd=trimbothstd(data,cut=0.05)bad_chan=np.zeros(n_chan,dtype=bool)foriinrange(n_chan):
- logger.info('Trimmed standard deviation of %s = %.4f uV'%(ch_names[i],all_trimstd[i]))
- logger.info('Peak-to-peak amplitude of %s = %.4f uV'%(ch_names[i],all_ptp[i]))
- ifcheck_ampandnot(0.1<all_trimstd[i]<1e3):
- logger.error('Wrong data amplitude for %s '
- '(trimmed STD = %.3f). Unit of data MUST be uV! '
- 'Channel will be skipped.'
- %(ch_names[i],all_trimstd[i]))
+ logger.info("Trimmed standard deviation of %s = %.4f uV"%(ch_names[i],all_trimstd[i]))
+ logger.info("Peak-to-peak amplitude of %s = %.4f uV"%(ch_names[i],all_ptp[i]))
+ ifcheck_ampandnot(0.1<all_trimstd[i]<1e3):
+ logger.error(
+ "Wrong data amplitude for %s "
+ "(trimmed STD = %.3f). Unit of data MUST be uV! "
+ "Channel will be skipped."%(ch_names[i],all_trimstd[i])
+ )bad_chan[i]=True# 5) Create sleep stage vector mask
@@ -218,50 +240,56 @@
aggdict["DistanceSpindleToSW"]=aggfuncelse:# REM
- aggdict={'Start':'count',
- 'Duration':aggfunc,
- 'LOCAbsValPeak':aggfunc,
- 'ROCAbsValPeak':aggfunc,
- 'LOCAbsRiseSlope':aggfunc,
- 'ROCAbsRiseSlope':aggfunc,
- 'LOCAbsFallSlope':aggfunc,
- 'ROCAbsFallSlope':aggfunc}
+ aggdict={
+ "Start":"count",
+ "Duration":aggfunc,
+ "LOCAbsValPeak":aggfunc,
+ "ROCAbsValPeak":aggfunc,
+ "LOCAbsRiseSlope":aggfunc,
+ "ROCAbsRiseSlope":aggfunc,
+ "LOCAbsFallSlope":aggfunc,
+ "ROCAbsFallSlope":aggfunc,
+ }# Apply grouping, after maskingdf_grp=self._events.loc[mask,:].groupby(grouper,sort=sort,as_index=False).agg(aggdict)
- df_grp=df_grp.rename(columns={'Start':'Count'})
+ df_grp=df_grp.rename(columns={"Start":"Count"})# Calculate density (= number per min of each stage)ifself._hypnoisnotNoneandgrp_stageisTrue:
- stages=np.unique(self._events['Stage'])
+ stages=np.unique(self._events["Stage"])dur={}forstinstages:# Get duration in minutes of each stage present in dataframe
@@ -292,37 +322,42 @@
Source code for yasa.detection
# Insert new density column in grouped dataframe after countdf_grp.insert(
- loc=df_grp.columns.get_loc('Count')+1,column='Density',
- value=df_grp.apply(lambdarw:rw['Count']/dur[rw['Stage']],axis=1))
+ loc=df_grp.columns.get_loc("Count")+1,
+ column="Density",
+ value=df_grp.apply(lambdarw:rw["Count"]/dur[rw["Stage"]],axis=1),
+ )returndf_grp.set_index(grouper)defget_mask(self):"""get_mask"""fromyasa.othersimport_index_to_events
+
mask=np.zeros(self._data.shape,dtype=int)
- foriinself._events['IdxChannel'].unique():
- ev_chan=self._events[self._events['IdxChannel']==i]
- idx_ev=_index_to_events(
- ev_chan[['Start','End']].to_numpy()*self._sf)
+ foriinself._events["IdxChannel"].unique():
+ ev_chan=self._events[self._events["IdxChannel"]==i]
+ idx_ev=_index_to_events(ev_chan[["Start","End"]].to_numpy()*self._sf)mask[i,idx_ev]=1returnnp.squeeze(mask)
- defget_sync_events(self,center,time_before,time_after,filt=(None,None),mask=None,
- as_dataframe=True):
+ defget_sync_events(
+ self,center,time_before,time_after,filt=(None,None),mask=None,as_dataframe=True
+ ):"""Get_sync_events (not for REM, spindles & SW only)"""fromyasa.othersimportget_centered_indices
+
asserttime_before>=0asserttime_after>=0bef=int(self._sf*time_before)aft=int(self._sf*time_after)# TODO: Step size is determined by sf: 0.01 sec at 100 Hz, 0.002 sec at# 500 Hz, 0.00390625 sec at 256 Hz. Should we add resample=100 (Hz) or step_size=0.01?
- time=np.arange(-bef,aft+1,dtype='int')/self._sf
+ time=np.arange(-bef,aft+1,dtype="int")/self._sfifany(filt):data=mne.filter.filter_data(
- self._data,self._sf,l_freq=filt[0],h_freq=filt[1],method='fir',verbose=False)
+ self._data,self._sf,l_freq=filt[0],h_freq=filt[1],method="fir",verbose=False
+ )else:data=self._data
@@ -332,18 +367,19 @@
Source code for yasa.detection
output=[]
- foriinmasked_events['IdxChannel'].unique():
+ foriinmasked_events["IdxChannel"].unique():# Copy is required to merge with the stage later on
- ev_chan=masked_events[masked_events['IdxChannel']==i].copy()
- ev_chan['Event']=np.arange(ev_chan.shape[0])
+ ev_chan=masked_events[masked_events["IdxChannel"]==i].copy()
+ ev_chan["Event"]=np.arange(ev_chan.shape[0])peaks=(ev_chan[center]*self._sf).astype(int).to_numpy()# Get centered indicesidx,idx_valid=get_centered_indices(data[i,:],peaks,bef,aft)# If no good epochs are returned raise a warningiflen(idx_valid)==0:logger.error(
- 'Time before and/or time after exceed data bounds, please '
- 'lower the temporal window around center. Skipping channel.')
+ "Time before and/or time after exceed data bounds, please "
+ "lower the temporal window around center. Skipping channel."
+ )continue# Get data at indices and time vector
@@ -356,15 +392,15 @@
coincidence=(x*y).sum()ifscaled:# Handle division by zero error
- denom=(x.sum()*y.sum())
+ denom=x.sum()*y.sum()ifdenom==0:coincidence=np.nanelse:
@@ -402,31 +438,144 @@
Source code for yasa.detection
returncoinc_mat
- defplot_average(self,event_type,center='Peak',hue='Channel',time_before=1,
- time_after=1,filt=(None,None),mask=None,figsize=(6,4.5),**kwargs):
+ defcompare_channels(self,score="f1",max_distance_sec=0):
+ """
+ Compare detected events across channels.
+ See full documentation in the methods of SpindlesResults and SWResults.
+ """
+ fromitertoolsimportproduct
+
+ assertscorein["f1","precision","recall"],f"Invalid scoring metric: {score}"
+
+ # Extract events and channel
+ detected=self.summary()
+ chan=detected["Channel"].unique()
+
+ # Get indices of start in deciseconds, rounding to nearest deciseconds (100 ms).
+ # This is needed for three reasons:
+ # 1. Speed up the for loop
+ # 2. Avoid memory error in yasa.compare_detection
+ # 3. Make sure that max_distance works even when self and other have different sf.
+ # TODO: Only the Start of the event is currently supported. Add more flexibility?
+ detected["Start"]=(detected["Start"]*10).round().astype(int)
+ max_distance=int(10*max_distance_sec)
+
+ # Initialize output dataframe / dict
+ scores=pd.DataFrame(index=chan,columns=chan,dtype=float)
+ scores.index.name="Channel"
+ scores.columns.name="Channel"
+ pairs=list(product(chan,repeat=2))
+
+ # Loop across pair of channels
+ forc_index,c_colinpairs:
+ idx_chan1=detected[detected["Channel"]==c_index]["Start"]
+ idx_chan2=detected[detected["Channel"]==c_col]["Start"]
+ # DANGER: Note how we invert idx_chan2 and idx_chan1 here. This is because
+ # idx_chan1 (the index of the dataframe) should be the ground-truth.
+ res=compare_detection(idx_chan2,idx_chan1,max_distance)
+ scores.loc[c_index,c_col]=res[score]
+
+ returnscores
+
+ defcompare_detection(self,other,max_distance_sec=0,other_is_groundtruth=True):
+ """
+ Compare detected events between two detection methods, or against a ground-truth scoring.
+ See full documentation in the methods of SpindlesResults and SWResults.
+ """
+ detected=self.summary()
+ ifisinstance(other,(SpindlesResults,SWResults,REMResults)):
+ groundtruth=other.summary()
+ elifisinstance(other,pd.DataFrame):
+ assert"Start"inother.columns
+ assert"Channel"inother.columns
+ groundtruth=other[["Start","Channel"]].copy()
+ else:
+ raiseValueError(
+ f"Invalid argument other: {other}. It must be a YASA detection output or a Pandas "
+ f"DataFrame with the columns Start and Channels"
+ )
+
+ # Get indices of start in deciseconds, rounding to nearest deciseconds (100 ms).
+ # This is needed for three reasons:
+ # 1. Speed up the for loop
+ # 2. Avoid memory error in yasa.compare_detection
+ # 3. Make sure that max_distance works even when self and other have different sf.
+ detected["Start"]=(detected["Start"]*10).round().astype(int)
+ groundtruth["Start"]=(groundtruth["Start"]*10).round().astype(int)
+ max_distance=int(10*max_distance_sec)
+
+ # Find channels that are present in both self and other
+ chan_detected=detected["Channel"].unique()
+ chan_groundtruth=groundtruth["Channel"].unique()
+ chan_both=np.intersect1d(chan_detected,chan_groundtruth)# Sort
+
+ ifnotlen(chan_both):
+ raiseValueError(
+ f"No intersecting channel between self and other:\n"
+ f"{chan_detected}\n{chan_groundtruth}"
+ )
+
+ # The output is a pandas.DataFrame (n_chan, n_metrics).
+ scores=pd.DataFrame(
+ index=chan_both,columns=["precision","recall","f1","n_self","n_other"],dtype=float
+ )
+ scores.index.name="Channel"
+
+ # Loop on each channel
+ forc_indexinchan_both:
+ idx_detected=detected[detected["Channel"]==c_index]["Start"]
+ idx_groundtruth=groundtruth[groundtruth["Channel"]==c_index]["Start"]
+ ifother_is_groundtruth:
+ res=compare_detection(idx_detected,idx_groundtruth,max_distance)
+ else:
+ res=compare_detection(idx_groundtruth,idx_detected,max_distance)
+ scores.loc[c_index,"precision"]=res["precision"]
+ scores.loc[c_index,"recall"]=res["recall"]
+ scores.loc[c_index,"f1"]=res["f1"]
+ scores.loc[c_index,"n_self"]=len(idx_detected)
+ scores.loc[c_index,"n_other"]=len(idx_groundtruth)
+
+ scores["n_self"]=scores["n_self"].astype(int)
+ scores["n_other"]=scores["n_other"].astype(int)
+
+ returnscores
+
+ defplot_average(
+ self,
+ event_type,
+ center="Peak",
+ hue="Channel",
+ time_before=1,
+ time_after=1,
+ filt=(None,None),
+ mask=None,
+ figsize=(6,4.5),
+ **kwargs,
+ ):"""Plot the average event (not for REM, spindles & SW only)"""importseabornassnsimportmatplotlib.pyplotasplt
- df_sync=self.get_sync_events(center=center,time_before=time_before,
- time_after=time_after,filt=filt,mask=mask)
+ df_sync=self.get_sync_events(
+ center=center,time_before=time_before,time_after=time_after,filt=filt,mask=mask
+ )assertnotdf_sync.empty,"Could not calculate event-locked data."
- asserthuein['Stage','Channel'],"hue must be 'Channel' or 'Stage'"
+ asserthuein["Stage","Channel"],"hue must be 'Channel' or 'Stage'"asserthueindf_sync.columns,"%s is not present in data."%hue
- ifevent_type=='spindles':
+ ifevent_type=="spindles":title="Average spindle"else:# "sw":title="Average SW"# Start figurefig,ax=plt.subplots(1,1,figsize=figsize)
- sns.lineplot(data=df_sync,x='Time',y='Amplitude',hue=hue,ax=ax,**kwargs)
+ sns.lineplot(data=df_sync,x="Time",y="Amplitude",hue=hue,ax=ax,**kwargs)# ax.legend(frameon=False, loc='lower right')
- ax.set_xlim(df_sync['Time'].min(),df_sync['Time'].max())
+ ax.set_xlim(df_sync["Time"].min(),df_sync["Time"].max())ax.set_title(title)
- ax.set_xlabel('Time (sec)')
- ax.set_ylabel('Amplitude (uV)')
+ ax.set_xlabel("Time (sec)")
+ ax.set_ylabel("Amplitude (uV)")returnaxdefplot_detection(self):
@@ -452,19 +601,15 @@
"""set_log_level(verbose)
- (data,sf,ch_names,hypno,include,mask,n_chan,n_samples,bad_chan
- )=_check_data_hypno(data,sf,ch_names,hypno,include)
+ (data,sf,ch_names,hypno,include,mask,n_chan,n_samples,bad_chan)=_check_data_hypno(
+ data,sf,ch_names,hypno,include
+ )# If all channels are badifsum(bad_chan)==n_chan:
- logger.warning('All channels have bad amplitude. Returning None.')
+ logger.warning("All channels have bad amplitude. Returning None.")returnNone# Check detection thresholds
- if'rel_pow'notinthresh.keys():
- thresh['rel_pow']=0.20
- if'corr'notinthresh.keys():
- thresh['corr']=0.65
- if'rms'notinthresh.keys():
- thresh['rms']=1.5
- do_rel_pow=thresh['rel_pow']notin[None,"none","None"]
- do_corr=thresh['corr']notin[None,"none","None"]
- do_rms=thresh['rms']notin[None,"none","None"]
+ if"rel_pow"notinthresh.keys():
+ thresh["rel_pow"]=0.20
+ if"corr"notinthresh.keys():
+ thresh["corr"]=0.65
+ if"rms"notinthresh.keys():
+ thresh["rms"]=1.5
+ do_rel_pow=thresh["rel_pow"]notin[None,"none","None"]
+ do_corr=thresh["corr"]notin[None,"none","None"]
+ do_rms=thresh["rms"]notin[None,"none","None"]n_thresh=sum([do_rel_pow,do_corr,do_rms])
- assertn_thresh>=1,'At least one threshold must be defined.'
+ assertn_thresh>=1,"At least one threshold must be defined."# Filteringnfast=next_fast_len(n_samples)# 1) Broadband bandpass filter (optional -- careful of lower freq for PAC)
- data_broad=filter_data(data,sf,freq_broad[0],freq_broad[1],method='fir',verbose=0)
+ data_broad=filter_data(data,sf,freq_broad[0],freq_broad[1],method="fir",verbose=0)# 2) Sigma bandpass filter# The width of the transition band is set to 1.5 Hz on each side,# meaning that for freq_sp = (12, 15 Hz), the -6 dB points are located at# 11.25 and 15.75 Hz.data_sigma=filter_data(
- data,sf,freq_sp[0],freq_sp[1],l_trans_bandwidth=1.5,h_trans_bandwidth=1.5,
- method='fir',verbose=0)
+ data,
+ sf,
+ freq_sp[0],
+ freq_sp[1],
+ l_trans_bandwidth=1.5,
+ h_trans_bandwidth=1.5,
+ method="fir",
+ verbose=0,
+ )# Hilbert power (to define the instantaneous frequency / power)analytic=signal.hilbert(data_sigma,N=nfast)[:,:n_samples]inst_phase=np.angle(analytic)inst_pow=np.square(np.abs(analytic))
- inst_freq=(sf/(2*np.pi)*np.diff(inst_phase,axis=-1))
+ inst_freq=sf/(2*np.pi)*np.diff(inst_phase,axis=-1)# Extract the SO signal for coupling# if coupling:
@@ -770,7 +933,8 @@
Source code for yasa.detection
# Note that even if the threshold is None we still need to calculate it# for the individual spindles parameter (RelPow).f,t,Sxx=stft_power(
- data_broad[i,:],sf,window=2,step=.2,band=freq_broad,interp=False,norm=True)
+ data_broad[i,:],sf,window=2,step=0.2,band=freq_broad,interp=False,norm=True
+ )idx_sigma=np.logical_and(f>=freq_sp[0],f<=freq_sp[1])rel_pow=Sxx[idx_sigma].sum(0)
@@ -778,39 +942,47 @@
Source code for yasa.detection
# Note that we could also have use the `interp=True` in the# `stft_power` function, however 2D interpolation is much slower than# 1D interpolation.
- func=interp1d(t,rel_pow,kind='cubic',bounds_error=False,fill_value=0)
+ func=interp1d(t,rel_pow,kind="cubic",bounds_error=False,fill_value=0)t=np.arange(n_samples)/sfrel_pow=func(t)ifdo_corr:
- _,mcorr=moving_transform(x=data_sigma[i,:],y=data_broad[i,:],sf=sf,window=.3,
- step=.1,method='corr',interp=True)
+ _,mcorr=moving_transform(
+ x=data_sigma[i,:],
+ y=data_broad[i,:],
+ sf=sf,
+ window=0.3,
+ step=0.1,
+ method="corr",
+ interp=True,
+ )ifdo_rms:
- _,mrms=moving_transform(x=data_sigma[i,:],sf=sf,window=.3,step=.1,method='rms',
- interp=True)
+ _,mrms=moving_transform(
+ x=data_sigma[i,:],sf=sf,window=0.3,step=0.1,method="rms",interp=True
+ )# Let's define the thresholdsifhypnoisNone:
- thresh_rms=mrms.mean()+thresh['rms']*trimbothstd(mrms,cut=0.10)
+ thresh_rms=mrms.mean()+thresh["rms"]*trimbothstd(mrms,cut=0.10)else:
- thresh_rms=mrms[mask].mean()+thresh['rms']*trimbothstd(mrms[mask],cut=0.10)
+ thresh_rms=mrms[mask].mean()+thresh["rms"]*trimbothstd(mrms[mask],cut=0.10)# Avoid too high threshold caused by Artefacts / Motion during Wakethresh_rms=min(thresh_rms,10)
- logger.info('Moving RMS threshold = %.3f',thresh_rms)
+ logger.info("Moving RMS threshold = %.3f",thresh_rms)# Boolean vector of supra-threshold indicesidx_sum=np.zeros(n_samples)ifdo_rel_pow:
- idx_rel_pow=(rel_pow>=thresh['rel_pow']).astype(int)
+ idx_rel_pow=(rel_pow>=thresh["rel_pow"]).astype(int)idx_sum+=idx_rel_pow
- logger.info('N supra-theshold relative power = %i',idx_rel_pow.sum())
+ logger.info("N supra-theshold relative power = %i",idx_rel_pow.sum())ifdo_corr:
- idx_mcorr=(mcorr>=thresh['corr']).astype(int)
+ idx_mcorr=(mcorr>=thresh["corr"]).astype(int)idx_sum+=idx_mcorr
- logger.info('N supra-theshold moving corr = %i',idx_mcorr.sum())
+ logger.info("N supra-theshold moving corr = %i",idx_mcorr.sum())ifdo_rms:idx_mrms=(mrms>=thresh_rms).astype(int)idx_sum+=idx_mrms
- logger.info('N supra-theshold moving RMS = %i',idx_mrms.sum())
+ logger.info("N supra-theshold moving RMS = %i",idx_mrms.sum())# Make sure that we do not detect spindles outside maskifhypnoisnotNone:
@@ -823,7 +995,7 @@
Source code for yasa.detection
# Sampling frequecy = 256 Hz --> w = 25 samples = 97 msw=int(0.1*sf)# Critical bugfix March 2022, see https://github.com/raphaelvallat/yasa/pull/55
- idx_sum=np.convolve(idx_sum,np.ones(w),mode='same')/w
+ idx_sum=np.convolve(idx_sum,np.ones(w),mode="same")/w# And we then find indices that are strictly greater than 2, i.e. we# find the 'true' beginning and 'true' end of the events by finding# where at least two out of the three treshold were crossed.
@@ -831,7 +1003,7 @@
Source code for yasa.detection
# If no events are found, skip to next channelifnotlen(where_sp):
- logger.warning('No spindle were found in channel %s.',ch_names[i])
+ logger.warning("No spindle were found in channel %s.",ch_names[i])continue# Merge events that are too close
@@ -849,7 +1021,7 @@
Source code for yasa.detection
# If no events of good duration are found, skip to next channelifall(~good_dur):
- logger.warning('No spindle were found in channel %s.',ch_names[i])
+ logger.warning("No spindle were found in channel %s.",ch_names[i])continue# Initialize empty variables
@@ -886,7 +1058,8 @@
Source code for yasa.detection
# Number of oscillationspeaks,peaks_params=signal.find_peaks(
- sp_det,distance=distance,prominence=(None,None))
+ sp_det,distance=distance,prominence=(None,None)
+ )sp_osc[j]=len(peaks)# For frequency and amplitude, we can also optionally use these
@@ -897,7 +1070,7 @@
Source code for yasa.detection
# Peak location & symmetry index# pk is expressed in sample since the beginning of the spindle
- pk=peaks[peaks_params['prominences'].argmax()]
+ pk=peaks[peaks_params["prominences"].argmax()]sp_pro[j]=sp_start[j]+pk/sfsp_sym[j]=pk/sp_det.size
@@ -910,70 +1083,82 @@
Source code for yasa.detection
sp_sta[j]=hypno[sp[j]][0]# Create a dataframe
- sp_params={'Start':sp_start,
- 'Peak':sp_pro,
- 'End':sp_end,
- 'Duration':sp_dur,
- 'Amplitude':sp_amp,
- 'RMS':sp_rms,
- 'AbsPower':sp_abs,
- 'RelPower':sp_rel,
- 'Frequency':sp_freq,
- 'Oscillations':sp_osc,
- 'Symmetry':sp_sym,
- # 'SOPhase': sp_cou,
- 'Stage':sp_sta}
+ sp_params={
+ "Start":sp_start,
+ "Peak":sp_pro,
+ "End":sp_end,
+ "Duration":sp_dur,
+ "Amplitude":sp_amp,
+ "RMS":sp_rms,
+ "AbsPower":sp_abs,
+ "RelPower":sp_rel,
+ "Frequency":sp_freq,
+ "Oscillations":sp_osc,
+ "Symmetry":sp_sym,
+ # 'SOPhase': sp_cou,
+ "Stage":sp_sta,
+ }df_chan=pd.DataFrame(sp_params)[good_dur]# We need at least 50 detected spindles to apply the Isolation Forest.ifremove_outliersanddf_chan.shape[0]>=50:
- col_keep=['Duration','Amplitude','RMS','AbsPower','RelPower',
- 'Frequency','Oscillations','Symmetry']
+ col_keep=[
+ "Duration",
+ "Amplitude",
+ "RMS",
+ "AbsPower",
+ "RelPower",
+ "Frequency",
+ "Oscillations",
+ "Symmetry",
+ ]ilf=IsolationForest(
- contamination='auto',max_samples='auto',verbose=0,random_state=42)
+ contamination="auto",max_samples="auto",verbose=0,random_state=42
+ )good=ilf.fit_predict(df_chan[col_keep])good[good==-1]=0
- logger.info('%i outliers were removed in channel %s.'
- %((good==0).sum(),ch_names[i]))
+ logger.info(
+ "%i outliers were removed in channel %s."%((good==0).sum(),ch_names[i])
+ )# Remove outliers from DataFramedf_chan=df_chan[good.astype(bool)]
- logger.info('%i spindles were found in channel %s.'
- %(df_chan.shape[0],ch_names[i]))
+ logger.info("%i spindles were found in channel %s."%(df_chan.shape[0],ch_names[i]))# ##################################################################### END SINGLE CHANNEL DETECTION# ####################################################################
- df_chan['Channel']=ch_names[i]
- df_chan['IdxChannel']=i
+ df_chan["Channel"]=ch_names[i]
+ df_chan["IdxChannel"]=idf=pd.concat([df,df_chan],axis=0,ignore_index=True)# If no spindles were detected, return Noneifdf.empty:
- logger.warning('No spindles were found in data. Returning None.')
+ logger.warning("No spindles were found in data. Returning None.")returnNone# Remove useless columnsto_drop=[]ifhypnoisNone:
- to_drop.append('Stage')
+ to_drop.append("Stage")else:
- df['Stage']=df['Stage'].astype(int)
+ df["Stage"]=df["Stage"].astype(int)# if not coupling:# to_drop.append('SOPhase')iflen(to_drop):df=df.drop(columns=to_drop)# Find spindles that are present on at least two channels
- ifmulti_onlyanddf['Channel'].nunique()>1:
+ ifmulti_onlyanddf["Channel"].nunique()>1:# We round to the nearest secondidx_good=np.logical_or(
- df['Start'].round(0).duplicated(keep=False),
- df['End'].round(0).duplicated(keep=False)).to_list()
+ df["Start"].round(0).duplicated(keep=False),df["End"].round(0).duplicated(keep=False)
+ ).to_list()df=df[idx_good].reset_index(drop=True)
- returnSpindlesResults(events=df,data=data,sf=sf,ch_names=ch_names,
- hypno=hypno,data_filt=data_sigma)
[docs]defsummary(self,grp_chan=False,grp_stage=False,mask=None,aggfunc="mean",sort=True):"""Return a summary of the spindles detection, optionally grouped across channels and/or stage.
@@ -1018,8 +1203,14 @@
Source code for yasa.detection
sort : bool If True, sort group keys when grouping. """
- returnsuper().summary(event_type='spindles',grp_chan=grp_chan,grp_stage=grp_stage,
- aggfunc=aggfunc,sort=sort,mask=mask)
[docs]defcompare_channels(self,score="f1",max_distance_sec=0):
+ """
+ Compare detected spindles across channels.
+
+ This is a wrapper around the :py:func:`yasa.compare_detection` function. Please
+ refer to the documentation of this function for more details.
+
+ Parameters
+ ----------
+ score : str
+ The performance metric to compute. Accepted values are "precision", "recall"
+ (aka sensitivity) and "f1" (default). The F1-score is the harmonic mean of precision
+ and recall, and is usually the preferred metric to evaluate the agreement between
+ two channels. All three metrics are bounded by 0 and 1, where 1 indicates perfect
+ agreement.
+ max_distance_sec : float
+ The maximum distance between spindles, in seconds, to consider as the same event.
+
+ .. warning:: To reduce computation cost, YASA rounds the start time of each spindle to
+ the nearest decisecond (= 100 ms). This means that the lowest possible resolution
+ is 100 ms, regardless of the sampling frequency of the data. Two spindles starting
+ at 500 ms and 540 ms on their respective channels will therefore always be
+ considered the same event, even when max_distance_sec=0.
+
+ Returns
+ -------
+ scores : :py:class:`pandas.DataFrame`
+ A Pandas DataFrame with the output scores, of shape (n_chan, n_chan).
+
+ Notes
+ -----
+ Some use cases of this function:
+
+ 1. What proportion of spindles detected in one channel are also detected on
+ another channel (if using ``score="recall"``).
+ 2. What is the overall agreement in the detected events between channels?
+ 3. Is the agreement better in channels that are close to one another?
+ """
+ returnsuper().compare_channels(score,max_distance_sec)
+
+
[docs]defcompare_detection(self,other,max_distance_sec=0,other_is_groundtruth=True):
+ """
+ Compare the detected spindles against either another YASA detection or against custom
+ annotations (e.g. ground-truth human scoring).
+
+ This function is a wrapper around the :py:func:`yasa.compare_detection` function. Please
+ refer to the documentation of this function for more details.
+
+ Parameters
+ ----------
+ other : dataframe or detection results
+ This can be either a) the output of another YASA detection, for example if you want to
+ test the impact of tweaking some parameters on the detected events or b) a pandas
+ DataFrame with custom annotations, obtained by another detection method outside
+ of YASA, or with manual labelling. If b), the dataframe must contain the "Start" and
+ "Channel" columns, with the start of each event in seconds from the beginning
+ of the recording and the channel name, respectively. The channel names should match
+ the output of the summary() method.
+ max_distance_sec : float
+ The maximum distance between spindles, in seconds, to consider as the same event.
+
+ .. warning:: To reduce computation cost, YASA rounds the start time of each spindle to
+ the nearest decisecond (= 100 ms). This means that the lowest possible resolution
+ is 100 ms, regardless of the sampling frequency of the data.
+ other_is_groundtruth : bool
+ If True (default), ``other`` will be considered as the ground-truth scoring. If False,
+ the current detection will be considered as the ground-truth, and the precision and
+ recall scores will be inverted. This parameter has no effect on the F1-score.
+
+ .. note:: when ``other`` is the ground-truth (default), the recall score is the
+ fraction of events in other that were succesfully detected by the current
+ detection, and the precision score is the proportion of detected events by the
+ current detection that are also present in other.
+
+ Returns
+ -------
+ scores : :py:class:`pandas.DataFrame`
+ A Pandas DataFrame with the channel names as index, and the following columns
+
+ * ``precision``: Precision score, aka positive predictive value
+ * ``recall``: Recall score, aka sensitivity
+ * ``f1``: F1-score
+ * ``n_self``: Number of detected events in ``self`` (current method).
+ * ``n_other``: Number of detected events in ``other``.
+
+ Notes
+ -----
+ Some use cases of this function:
+
+ 1. How well does YASA events detection perform against ground-truth human annotations?
+ 2. If I change the threshold(s) of the events detection, do the detected events match
+ those obtained with the default parameters?
+ 3. Which detection thresholds give the highest agreement with the ground-truth scoring?
+ """
+ returnsuper().compare_detection(other,max_distance_sec,other_is_groundtruth)
+
[docs]defget_mask(self):
- """Return a boolean array indicating for each sample in data if this
+ """
+ Return a boolean array indicating for each sample in data if this sample is part of a detected event (True) or not (False). """returnsuper().get_mask()
[docs]defget_sync_events(
+ self,
+ center="Peak",
+ time_before=1,
+ time_after=1,
+ filt=(None,None),
+ mask=None,
+ as_dataframe=True,
+ ):""" Return the raw or filtered data of each detected event after centering to a specific timepoint.
@@ -1116,12 +1411,26 @@
Source code for yasa.detection
'IdxChannel' : Index of channel in data 'Stage': Sleep stage in which the events occured (if available) """
- returnsuper().get_sync_events(center=center,time_before=time_before,
- time_after=time_after,filt=filt,mask=mask,
- as_dataframe=as_dataframe)
"""set_log_level(verbose)
- (data,sf,ch_names,hypno,include,mask,n_chan,n_samples,bad_chan
- )=_check_data_hypno(data,sf,ch_names,hypno,include)
+ (data,sf,ch_names,hypno,include,mask,n_chan,n_samples,bad_chan)=_check_data_hypno(
+ data,sf,ch_names,hypno,include
+ )# If all channels are badifsum(bad_chan)==n_chan:
- logger.warning('All channels have bad amplitude. Returning None.')
+ logger.warning("All channels have bad amplitude. Returning None.")returnNone# Define time vector
@@ -1402,13 +1731,21 @@
Source code for yasa.detection
# Bandpass filternfast=next_fast_len(n_samples)data_filt=filter_data(
- data,sf,freq_sw[0],freq_sw[1],method='fir',verbose=0,l_trans_bandwidth=0.2,
- h_trans_bandwidth=0.2)
+ data,
+ sf,
+ freq_sw[0],
+ freq_sw[1],
+ method="fir",
+ verbose=0,
+ l_trans_bandwidth=0.2,
+ h_trans_bandwidth=0.2,
+ )# Extract the spindles-related sigma signal for couplingifcoupling:is_tensorpac_installed()importtensorpac.methodsastpm
+
# The width of the transition band is set to 1.5 Hz on each side,# meaning that for freq_sp = (12, 15 Hz), the -6 dB points are located# at 11.25 and 15.75 Hz. The frequency band for the amplitude signal
@@ -1419,10 +1756,17 @@
# If no peaks are detected, return Noneiflen(idx_neg_peaks)==0orlen(idx_pos_peaks)==0:
- logger.warning('No SW were found in channel %s.',ch_names[i])
+ logger.warning("No SW were found in channel %s.",ch_names[i])continue# Make sure that the last detected peak is a positive one
@@ -1464,12 +1808,12 @@
Source code for yasa.detection
idx_pos_peaks=idx_neg_peaks+closest_pos_peaks# Now we compute the PTP amplitude and keep only the good peaks
- sw_ptp=(np.abs(data_filt[i,idx_neg_peaks])+data_filt[i,idx_pos_peaks])
+ sw_ptp=np.abs(data_filt[i,idx_neg_peaks])+data_filt[i,idx_pos_peaks]good_ptp=np.logical_and(sw_ptp>amp_ptp[0],sw_ptp<amp_ptp[1])# If good_ptp is all Falseifall(~good_ptp):
- logger.warning('No SW were found in channel %s.',ch_names[i])
+ logger.warning("No SW were found in channel %s.",ch_names[i])continuesw_ptp=sw_ptp[good_ptp]
@@ -1518,28 +1862,30 @@
Source code for yasa.detection
sw_sta=np.zeros(sw_dur.shape)# And we apply a set of thresholds to remove bad slow waves
- good_sw=np.logical_and.reduce((
- # Data edges
- previous_neg_zc!=0,
- following_neg_zc!=0,
- previous_pos_zc!=0,
- following_pos_zc!=0,
- # Duration criteria
- sw_dur==sw_dur_both_phase,# dur = negative + positive
- sw_dur<=dur_neg[1]+dur_pos[1],# dur < max(neg) + max(pos)
- sw_dur>=dur_neg[0]+dur_pos[0],# dur > min(neg) + min(pos)
- neg_phase_dur>dur_neg[0],
- neg_phase_dur<dur_neg[1],
- pos_phase_dur>dur_pos[0],
- pos_phase_dur<dur_pos[1],
- # Sanity checks
- sw_midcrossing>sw_start,
- sw_midcrossing<sw_end,
- sw_slope>0,
- ))
+ good_sw=np.logical_and.reduce(
+ (
+ # Data edges
+ previous_neg_zc!=0,
+ following_neg_zc!=0,
+ previous_pos_zc!=0,
+ following_pos_zc!=0,
+ # Duration criteria
+ sw_dur==sw_dur_both_phase,# dur = negative + positive
+ sw_dur<=dur_neg[1]+dur_pos[1],# dur < max(neg) + max(pos)
+ sw_dur>=dur_neg[0]+dur_pos[0],# dur > min(neg) + min(pos)
+ neg_phase_dur>dur_neg[0],
+ neg_phase_dur<dur_neg[1],
+ pos_phase_dur>dur_pos[0],
+ pos_phase_dur<dur_pos[1],
+ # Sanity checks
+ sw_midcrossing>sw_start,
+ sw_midcrossing<sw_end,
+ sw_slope>0,
+ )
+ )ifall(~good_sw):
- logger.warning('No SW were found in channel %s.',ch_names[i])
+ logger.warning("No SW were found in channel %s.",ch_names[i])continue# Filter good events
@@ -1556,28 +1902,33 @@
Source code for yasa.detection
sw_sta=sw_sta[good_sw]# Create a dictionnary
- sw_params=OrderedDict({
- 'Start':sw_start,
- 'NegPeak':sw_idx_neg,
- 'MidCrossing':sw_midcrossing,
- 'PosPeak':sw_idx_pos,
- 'End':sw_end,
- 'Duration':sw_dur,
- 'ValNegPeak':data_filt[i,idx_neg_peaks],
- 'ValPosPeak':data_filt[i,idx_pos_peaks],
- 'PTP':sw_ptp,
- 'Slope':sw_slope,
- 'Frequency':1/sw_dur,
- 'Stage':sw_sta,
- })
+ sw_params=OrderedDict(
+ {
+ "Start":sw_start,
+ "NegPeak":sw_idx_neg,
+ "MidCrossing":sw_midcrossing,
+ "PosPeak":sw_idx_pos,
+ "End":sw_end,
+ "Duration":sw_dur,
+ "ValNegPeak":data_filt[i,idx_neg_peaks],
+ "ValPosPeak":data_filt[i,idx_pos_peaks],
+ "PTP":sw_ptp,
+ "Slope":sw_slope,
+ "Frequency":1/sw_dur,
+ "Stage":sw_sta,
+ }
+ )# Add phase (in radians) of slow-oscillation signal at maximum# spindles-related sigma amplitude within a XX-seconds centered epochs.ifcoupling:# Get phase and amplitude for each centered epoch
- time_before=time_after=coupling_params['time']
- assertfloat(sf*time_before).is_integer(),(
- "Invalid time parameter for coupling. Must be a whole number of samples.")
+ time_before=time_after=coupling_params["time"]
+ assertfloat(
+ sf*time_before
+ ).is_integer(),(
+ "Invalid time parameter for coupling. Must be a whole number of samples."
+ )bef=int(sf*time_before)aft=int(sf*time_after)# Center of each epoch is defined as the negative peak of the SW
@@ -1591,69 +1942,74 @@
Source code for yasa.detection
# Now we need to append it back to the original unmasked shape# to avoid error when idx.shape[0] != idx_valid.shape, i.e.# some epochs were out of data bounds.
- sw_params['SigmaPeak']=np.ones(n_peaks)*np.nan
+ sw_params["SigmaPeak"]=np.ones(n_peaks)*np.nan# Timestamp at sigma peak, expressed in seconds from negative peak# e.g. -0.39, 0.5, 1, 2 -- limits are [time_before, time_after]time_sigpk=(idx_max_amp-bef)/sf# convert to absolute time from beginning of the recording# time_sigpk only includes valid epochtime_sigpk_abs=sw_idx_neg[idx_valid]+time_sigpk
- sw_params['SigmaPeak'][idx_valid]=time_sigpk_abs
+ sw_params["SigmaPeak"][idx_valid]=time_sigpk_abs# 2) PhaseAtSigmaPeak# Find SW phase at max sigma amplitude in epochpha_at_max=np.squeeze(np.take_along_axis(sw_pha_ev,idx_max_amp[...,None],axis=1))
- sw_params['PhaseAtSigmaPeak']=np.ones(n_peaks)*np.nan
- sw_params['PhaseAtSigmaPeak'][idx_valid]=pha_at_max
+ sw_params["PhaseAtSigmaPeak"]=np.ones(n_peaks)*np.nan
+ sw_params["PhaseAtSigmaPeak"][idx_valid]=pha_at_max# 3) Normalized Direct PAC, with thresholding# Unreliable values are set to 0
- ndp=np.squeeze(tpm.norm_direct_pac(
- sw_pha_ev[None,...],sp_amp_ev[None,...],p=coupling_params['p']))
- sw_params['ndPAC']=np.ones(n_peaks)*np.nan
- sw_params['ndPAC'][idx_valid]=ndp
+ ndp=np.squeeze(
+ tpm.norm_direct_pac(
+ sw_pha_ev[None,...],sp_amp_ev[None,...],p=coupling_params["p"]
+ )
+ )
+ sw_params["ndPAC"]=np.ones(n_peaks)*np.nan
+ sw_params["ndPAC"][idx_valid]=ndp# Make sure that Stage is the last column of the dataframe
- sw_params.move_to_end('Stage')
+ sw_params.move_to_end("Stage")# Convert to dataframe, keeping only good eventsdf_chan=pd.DataFrame(sw_params)# Remove all duplicates
- df_chan=df_chan.drop_duplicates(subset=['Start'],keep=False)
- df_chan=df_chan.drop_duplicates(subset=['End'],keep=False)
+ df_chan=df_chan.drop_duplicates(subset=["Start"],keep=False)
+ df_chan=df_chan.drop_duplicates(subset=["End"],keep=False)# We need at least 50 detected slow waves to apply the Isolation Forestifremove_outliersanddf_chan.shape[0]>=50:
- col_keep=['Duration','ValNegPeak','ValPosPeak','PTP','Slope','Frequency']
- ilf=IsolationForest(contamination='auto',max_samples='auto',
- verbose=0,random_state=42)
+ col_keep=["Duration","ValNegPeak","ValPosPeak","PTP","Slope","Frequency"]
+ ilf=IsolationForest(
+ contamination="auto",max_samples="auto",verbose=0,random_state=42
+ )good=ilf.fit_predict(df_chan[col_keep])good[good==-1]=0
- logger.info('%i outliers were removed in channel %s.'
- %((good==0).sum(),ch_names[i]))
+ logger.info(
+ "%i outliers were removed in channel %s."%((good==0).sum(),ch_names[i])
+ )# Remove outliers from DataFramedf_chan=df_chan[good.astype(bool)]
- logger.info('%i slow-waves were found in channel %s.'
- %(df_chan.shape[0],ch_names[i]))
+ logger.info("%i slow-waves were found in channel %s."%(df_chan.shape[0],ch_names[i]))# ##################################################################### END SINGLE CHANNEL DETECTION# ####################################################################
- df_chan['Channel']=ch_names[i]
- df_chan['IdxChannel']=i
+ df_chan["Channel"]=ch_names[i]
+ df_chan["IdxChannel"]=idf=pd.concat([df,df_chan],axis=0,ignore_index=True)# If no SW were detected, return Noneifdf.empty:
- logger.warning('No SW were found in data. Returning None.')
+ logger.warning("No SW were found in data. Returning None.")returnNoneifhypnoisNone:
- df=df.drop(columns=['Stage'])
+ df=df.drop(columns=["Stage"])else:
- df['Stage']=df['Stage'].astype(int)
+ df["Stage"]=df["Stage"].astype(int)
- returnSWResults(events=df,data=data,sf=sf,ch_names=ch_names,
- hypno=hypno,data_filt=data_filt)
[docs]defsummary(self,grp_chan=False,grp_stage=False,mask=None,aggfunc="mean",sort=True):"""Return a summary of the SW detection, optionally grouped across channels and/or stage.
@@ -1696,8 +2052,14 @@
Source code for yasa.detection
sort : bool If True, sort group keys when grouping. """
- returnsuper().summary(event_type='sw',grp_chan=grp_chan,grp_stage=grp_stage,
- aggfunc=aggfunc,sort=sort,mask=mask)
[docs]deffind_cooccurring_spindles(self,spindles,lookaround=1.2):"""Given a spindles detection summary dataframe, find slow-waves that co-occur with
@@ -1753,13 +2115,13 @@
Source code for yasa.detection
cooccurring_spindle_peaks=[]# Find intersecting channels
- common_ch=np.intersect1d(self._events['Channel'].unique(),spindles['Channel'].unique())
+ common_ch=np.intersect1d(self._events["Channel"].unique(),spindles["Channel"].unique())assertlen(common_ch),"No common channel(s) were found."# Loop across channels
- forchaninself._events['Channel'].unique():
+ forchaninself._events["Channel"].unique():sw_chan_peaks=self._events[self._events["Channel"]==chan]["NegPeak"].to_numpy()
- sp_chan_peaks=spindles[spindles["Channel"]==chan]['Peak'].to_numpy()
+ sp_chan_peaks=spindles[spindles["Channel"]==chan]["Peak"].to_numpy()# Loop across individual slow-wavesforsw_negpeakinsw_chan_peaks:start=sw_negpeak-lookaround
@@ -1777,7 +2139,103 @@
Source code for yasa.detection
# Add columns to self._events: IN-PLACE MODIFICATION!self._events["CooccurringSpindle"]=~np.isnan(distance_sp_to_sw_peak)self._events["CooccurringSpindlePeak"]=cooccurring_spindle_peaks
- self._events['DistanceSpindleToSW']=distance_sp_to_sw_peak
[docs]defcompare_channels(self,score="f1",max_distance_sec=0):
+ """
+ Compare detected slow-waves across channels.
+
+ This is a wrapper around the :py:func:`yasa.compare_detection` function. Please
+ refer to the documentation of this function for more details.
+
+ Parameters
+ ----------
+ score : str
+ The performance metric to compute. Accepted values are "precision", "recall"
+ (aka sensitivity) and "f1" (default). The F1-score is the harmonic mean of precision
+ and recall, and is usually the preferred metric to evaluate the agreement between
+ two channels. All three metrics are bounded by 0 and 1, where 1 indicates perfect
+ agreement.
+ max_distance_sec : float
+ The maximum distance between slow-waves, in seconds, to consider as the same event.
+
+ .. warning:: To reduce computation cost, YASA rounds the start time of each spindle to
+ the nearest decisecond (= 100 ms). This means that the lowest possible resolution
+ is 100 ms, regardless of the sampling frequency of the data. Two slow-waves
+ starting at 500 ms and 540 ms on their respective channels will therefore always be
+ considered the same event, even when max_distance_sec=0.
+
+ Returns
+ -------
+ scores : :py:class:`pandas.DataFrame`
+ A Pandas DataFrame with the output scores, of shape (n_chan, n_chan).
+
+ Notes
+ -----
+ Some use cases of this function:
+
+ 1. What proportion of slow-waves detected in one channel are also detected on
+ another channel (if using ``score="recall"``).
+ 2. What is the overall agreement in the detected events between channels?
+ 3. Is the agreement better in channels that are close to one another?
+ """
+ returnsuper().compare_channels(score,max_distance_sec)
+
+
[docs]defcompare_detection(self,other,max_distance_sec=0,other_is_groundtruth=True):
+ """
+ Compare the detected slow-waves against either another YASA detection or against custom
+ annotations (e.g. ground-truth human scoring).
+
+ This function is a wrapper around the :py:func:`yasa.compare_detection` function. Please
+ refer to the documentation of this function for more details.
+
+ Parameters
+ ----------
+ other : dataframe or detection results
+ This can be either a) the output of another YASA detection, for example if you want to
+ test the impact of tweaking some parameters on the detected events or b) a pandas
+ DataFrame with custom annotations, obtained by another detection method outside
+ of YASA, or with manual labelling. If b), the dataframe must contain the "Start" and
+ "Channel" columns, with the start of each event in seconds from the beginning
+ of the recording and the channel name, respectively. The channel names should match
+ the output of the summary() method.
+ max_distance_sec : float
+ The maximum distance between slow-waves, in seconds, to consider as the same event.
+
+ .. warning:: To reduce computation cost, YASA rounds the start time of each slow-wave
+ to the nearest decisecond (= 100 ms). This means that the lowest possible
+ resolution is 100 ms, regardless of the sampling frequency of the data.
+ other_is_groundtruth : bool
+ If True (default), ``other`` will be considered as the ground-truth scoring. If False,
+ the current detection will be considered as the ground-truth, and the precision and
+ recall scores will be inverted. This parameter has no effect on the F1-score.
+
+ .. note:: when ``other`` is the ground-truth (default), the recall score is the
+ fraction of events in other that were succesfully detected by the current
+ detection, and the precision score is the proportion of detected events by the
+ current detection that are also present in other.
+
+ Returns
+ -------
+ scores : :py:class:`pandas.DataFrame`
+ A Pandas DataFrame with the channel names as index, and the following columns
+
+ * ``precision``: Precision score, aka positive predictive value
+ * ``recall``: Recall score, aka sensitivity
+ * ``f1``: F1-score
+ * ``n_self``: Number of detected events in ``self`` (current method).
+ * ``n_other``: Number of detected events in ``other``.
+
+ Notes
+ -----
+ Some use cases of this function:
+
+ 1. How well does YASA events detection perform against ground-truth human annotations?
+ 2. If I change the threshold(s) of the events detection, do the detected events match
+ those obtained with the default parameters?
+ 3. Which detection thresholds give the highest agreement with the ground-truth scoring?
+ """
+ returnsuper().compare_detection(other,max_distance_sec,other_is_groundtruth)
[docs]defget_coincidence_matrix(self,scaled=True):"""Return the (scaled) coincidence matrix.
@@ -1834,8 +2292,15 @@
[docs]defget_sync_events(
+ self,
+ center="NegPeak",
+ time_before=0.4,
+ time_after=0.8,
+ filt=(None,None),
+ mask=None,
+ as_dataframe=True,
+ ):""" Return the raw data of each detected event after centering to a specific timepoint.
@@ -1874,11 +2339,25 @@
Source code for yasa.detection
'Stage': Sleep stage in which the events occured (if available) """returnsuper().get_sync_events(
- center=center,time_before=time_before,time_after=time_after,filt=filt,mask=mask,
- as_dataframe=as_dataframe)
Optional argument that are passed to :py:func:`seaborn.lineplot`. """returnsuper().plot_average(
- event_type='sw',center=center,hue=hue,time_before=time_before,
- time_after=time_after,filt=filt,mask=mask,figsize=figsize,**kwargs)
[docs]defrem_detect(
+ loc,
+ roc,
+ sf,
+ hypno=None,
+ include=4,
+ amplitude=(50,325),
+ duration=(0.3,1.2),
+ freq_rem=(0.5,5),
+ remove_outliers=False,
+ verbose=False,
+):"""Rapid eye movements (REMs) detection. This detection requires both the left EOG (LOC) and right EOG (LOC).
@@ -1951,10 +2448,9 @@
Source code for yasa.detection
.. warning:: The default unit of :py:class:`mne.io.BaseRaw` is Volts. Therefore, if passing data from a :py:class:`mne.io.BaseRaw`,
- you need to multiply the data by 1e6 to convert to micro-Volts
- (1 V = 1,000,000 uV), e.g.:
+ make sure to use units="uV" to get the data in micro-Volts, e.g.:
- >>> data = raw.get_data() * 1e6 # Make sure that data is in uV
+ >>> data = raw.get_data(units="uV") # Make sure that data is in uV sf : float Sampling frequency of the data, in Hz. hypno : array_like
@@ -2063,18 +2559,18 @@
Source code for yasa.detection
# Safety checksloc=np.squeeze(np.asarray(loc,dtype=np.float64))roc=np.squeeze(np.asarray(roc,dtype=np.float64))
- assertloc.ndim==1,'LOC must be 1D.'
- assertroc.ndim==1,'ROC must be 1D.'
- assertloc.size==roc.size,'LOC and ROC must have the same size.'
+ assertloc.ndim==1,"LOC must be 1D."
+ assertroc.ndim==1,"ROC must be 1D."
+ assertloc.size==roc.size,"LOC and ROC must have the same size."data=np.vstack((loc,roc))
- (data,sf,ch_names,hypno,include,mask,n_chan,n_samples,bad_chan
- )=_check_data_hypno(data,sf,['LOC','ROC'],hypno,include)
+ (data,sf,ch_names,hypno,include,mask,n_chan,n_samples,bad_chan)=_check_data_hypno(
+ data,sf,["LOC","ROC"],hypno,include
+ )# If all channels are badifany(bad_chan):
- logger.warning('At least one channel has bad amplitude. '
- 'Returning None.')
+ logger.warning("At least one channel has bad amplitude. ""Returning None.")returnNone# Bandpass filter
@@ -2088,9 +2584,14 @@
Source code for yasa.detection
# - distance: required distance in samples between neighboring peaks.# - prominence: required prominence of peaks.# - wlen: limit search for bases to a specific window.
- hmin,hmax=amplitude[0]**2,amplitude[1]**2
- pks,pks_params=signal.find_peaks(negp,height=(hmin,hmax),distance=(duration[0]*sf),
- prominence=(0.8*hmin),wlen=(duration[1]*sf))
+ hmin,hmax=amplitude[0]**2,amplitude[1]**2
+ pks,pks_params=signal.find_peaks(
+ negp,
+ height=(hmin,hmax),
+ distance=(duration[0]*sf),
+ prominence=(0.8*hmin),
+ wlen=(duration[1]*sf),
+ )# Intersect with sleep stage vector# We do that before calculating the features in order to gain some time
@@ -2101,78 +2602,98 @@
Source code for yasa.detection
# If no peaks are detected, return Noneiflen(pks)==0:
- logger.warning('No REMs were found in data. Returning None.')
+ logger.warning("No REMs were found in data. Returning None.")returnNone# HypnogramifhypnoisnotNone:# The sleep stage at the beginning of the REM is considered.
- rem_sta=hypno[pks_params['left_bases']]
+ rem_sta=hypno[pks_params["left_bases"]]else:rem_sta=np.zeros(pks.shape)# Calculate time features
- pks_params['Start']=pks_params['left_bases']/sf
- pks_params['Peak']=pks/sf
- pks_params['End']=pks_params['right_bases']/sf
- pks_params['Duration']=pks_params['End']-pks_params['Start']
+ pks_params["Start"]=pks_params["left_bases"]/sf
+ pks_params["Peak"]=pks/sf
+ pks_params["End"]=pks_params["right_bases"]/sf
+ pks_params["Duration"]=pks_params["End"]-pks_params["Start"]# Time points in minutes (HH:MM:SS)# pks_params['StartMin'] = pd.to_timedelta(pks_params['Start'], unit='s').dt.round('s') # noqa# pks_params['PeakMin'] = pd.to_timedelta(pks_params['Peak'], unit='s').dt.round('s') # noqa# pks_params['EndMin'] = pd.to_timedelta(pks_params['End'], unit='s').dt.round('s') # noqa# Absolute LOC / ROC value at peak (filtered)
- pks_params['LOCAbsValPeak']=abs(data_filt[0,pks])
- pks_params['ROCAbsValPeak']=abs(data_filt[1,pks])
+ pks_params["LOCAbsValPeak"]=abs(data_filt[0,pks])
+ pks_params["ROCAbsValPeak"]=abs(data_filt[1,pks])# Absolute rising and falling slope
- dist_pk_left=(pks-pks_params['left_bases'])/sf
- dist_pk_right=(pks_params['right_bases']-pks)/sf
- locrs=(data_filt[0,pks]-data_filt[0,pks_params['left_bases']])/dist_pk_left
- rocrs=(data_filt[1,pks]-data_filt[1,pks_params['left_bases']])/dist_pk_left
- locfs=(data_filt[0,pks_params['right_bases']]-data_filt[0,pks])/dist_pk_right
- rocfs=(data_filt[1,pks_params['right_bases']]-data_filt[1,pks])/dist_pk_right
- pks_params['LOCAbsRiseSlope']=abs(locrs)
- pks_params['ROCAbsRiseSlope']=abs(rocrs)
- pks_params['LOCAbsFallSlope']=abs(locfs)
- pks_params['ROCAbsFallSlope']=abs(rocfs)
- pks_params['Stage']=rem_sta# Sleep stage
+ dist_pk_left=(pks-pks_params["left_bases"])/sf
+ dist_pk_right=(pks_params["right_bases"]-pks)/sf
+ locrs=(data_filt[0,pks]-data_filt[0,pks_params["left_bases"]])/dist_pk_left
+ rocrs=(data_filt[1,pks]-data_filt[1,pks_params["left_bases"]])/dist_pk_left
+ locfs=(data_filt[0,pks_params["right_bases"]]-data_filt[0,pks])/dist_pk_right
+ rocfs=(data_filt[1,pks_params["right_bases"]]-data_filt[1,pks])/dist_pk_right
+ pks_params["LOCAbsRiseSlope"]=abs(locrs)
+ pks_params["ROCAbsRiseSlope"]=abs(rocrs)
+ pks_params["LOCAbsFallSlope"]=abs(locfs)
+ pks_params["ROCAbsFallSlope"]=abs(rocfs)
+ pks_params["Stage"]=rem_sta# Sleep stage# Convert to Pandas DataFramedf=pd.DataFrame(pks_params)# Make sure that the sign of ROC and LOC is opposite
- df['IsOppositeSign']=(np.sign(data_filt[1,pks])!=np.sign(data_filt[0,pks]))
+ df["IsOppositeSign"]=np.sign(data_filt[1,pks])!=np.sign(data_filt[0,pks])df=df[np.sign(data_filt[1,pks])!=np.sign(data_filt[0,pks])]# Remove bad durationtmin,tmax=duration
- good_dur=np.logical_and(pks_params['Duration']>=tmin,pks_params['Duration']<tmax)
+ good_dur=np.logical_and(pks_params["Duration"]>=tmin,pks_params["Duration"]<tmax)df=df[good_dur]# Keep only useful channels
- df=df[['Start','Peak','End','Duration','LOCAbsValPeak','ROCAbsValPeak',
- 'LOCAbsRiseSlope','ROCAbsRiseSlope','LOCAbsFallSlope','ROCAbsFallSlope','Stage']]
+ df=df[
+ [
+ "Start",
+ "Peak",
+ "End",
+ "Duration",
+ "LOCAbsValPeak",
+ "ROCAbsValPeak",
+ "LOCAbsRiseSlope",
+ "ROCAbsRiseSlope",
+ "LOCAbsFallSlope",
+ "ROCAbsFallSlope",
+ "Stage",
+ ]
+ ]ifhypnoisNone:
- df=df.drop(columns=['Stage'])
+ df=df.drop(columns=["Stage"])else:
- df['Stage']=df['Stage'].astype(int)
+ df["Stage"]=df["Stage"].astype(int)# We need at least 50 detected REMs to apply the Isolation Forest.ifremove_outliersanddf.shape[0]>=50:
- col_keep=['Duration','LOCAbsValPeak','ROCAbsValPeak','LOCAbsRiseSlope',
- 'ROCAbsRiseSlope','LOCAbsFallSlope','ROCAbsFallSlope']
- ilf=IsolationForest(contamination='auto',max_samples='auto',
- verbose=0,random_state=42)
+ col_keep=[
+ "Duration",
+ "LOCAbsValPeak",
+ "ROCAbsValPeak",
+ "LOCAbsRiseSlope",
+ "ROCAbsRiseSlope",
+ "LOCAbsFallSlope",
+ "ROCAbsFallSlope",
+ ]
+ ilf=IsolationForest(contamination="auto",max_samples="auto",verbose=0,random_state=42)good=ilf.fit_predict(df[col_keep])good[good==-1]=0
- logger.info('%i outliers were removed.',(good==0).sum())
+ logger.info("%i outliers were removed.",(good==0).sum())# Remove outliers from DataFramedf=df[good.astype(bool)]
- logger.info('%i REMs were found in data.',df.shape[0])
+ logger.info("%i REMs were found in data.",df.shape[0])df=df.reset_index(drop=True)
- returnREMResults(events=df,data=data,sf=sf,ch_names=ch_names,
- hypno=hypno,data_filt=data_filt)
[docs]defsummary(self,grp_stage=False,mask=None,aggfunc="mean",sort=True):"""Return a summary of the REM detection, optionally grouped across stage. Parameters
@@ -2217,8 +2738,14 @@
Source code for yasa.detection
"""# ``grp_chan`` is always False for REM detection because the# REMs are always detected on a combination of LOC and ROC.
- returnsuper().summary(event_type='rem',grp_chan=False,grp_stage=grp_stage,
- aggfunc=aggfunc,sort=sort,mask=mask)
[docs]defget_mask(self):"""Return a boolean array indicating for each sample in data if this
@@ -2226,14 +2753,15 @@
Source code for yasa.detection
"""# We cannot use super() because "Channel" is not present in _events.fromyasa.othersimport_index_to_events
+
mask=np.zeros(self._data.shape,dtype=int)
- idx_ev=_index_to_events(
- self._events[['Start','End']].to_numpy()*self._sf)
+ idx_ev=_index_to_events(self._events[["Start","End"]].to_numpy()*self._sf)mask[:,idx_ev]=1returnmask
[docs]defget_sync_events(
+ self,center="Peak",time_before=0.4,time_after=0.4,filt=(None,None),mask=None
+ ):""" Return the raw or filtered data of each detected event after centering to a specific timepoint.
@@ -2268,6 +2796,7 @@
Source code for yasa.detection
'IdxChannel' : Index of channel in data """fromyasa.othersimportget_centered_indices
+
asserttime_before>=0asserttime_after>=0bef=int(self._sf*time_before)
@@ -2275,7 +2804,8 @@
mask=self._check_mask(mask)masked_events=self._events.loc[mask,:]
- time=np.arange(-bef,aft+1,dtype='int')/self._sf
+ time=np.arange(-bef,aft+1,dtype="int")/self._sf# Get location of peaks in datapeaks=(masked_events[center]*self._sf).astype(int).to_numpy()# Get centered indices (here we could use second channel as well).idx,idx_valid=get_centered_indices(data[0,:],peaks,bef,aft)# If no good epochs are returned raise a warningassertlen(idx_valid),(
- 'Time before and/or time after exceed data bounds, please '
- 'lower the temporal window around center.')
+ "Time before and/or time after exceed data bounds, please "
+ "lower the temporal window around center."
+ )# Initialize empty dataframedf_sync=pd.DataFrame()
@@ -2300,16 +2831,24 @@
###########################################################################set_log_level(verbose)
- (data,sf,_,hypno,include,_,n_chan,n_samples,_
- )=_check_data_hypno(data,sf,ch_names=None,hypno=hypno,include=include,check_amp=False)
+ (data,sf,_,hypno,include,_,n_chan,n_samples,_)=_check_data_hypno(
+ data,sf,ch_names=None,hypno=hypno,include=include,check_amp=False
+ )
- assertisinstance(n_chan_reject,int),'n_chan_reject must be int.'
- assertn_chan_reject>=1,'n_chan_reject must be >= 1.'
- assertn_chan_reject<=n_chan,'n_chan_reject must be <= n_chan.'
+ assertisinstance(n_chan_reject,int),"n_chan_reject must be int."
+ assertn_chan_reject>=1,"n_chan_reject must be >= 1."
+ assertn_chan_reject<=n_chan,"n_chan_reject must be <= n_chan."# Safety check: sampling frequency and window
- assertisinstance(sf,(int,float)),'sf must be int or float'
- assertisinstance(window,(int,float)),'window must be int or float'
+ assertisinstance(sf,(int,float)),"sf must be int or float"
+ assertisinstance(window,(int,float)),"window must be int or float"ifisinstance(sf,float):
- assertsf.is_integer(),'sf must be a whole number.'
+ assertsf.is_integer(),"sf must be a whole number."sf=int(sf)win_sec=windowwindow=win_sec*sf# Convert window to samplesifisinstance(window,float):
- assertwindow.is_integer(),'window * sf must be a whole number.'
+ assertwindow.is_integer(),"window * sf must be a whole number."window=int(window)# Safety check: hypnogram
@@ -2565,24 +3115,27 @@
Source code for yasa.detection
# Safety checks: methodsassertisinstance(method,str),"method must be a string."method=method.lower()
- ifmethodin['cov','covar','covariance','riemann','potato']:
- method='covar'
+ ifmethodin["cov","covar","covariance","riemann","potato"]:
+ method="covar"is_pyriemann_installed()frompyriemann.estimationimportCovariances,Shrinkagefrompyriemann.clusteringimportPotato
+
# Must have at least 4 channels to use method='covar'ifn_chan<=4:
- logger.warning("Must have at least 4 channels for method='covar'. "
- "Automatically switching to method='std'.")
- method='std'
+ logger.warning(
+ "Must have at least 4 channels for method='covar'. "
+ "Automatically switching to method='std'."
+ )
+ method="std"############################################################################ START THE REJECTION############################################################################ Remove flat channels
- isflat=(np.nanstd(data,axis=-1)==0)
+ isflat=np.nanstd(data,axis=-1)==0ifisflat.any():
- logger.warning('Flat channel(s) were found and removed in data.')
+ logger.warning("Flat channel(s) were found and removed in data.")data=data[~isflat]n_chan=data.shape[0]
@@ -2600,13 +3153,13 @@
Source code for yasa.detection
n_flat_epochs=where_flat_epochs.size# Now let's make sure that we have an hypnogram and an include variable
- if'hypno_win'notinlocals():
+ if"hypno_win"notinlocals():# [-2, -2, -2, -2, ...], where -2 stands for unscored
- hypno_win=-2*np.ones(n_epochs,dtype='float')
- include=np.array([-2],dtype='float')
+ hypno_win=-2*np.ones(n_epochs,dtype="float")
+ include=np.array([-2],dtype="float")# We want to make sure that hypno-win and n_epochs have EXACTLY same shape
- assertn_epochs==hypno_win.shape[-1],'Hypno and epochs do not match.'
+ assertn_epochs==hypno_win.shape[-1],"Hypno and epochs do not match."# Finally, we make sure not to include any flat epochs in calculation# just using a random number that is unlikely to be picked by users
@@ -2614,19 +3167,19 @@
Source code for yasa.detection
hypno_win[where_flat_epochs]=-111991# Add logger info
- logger.info('Number of channels in data = %i',n_chan)
- logger.info('Number of samples in data = %i',n_samples)
- logger.info('Sampling frequency = %.2f Hz',sf)
- logger.info('Data duration = %.2f seconds',n_samples/sf)
- logger.info('Number of epochs = %i'%n_epochs)
- logger.info('Artifact window = %.2f seconds'%win_sec)
- logger.info('Method = %s'%method)
- logger.info('Threshold = %.2f standard deviations'%threshold)
+ logger.info("Number of channels in data = %i",n_chan)
+ logger.info("Number of samples in data = %i",n_samples)
+ logger.info("Sampling frequency = %.2f Hz",sf)
+ logger.info("Data duration = %.2f seconds",n_samples/sf)
+ logger.info("Number of epochs = %i"%n_epochs)
+ logger.info("Artifact window = %.2f seconds"%win_sec)
+ logger.info("Method = %s"%method)
+ logger.info("Threshold = %.2f standard deviations"%threshold)# Create empty `hypno_art` vector (1 sample = 1 epoch)
- epoch_is_art=np.zeros(n_epochs,dtype='int')
+ epoch_is_art=np.zeros(n_epochs,dtype="int")
- ifmethod=='covar':
+ ifmethod=="covar":# Calculate the covariance matrices,# shape (n_epochs, n_chan, n_chan)covmats=Covariances().fit_transform(epochs)
@@ -2634,10 +3187,11 @@
Source code for yasa.detection
covmats=Shrinkage().fit_transform(covmats)# Define Potato instance: 0 = clean, 1 = art# To increase speed we set the max number of iterations from 10 to 100
- potato=Potato(metric='riemann',threshold=threshold,pos_label=0,
- neg_label=1,n_iter_max=10)
+ potato=Potato(
+ metric="riemann",threshold=threshold,pos_label=0,neg_label=1,n_iter_max=10
+ )# Create empty z-scores output (n_epochs)
- zscores=np.zeros(n_epochs,dtype='float')*np.nan
+ zscores=np.zeros(n_epochs,dtype="float")*np.nanforstageininclude:where_stage=np.where(hypno_win==stage)[0]
@@ -2646,9 +3200,11 @@
Source code for yasa.detection
ifwhere_stage.size<30:ifhypnoisnotNone:# Only show warnig if user actually pass an hypnogram
- logger.warning(f"At least 30 epochs are required to "
- f"calculate z-score. Skipping "
- f"stage {stage}")
+ logger.warning(
+ f"At least 30 epochs are required to "
+ f"calculate z-score. Skipping "
+ f"stage {stage}"
+ )continue# Apply Potato algorithm, extract z-scores and labelszs=potato.fit_transform(covmats[where_stage])
@@ -2656,20 +3212,22 @@
Source code for yasa.detection
ifhypnoisnotNone:# Only shows if user actually pass an hypnogramperc_reject=100*(art.sum()/art.size)
- text=(f"Stage {stage}: {art.sum()} / {art.size} "
- f"epochs rejected ({perc_reject:.2f}%)")
+ text=(
+ f"Stage {stage}: {art.sum()} / {art.size} "
+ f"epochs rejected ({perc_reject:.2f}%)"
+ )logger.info(text)# Append to global vectorepoch_is_art[where_stage]=artzscores[where_stage]=zs
- elifmethodin['std','sd']:
+ elifmethodin["std","sd"]:# Calculate log-transformed standard dev in each epoch# We add 1 to avoid log warning id std is zero (e.g. flat line)# (n_epochs, n_chan)std_epochs=np.log(np.nanstd(epochs,axis=-1)+1)# Create empty zscores output (n_epochs, n_chan)
- zscores=np.zeros((n_epochs,n_chan),dtype='float')*np.nan
+ zscores=np.zeros((n_epochs,n_chan),dtype="float")*np.nanforstageininclude:where_stage=np.where(hypno_win==stage)[0]# At least 30 epochs are required to calculate z-scores
@@ -2677,9 +3235,11 @@
Source code for yasa.detection
ifwhere_stage.size<30:ifhypnoisnotNone:# Only show warnig if user actually pass an hypnogram
- logger.warning(f"At least 30 epochs are required to "
- f"calculate z-score. Skipping "
- f"stage {stage}")
+ logger.warning(
+ f"At least 30 epochs are required to "
+ f"calculate z-score. Skipping "
+ f"stage {stage}"
+ )continue# Calculate z-scores of STD for each channel x stagec_mean=np.nanmean(std_epochs[where_stage],axis=0,keepdims=True)
@@ -2691,8 +3251,10 @@
Source code for yasa.detection
ifhypnoisnotNone:# Only shows if user actually pass an hypnogramperc_reject=100*(art.sum()/art.size)
- text=(f"Stage {stage}: {art.sum()} / {art.size} "
- f"epochs rejected ({perc_reject:.2f}%)")
+ text=(
+ f"Stage {stage}: {art.sum()} / {art.size} "
+ f"epochs rejected ({perc_reject:.2f}%)"
+ )logger.info(text)# Append to global vectorepoch_is_art[where_stage]=art
@@ -2700,18 +3262,186 @@
Source code for yasa.detection
# Mark flat epochs as artefactsifn_flat_epochs>0:
- logger.info(f"Rejecting {n_flat_epochs} epochs with >=50% of channels "
- f"that are flat. Z-scores set to np.nan for these epochs.")
+ logger.info(
+ f"Rejecting {n_flat_epochs} epochs with >=50% of channels "
+ f"that are flat. Z-scores set to np.nan for these epochs."
+ )epoch_is_art[where_flat_epochs]=1# Log total percentage of epochs rejectedperc_reject=100*(epoch_is_art.sum()/n_epochs)
- text=(f"TOTAL: {epoch_is_art.sum()} / {n_epochs} epochs rejected ({perc_reject:.2f}%)")
+ text=f"TOTAL: {epoch_is_art.sum()} / {n_epochs} epochs rejected ({perc_reject:.2f}%)"logger.info(text)# Convert epoch_is_art to boolean [0, 0, 1] -- > [False, False, True]epoch_is_art=epoch_is_art.astype(bool)returnepoch_is_art,zscores
[docs]defcompare_detection(indices_detection,indices_groundtruth,max_distance=0):
+ """
+ Determine correctness of detected events against ground-truth events.
+
+ Parameters
+ ----------
+ indices_detection : array_like
+ Indices of the detected events. For example, this could be the indices of the
+ start of the spindles, or the negative peak of the slow-waves. The indices must be in
+ samples, and not in seconds.
+ indices_groundtruth : array_like
+ Indices of the ground-truth events, in samples.
+ max_distance : int, optional
+ Maximum distance between indices, in samples, to consider as the same event (default = 0).
+ For example, if the sampling frequency of the data is 100 Hz, using `max_distance=100` will
+ search for a matching event 1 second before or after the current event.
+
+ Returns
+ -------
+ results : dict
+ A dictionary with the comparison results:
+
+ * ``tp``: True positives, i.e. actual events detected as events.
+ * ``fp``: False positives, i.e. non-events detected as events.
+ * ``fn``: False negatives, i.e. actual events not detected as events.
+ * ``precision``: Precision score, aka positive predictive value (see Notes)
+ * ``recall``: Recall score, aka sensitivity (see Notes)
+ * ``f1``: F1-score (see Notes)
+
+ Notes
+ -----`
+ * The precision score is calculated as TP / (TP + FP).
+ * The recall score is calculated as TP / (TP + FN).
+ * The F1-score is calculated as TP / (TP + 0.5 * (FP + FN)).
+
+ This function is inspired by the `sleepecg.compare_heartbeats
+ <https://sleepecg.readthedocs.io/en/stable/generated/sleepecg.compare_heartbeats.html>`_
+ function.
+
+ Examples
+ --------
+ A simple example. Here, `detected` refers to the indices (in the data) of the detected events.
+ These could be for example the index of the onset of each detected spindle. `grndtrth` refers
+ to the ground-truth (e.g. human-annotated) events.
+
+ >>> from yasa import compare_detection
+ >>> detected = [5, 12, 20, 34, 41, 57, 63]
+ >>> grndtrth = [5, 12, 18, 26, 34, 41, 55, 63, 68]
+ >>> compare_detection(detected, grndtrth)
+ {'tp': array([ 5, 12, 34, 41, 63]),
+ 'fp': array([20, 57]),
+ 'fn': array([18, 26, 55, 68]),
+ 'precision': 0.7142857142857143,
+ 'recall': 0.5555555555555556,
+ 'f1': 0.625}
+
+ There are 4 true positives, 2 false positives and 4 false negatives. This gives a precision
+ score of 0.71 (= 5 / (5 + 2)), a recall score of 0.55 (= 5 / (5 + 4)) and a F1-score of 0.625.
+ The F1-score is the harmonic average of precision and recall, and should be the preferred
+ metric when comparing the performance of a detection against a ground-truth.
+
+ Order matters! If we set `detected` as the ground-truth, FP and FN are inverted, and same for
+ precision and recall. The TP and F1-score remain the same though. Therefore, when comparing two
+ detections (and not a detection against a ground-truth), the F1-score is the preferred metric
+ because it is independent of the order.
+
+ >>> compare_detection(grndtrth, detected)
+ {'tp': array([ 5, 12, 34, 41, 63]),
+ 'fp': array([18, 26, 55, 68]),
+ 'fn': array([20, 57]),
+ 'precision': 0.7142857142857143,
+ 'recall': 0.7142857142857143,
+ 'f1': 0.625}
+
+ There might be some events that are very close to each other, and we would like to count them
+ as true positive even though they do not occur exactly at the same index. This is possible
+ with the `max_distance` argument, which defines the lookaround window (in samples) for
+ each event.
+
+ >>> compare_detection(detected, grndtrth, max_distance=2)
+ {'tp': array([ 5, 12, 20, 34, 41, 57, 63]),
+ 'fp': array([], dtype=int64),
+ 'fn': array([26, 68]),
+ 'precision': 1.0,
+ 'recall': 0.7777777777777778,
+ 'f1': 0.875}
+
+ Finally, if detected is empty, all performance metrics will be set to zero, and a copy of
+ the groundtruth array will be returned as false negatives.
+
+ >>> compare_detection([], grndtrth)
+ {'tp': array([], dtype=int64),
+ 'fp': array([], dtype=int64),
+ 'fn': array([ 5, 12, 18, 26, 34, 41, 55, 63, 68]),
+ 'precision': 0,
+ 'recall': 0,
+ 'f1': 0}
+ """
+ # Safety check
+ assertall([float(i).is_integer()foriinindices_detection])# all([]) == True
+ assertall([float(i).is_integer()foriinindices_groundtruth])
+ indices_detection=np.array(indices_detection,dtype=int)# Force copy
+ indices_groundtruth=np.array(indices_groundtruth,dtype=int)
+ assertindices_detection.ndim==1,"detection indices must be a 1D list or array."
+ assertindices_groundtruth.ndim==1,"groundtruth indices must be a 1D list or array."
+ assertmax_distance>=0,"max_distance must be 0 or a positive integer."
+ assertisinstance(max_distance,int),"max_distance must be 0 or a positive integer."
+
+ # Handle cases where indices_detection or indices_groundtruth is empty
+ ifindices_detection.size==0:
+ results=dict(
+ tp=np.array([],dtype=int),
+ fp=np.array([],dtype=int),
+ fn=indices_groundtruth.copy(),
+ precision=0,
+ recall=0,
+ f1=0,
+ )
+ returnresults
+
+ ifindices_groundtruth.size==0:
+ results=dict(
+ tp=np.array([],dtype=int),
+ fp=indices_detection.copy(),
+ fn=np.array([],dtype=int),
+ precision=0,
+ recall=0,
+ f1=0,
+ )
+ returnresults
+
+ # Create boolean masks
+ max_len=max(max(indices_detection),max(indices_groundtruth))+1
+ detection_mask=np.zeros(max_len,dtype=bool)
+ detection_mask[indices_detection]=1
+ true_mask=np.zeros(max_len,dtype=bool)
+ true_mask[indices_groundtruth]=1
+
+ # Create smoothed masks
+ fuzzy_filter=np.ones(max_distance*2+1,dtype=bool)
+ iflen(fuzzy_filter)>=max_len:
+ raiseValueError(
+ f"The convolution window is larger than the signal. `max_distance` should be between "
+ f"0 and {int(max_len/2-1)} samples."
+ )
+ detection_mask_fuzzy=np.convolve(detection_mask,fuzzy_filter,mode="same")
+ true_mask_fuzzy=np.convolve(true_mask,fuzzy_filter,mode="same")
+
+ # Confusion matrix and performance metrics
+ results={}
+ results["tp"]=np.where(detection_mask&true_mask_fuzzy)[0]
+ results["fp"]=np.where(detection_mask&~true_mask_fuzzy)[0]
+ results["fn"]=np.where(~detection_mask_fuzzy&true_mask)[0]
+
+ n_tp,n_fp,n_fn=len(results["tp"]),len(results["fp"]),len(results["fn"])
+ results["precision"]=n_tp/(n_tp+n_fp)
+ results["recall"]=n_tp/(n_tp+n_fn)
+ results["f1"]=n_tp/(n_tp+0.5*(n_fp+n_fn))
+ returnresults
[docs]defhrv_stage(
+ data,
+ sf,
+ *,
+ hypno=None,
+ include=(2,3,4),
+ threshold="2min",
+ equal_length=False,
+ rr_limit=(400,2000),
+ verbose=False,
+):
+ """Calculate heart rate and heart rate variability (HRV) features from an ECG.
+
+ By default, the cardiac features are calculated for each period of N2, N3 or REM sleep that
+ are longer than 2 minutes.
+
+ .. versionadded:: 0.6.2
+
+ Parameters
+ ----------
+ data : :py:class:`numpy.ndarray`
+ Single-channel ECG data. Must be a 1D NumPy array.
+ sf : float
+ The sampling frequency of the data.
+ hypno : array_like
+ Sleep stage (hypnogram). The heart rate calculation will be applied for each sleep stage
+ defined in ``include`` (default = N2, N3 and REM sleep separately).
+
+ The hypnogram must have the same number of samples as ``data``.
+ To upsample your hypnogram, please refer to
+ :py:func:`yasa.hypno_upsample_to_data`.
+
+ .. note::
+ The default hypnogram format in YASA is a 1D integer
+ vector where:
+
+ - -2 = Unscored
+ - -1 = Artefact / Movement
+ - 0 = Wake
+ - 1 = N1 sleep
+ - 2 = N2 sleep
+ - 3 = N3 sleep
+ - 4 = REM sleep
+ include : tuple, list or int
+ Values in ``hypno`` that will be included in the mask. The default is
+ (2, 3, 4), meaning that the detection is applied on N2, N3 and REM
+ sleep separately.
+ threshold : str
+ Only periods of a given stage that exceed the duration defined in ``threshold`` will be
+ kept in subsequent analysis. The default is 2 minutes ('2min'). Other possible values
+ include: '5min', '15min', '30sec', '1hour', etc. To disable thresholding, use '0min'.
+ equal_length : bool
+ If True, the periods will all have the exact duration defined in ``threshold``.
+ That is, periods that are longer than the duration threshold will be divided into
+ sub-periods of exactly the length of threshold.
+ rr_limit : tuple
+ Lower and upper limit for the RR interval. Default is 400 to 2000 ms, corresponding to a
+ heart rate of 30 to 150 bpm. RR intervals outside this range will be set to NaN and
+ filled with linear interpolation. Use ``rr_limit=(0, np.inf)`` to disable RR correction.
+ verbose : bool or str
+ Verbose level. Default (False) will only print warning and error
+ messages. The logging levels are 'debug', 'info', 'warning', 'error',
+ and 'critical'. For most users the choice is between 'info'
+ (or ``verbose=True``) and warning (``verbose=False``). Set this to True if you are getting
+ invalid results and want to better understand what is happening.
+
+ Returns
+ -------
+ epochs : :py:class:`pandas.DataFrame`
+ Output dataframe with values (= the sleep stages defined in ``include``) and
+ epoch number as index. The columns are
+
+ * ``start`` : The start of the epoch, in seconds from the beginning of the recording.
+ * ``duration`` : The duration of the epoch, in seconds.
+ * ``hr_mean``: The mean heart rate (HR) across the epoch, in beats per minute (bpm).
+ * ``hr_std``: The standard deviation of the HR across the epoch, in bpm
+ * ``hrv_rmssd``: Heart rate variability across the epoch (RMSSD), in milliseconds.
+ rpeaks : dict
+ A dictionary with the detected heartbeats (R-peaks) indices for each epoch of each stage.
+ Indices are expressed as samples from the beginning of the epoch. This can be used to
+ manually recalculate the RR intervals, apply a custom preprocessing on the RR intervals,
+ and/or calculate more advanced HRV metrics.
+
+ Notes
+ -----
+ This function returns three cardiac features for each epoch: the mean and standard deviation of
+ the heart rate, and the root mean square of successive differences between normal heartbeats
+ (RMSSD). The RMSSD reflects the beat-to-beat variance in HR and is the primary time-domain
+ measure used to estimate the vagally mediated changes reflected in heart rate variability.
+
+ Heartbeat detection is performed with the SleepECG library: https://github.com/cbrnr/sleepecg
+
+ For an example of this function, please see the `Jupyter notebook
+ <https://github.com/raphaelvallat/yasa/blob/master/notebooks/16_EEG-HRV_coupling.ipynb>`_
+
+ References
+ ----------
+ * Shaffer, F., & Ginsberg, J. P. (2017). An overview of heart rate variability metrics and
+ norms. Frontiers in public health, 258.
+ """
+ set_log_level(verbose)
+ is_sleepecg_installed()
+ fromsleepecgimportdetect_heartbeats
+
+ ifisinstance(hypno,type(None)):
+ logger.warning(
+ "No hypnogram was passed. The entire recording will be used, i.e. "
+ "hypno will be set to np.zeros(data.size) and include will be set to 0."
+ )
+ data=np.asarray(data,dtype=np.float64)
+ hypno=np.zeros(max(data.shape),dtype=int)
+ include=0
+
+ # Safety check
+ (data,sf,_,hypno,include,_,n_chan,n_samples,_)=_check_data_hypno(
+ data,sf,None,hypno,include,check_amp=False
+ )
+ assertn_chan==1,"data must be a 1D ECG array."
+ data=np.squeeze(data)
+
+ # Find periods of equal duration
+ epochs=hypno_find_periods(hypno,sf,threshold=threshold,equal_length=equal_length)
+ assertepochs.shape[0]>0,f"No epochs longer than {threshold} found in hypnogram."
+ epochs=epochs[epochs["values"].isin(include)].reset_index(drop=True)
+ # Sort by stage and add epoch number
+ epochs=epochs.sort_values(by=["values","start"])
+ epochs["epoch"]=epochs.groupby("values")["start"].transform(lambdax:range(len(x)))
+ epochs=epochs.set_index(["values","epoch"])
+
+ # Loop over epochs
+ rpeaks={}
+ foridxinepochs.index:
+ start=epochs.loc[idx,"start"]
+ duration=epochs.loc[idx,"length"]
+ end=int(epochs.loc[idx,"start"]+duration)
+ # Detect R-peaks
+ try:
+ pks=detect_heartbeats(data[start:end],fs=sf)
+ exceptExceptionase:
+ logger.info(f"Heartbeat detection failed for epoch {idx[1]} of stage {idx[0]}: {e}")
+ continue
+
+ # Save rpeaks to dict
+ rpeaks[idx]=pks
+
+ # If not enough R-peaks were detected, skip epochs and return NaN
+ # Here, we assume a minimal HR of 30 bpm
+ constant_hr=60*(pks.size/(duration/sf))
+ ifconstant_hr<30:
+ logger.info(f"Too few detected heartbeats in epoch {idx[1]} of stage {idx[0]}.")
+ continue
+
+ # Find and correct RR intervals. Default is 400 ms (150 bpm) to 2000 ms (30 bpm)
+ rri=1000*np.diff(pks)/sf
+ rri=np.ma.masked_outside(rri,rr_limit[0],rr_limit[1]).filled(np.nan)
+ # Interpolate NaN values, but no more than 10 consecutive values
+ ifnp.isnan(rri).any():
+ rri=pd.Series(rri).interpolate(limit_direction="both",limit=10).to_numpy()
+ ifnp.isnan(rri).any():
+ # If there are still NaN present, skip current epoch
+ logger.info(f"Invalid RR intervals in epoch {idx[1]} of stage {idx[0]}.")
+ continue
+
+ # Heart rate
+ hr=60000/rri
+ epochs.loc[idx,"hr_mean"]=np.mean(hr)
+ epochs.loc[idx,"hr_std"]=np.std(hr,ddof=1)
+ epochs.loc[idx,"hrv_rmssd"]=np.sqrt(np.mean(np.diff(rri)**2))
+
+ # Convert start and duration to seconds
+ epochs["start"]/=sf
+ epochs["length"]/=sf
+ epochs=epochs.rename(columns={"length":"duration"})
+
+ returnepochs,rpeaks
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/_modules/yasa/hypno.html b/docs/build/html/_modules/yasa/hypno.html
index 91488ae..71e2c5f 100644
--- a/docs/build/html/_modules/yasa/hypno.html
+++ b/docs/build/html/_modules/yasa/hypno.html
@@ -3,11 +3,14 @@
- yasa.hypno — yasa 0.6.1 documentation
+ yasa.hypno — yasa 0.6.2 documentation
+
+
+
@@ -41,7 +44,7 @@
yasa
- 0.6.1
+ 0.6.2
The hypnogram, upsampled to ``sf_data``. """repeats=sf_data/sf_hypno
- assertsf_hypno<=sf_data,'sf_hypno must be less than sf_data.'
- assertrepeats.is_integer(),'sf_hypno / sf_data must be a whole number.'
+ assertsf_hypno<=sf_data,"sf_hypno must be less than sf_data."
+ assertrepeats.is_integer(),"sf_hypno / sf_data must be a whole number."assertisinstance(hypno,(list,np.ndarray,pd.Series))returnnp.repeat(np.asarray(hypno),repeats)
@@ -249,11 +277,11 @@
Source code for yasa.hypno
"""# Check if data is an MNE raw objectifisinstance(data,mne.io.BaseRaw):
- sf=data.info['sfreq']
+ sf=data.info["sfreq"]data=data.times# 1D array and does not require to preload datadata=np.asarray(data)hypno=np.asarray(hypno)
- asserthypno.ndim==1,'Hypno must be 1D.'
+ asserthypno.ndim==1,"Hypno must be 1D."npts_hyp=hypno.sizenpts_data=max(data.shape)# Support for 2D dataifnpts_hyp<npts_data:
@@ -261,22 +289,30 @@
Source code for yasa.hypno
npts_diff=npts_data-npts_hypifsfisnotNone:dur_diff=npts_diff/sf
- logger.warning('Hypnogram is SHORTER than data by %.2f seconds. '
- 'Padding hypnogram with last value to match data.size.'%dur_diff)
+ logger.warning(
+ "Hypnogram is SHORTER than data by %.2f seconds. "
+ "Padding hypnogram with last value to match data.size."%dur_diff
+ )else:
- logger.warning('Hypnogram is SHORTER than data by %i samples. '
- 'Padding hypnogram with last value to match data.size.'%npts_diff)
- hypno=np.pad(hypno,(0,npts_diff),mode='edge')
+ logger.warning(
+ "Hypnogram is SHORTER than data by %i samples. "
+ "Padding hypnogram with last value to match data.size."%npts_diff
+ )
+ hypno=np.pad(hypno,(0,npts_diff),mode="edge")elifnpts_hyp>npts_data:# Hypnogram is longer than datanpts_diff=npts_hyp-npts_dataifsfisnotNone:dur_diff=npts_diff/sf
- logger.warning('Hypnogram is LONGER than data by %.2f seconds. '
- 'Cropping hypnogram to match data.size.'%dur_diff)
+ logger.warning(
+ "Hypnogram is LONGER than data by %.2f seconds. "
+ "Cropping hypnogram to match data.size."%dur_diff
+ )else:
- logger.warning('Hypnogram is LONGER than data by %i samples. '
- 'Cropping hypnogram to match data.size.'%npts_diff)
+ logger.warning(
+ "Hypnogram is LONGER than data by %i samples. "
+ "Cropping hypnogram to match data.size."%npts_diff
+ )hypno=hypno[0:npts_data]returnhypno
@@ -323,7 +359,7 @@
[docs]defhypno_find_periods(hypno,sf_hypno,threshold="5min",equal_length=False):
+ """Find sequences of consecutive values exceeding a certain duration in hypnogram.
+
+ .. versionadded:: 0.6.2
+
+ Parameters
+ ----------
+ hypno : array_like
+ A 1D array with the sleep stages (= hypnogram). The dtype can be anything (int, bool, str).
+ More generally, this can be any vector for which you wish to find runs of
+ consecutive items.
+ sf_hypno : float
+ The current sampling frequency of ``hypno``, in Hz, e.g. 1/30 = 1 value per each 30 seconds
+ of EEG data, 1 = 1 value per second of EEG data.
+ threshold : str
+ This function will only keep periods that exceed a certain duration (default '5min'), e.g.
+ '5min', '15min', '30sec', '1hour'. To disable thresholding, use '0sec'.
+ equal_length : bool
+ If True, the periods will all have the exact duration defined
+ in threshold. That is, periods that are longer than the duration threshold will be divided
+ into sub-periods of exactly the length of ``threshold``.
+
+ Returns
+ -------
+ periods : :py:class:`pandas.DataFrame`
+ Output dataframe
+
+ * ``values`` : The value in hypno of the current period
+ * ``start`` : The index of the start of the period in hypno
+ * ``length`` : The duration of the period, in number of samples
+
+ Examples
+ --------
+ Let's assume that we have an hypnogram where sleep = 1 and wake = 0. There is one value per
+ minute, and therefore the sampling frequency of the hypnogram is 1 / 60 sec (~0.016 Hz).
+
+ >>> import yasa
+ >>> hypno = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
+ >>> yasa.hypno_find_periods(hypno, sf_hypno=1/60, threshold="0min")
+ values start length
+ 0 0 0 11
+ 1 1 11 3
+ 2 0 14 2
+ 3 1 16 9
+ 4 0 25 2
+
+ This gives us the start and duration of each sequence of consecutive values in the hypnogram.
+ For example, the first row tells us that there is a sequence of 11 consecutive 0 starting at
+ the first index of hypno.
+
+ Now, we may want to keep only periods that are longer than a specific threshold,
+ for example 5 minutes:
+
+ >>> yasa.hypno_find_periods(hypno, sf_hypno=1/60, threshold="5min")
+ values start length
+ 0 0 0 11
+ 1 1 16 9
+
+ Only the two sequences that are longer than 5 minutes (11 minutes and 9 minutes respectively)
+ are kept. Feel free to play around with different values of threshold!
+
+ This function is not limited to binary arrays, e.g.
+
+ >>> hypno = [0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 0, 0, 0, 1, 0, 1]
+ >>> yasa.hypno_find_periods(hypno, sf_hypno=1/60, threshold="2min")
+ values start length
+ 0 0 0 4
+ 1 2 5 6
+ 2 0 11 3
+
+ Lastly, using ``equal_length=True`` will further divide the periods into segments of the
+ same duration, i.e. the duration defined in ``threshold``:
+
+ >>> hypno = [0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 0, 0, 0, 1, 0, 1]
+ >>> yasa.hypno_find_periods(hypno, sf_hypno=1/60, threshold="2min", equal_length=True)
+ values start length
+ 0 0 0 2
+ 1 0 2 2
+ 2 2 5 2
+ 3 2 7 2
+ 4 2 9 2
+ 5 0 11 2
+
+ Here, the first period of 4 minutes of consecutive 0 is further divided into 2 periods of
+ exactly 2 minutes. Next, the sequence of 6 consecutive 2 is further divided into 3 periods of
+ 2 minutes. Lastly, the last value in the sequence of 3 consecutive 0 at the end of the array is
+ removed to keep only a segment of 2 exactly minutes. In other words, the remainder of the
+ division of a given segment by the desired duration is discarded.
+ """
+ # Convert the threshold to number of samples
+ assertisinstance(threshold,str),"Threshold must be a string, e.g. '5min', '30sec', '15min'"
+ thr_sec=pd.Timedelta(threshold).total_seconds()
+ thr_samp=sf_hypno*thr_sec
+ iffloat(thr_samp).is_integer():
+ thr_samp=int(thr_samp)
+ else:
+ raiseValueError(
+ f"The selected threshold does not result in an whole number of samples ("
+ f"{thr_sec:.3f} seconds * {sf_hypno:.3f} Hz = {thr_samp:.3f} samples)"
+ )
+
+ # Find run starts
+ # https://gist.github.com/alimanfoo/c5977e87111abe8127453b21204c1065
+ assertisinstance(hypno,(list,np.ndarray,pd.Series)),"hypno must be an array."
+ x=np.asarray(hypno)
+ n=x.shape[0]
+ loc_run_start=np.empty(n,dtype=bool)
+ loc_run_start[0]=True
+ loc_run_start[1:]=x[:-1]!=x[1:]
+ run_starts=np.nonzero(loc_run_start)[0]
+ # Find run values
+ run_values=x[loc_run_start]
+ # Find run lengths
+ run_lengths=np.diff(np.append(run_starts,n))
+ seq=pd.DataFrame({"values":run_values,"start":run_starts,"length":run_lengths})
+
+ # Remove runs that are shorter than threshold
+ seq=seq[seq["length"]>=thr_samp].reset_index(drop=True)
+
+ ifnotequal_length:
+ returnseq
+
+ # Divide into epochs of equal length
+ assertthr_samp>0,"Threshold must be non-zero if using equal_length=True."
+ new_seq={"values":[],"start":[],"length":[]}
+
+ fori,rowinseq.iterrows():
+ quotient,remainder=np.divmod(row["length"],thr_samp)
+ new_start=row["start"]
+ ifquotient>0:
+ whilequotient!=0:
+ new_seq["values"].append(row["values"])
+ new_seq["start"].append(new_start)
+ new_seq["length"].append(thr_samp)
+ new_start+=thr_samp
+ quotient-=1
+ else:
+ new_seq["values"].append(row["values"])
+ new_seq["start"].append(row["start"])
+ new_seq["length"].append(row["length"])
+
+ new_seq=pd.DataFrame(new_seq)
+ returnnew_seq
Original code imported from the Visbrain package. """# Convert min_distance_ms
- min_distance=min_distance_ms/1000.*sf
+ min_distance=min_distance_ms/1000.0*sfidx_diff=np.diff(index)condition=idx_diff>1idx_distance=np.where(condition)[0]
@@ -136,8 +138,7 @@
Source code for yasa.others
bad=idx_distance[np.where(distance<min_distance)[0]]# Fill gap between events separated with less than min_distance_msiflen(bad)>0:
- fill=np.hstack([np.arange(index[j]+1,index[j+1])
- fori,jinenumerate(bad)])
+ fill=np.hstack([np.arange(index[j]+1,index[j+1])fori,jinenumerate(bad)])f_index=np.sort(np.append(index,fill))returnf_indexelse:
@@ -170,8 +171,7 @@
[docs]defmoving_transform(x,y=None,sf=100,window=0.3,step=0.1,method="corr",interp=False):"""Moving transformation of one or two time-series. Parameters
@@ -220,8 +220,17 @@
[-51, 3, 31, -99, 33, -47, 5, -97, -47, 90]]]) """fromnumpy.lib.stride_tricksimportas_strided
+
assertaxis<=data.ndim,"Axis value out of range."
- assertisinstance(sf,(int,float)),'sf must be int or float'
- assertisinstance(window,(int,float)),'window must be int or float'
- assertisinstance(step,(int,float,type(None))),('step must be int, '
- 'float or None.')
+ assertisinstance(sf,(int,float)),"sf must be int or float"
+ assertisinstance(window,(int,float)),"window must be int or float"
+ assertisinstance(step,(int,float,type(None))),"step must be int, ""float or None."ifisinstance(sf,float):
- assertsf.is_integer(),'sf must be a whole number.'
+ assertsf.is_integer(),"sf must be a whole number."sf=int(sf)
- assertisinstance(axis,int),'axis must be int.'
+ assertisinstance(axis,int),"axis must be int."# window and step in samples instead of pointswindow*=sfstep=windowifstepisNoneelsestep*sfifisinstance(window,float):
- assertwindow.is_integer(),'window * sf must be a whole number.'
+ assertwindow.is_integer(),"window * sf must be a whole number."window=int(window)ifisinstance(step,float):
- assertstep.is_integer(),'step * sf must be a whole number.'
+ assertstep.is_integer(),"step * sf must be a whole number."step=int(step)assertstep>=1,"Stepsize may not be zero or negative."
- assertwindow<data.shape[axis],("Sliding window size may not exceed "
- "size of selected axis")
+ assertwindow<data.shape[axis],"Sliding window size may not exceed ""size of selected axis"# Define output shapeshape=list(data.shape)
- shape[axis]=np.floor(data.shape[axis]/step-window/step+1
- ).astype(int)
+ shape[axis]=np.floor(data.shape[axis]/step-window/step+1).astype(int)shape.append(window)# Calculate strides and time vector
@@ -546,13 +561,13 @@
Source code for yasa.others
npts_before=int(npts_before)npts_after=int(npts_after)data=np.asarray(data)
- idx=np.asarray(idx,dtype='int')
+ idx=np.asarray(idx,dtype="int")assertidx.ndim==1,"idx must be 1D."assertdata.ndim==1,"data must be 1D."defrng(x):"""Create a range before and after a given value."""
- returnnp.arange(x-npts_before,x+npts_after+1,dtype='int')
+ returnnp.arange(x-npts_before,x+npts_after+1,dtype="int")idx_ep=np.apply_along_axis(rng,1,idx[...,np.newaxis])# We drop the events for which the indices exceed data
diff --git a/docs/build/html/_modules/yasa/plotting.html b/docs/build/html/_modules/yasa/plotting.html
index a773b95..a12daa2 100644
--- a/docs/build/html/_modules/yasa/plotting.html
+++ b/docs/build/html/_modules/yasa/plotting.html
@@ -3,11 +3,14 @@
- yasa.plotting — yasa 0.6.1 documentation
+ yasa.plotting — yasa 0.6.2 documentation
+
+
+
@@ -41,7 +44,7 @@
yasa
- 0.6.1
+ 0.6.2
[docs]defplot_hypnogram(hypno,sf_hypno=1/30,lw=1.5,figsize=(9,3)):""" Plot a hypnogram.
@@ -151,15 +154,15 @@
Source code for yasa.plotting
>>> ax = yasa.plot_hypnogram(hypno) """# Increase font size while preserving original
- old_fontsize=plt.rcParams['font.size']
- plt.rcParams.update({'font.size':18})
+ old_fontsize=plt.rcParams["font.size"]
+ plt.rcParams.update({"font.size":18})# Safety checks
- assertisinstance(hypno,(np.ndarray,pd.Series,list)),'hypno must be an array.'
+ assertisinstance(hypno,(np.ndarray,pd.Series,list)),"hypno must be an array."hypno=np.asarray(hypno).astype(int)assert(hypno>=-2).all()and(hypno<=4).all(),"hypno values must be between -2 to 4."
- asserthypno.ndim==1,'hypno must be a 1D array.'
- assertisinstance(sf_hypno,(int,float)),'sf must be int or float.'
+ asserthypno.ndim==1,"hypno must be a 1D array."
+ assertisinstance(sf_hypno,(int,float)),"sf must be int or float."t_hyp=np.arange(hypno.size)/(sf_hypno*3600)# Make sure that REM is displayed after Wake
@@ -170,40 +173,51 @@
Source code for yasa.plotting
fig,ax0=plt.subplots(nrows=1,figsize=figsize)# Hypnogram (top axis)
- ax0.step(t_hyp,-1*hypno,color='k',lw=lw)
- ax0.step(t_hyp,-1*hypno_rem,color='red',lw=lw)
- ax0.step(t_hyp,-1*hypno_art_uns,color='grey',lw=lw)
+ ax0.step(t_hyp,-1*hypno,color="k",lw=lw)
+ ax0.step(t_hyp,-1*hypno_rem,color="red",lw=lw)
+ ax0.step(t_hyp,-1*hypno_art_uns,color="grey",lw=lw)if-2inhypnoand-1inhypno:# Both Unscored and Artefacts are presentax0.set_yticks([2,1,0,-1,-2,-3,-4])
- ax0.set_yticklabels(['Uns','Art','W','R','N1','N2','N3'])
+ ax0.set_yticklabels(["Uns","Art","W","R","N1","N2","N3"])ax0.set_ylim(-4.5,2.5)elif-2inhypnoand-1notinhypno:# Only Unscored are presentax0.set_yticks([2,0,-1,-2,-3,-4])
- ax0.set_yticklabels(['Uns','W','R','N1','N2','N3'])
+ ax0.set_yticklabels(["Uns","W","R","N1","N2","N3"])ax0.set_ylim(-4.5,2.5)elif-2notinhypnoand-1inhypno:# Only Artefacts are presentax0.set_yticks([1,0,-1,-2,-3,-4])
- ax0.set_yticklabels(['Art','W','R','N1','N2','N3'])
+ ax0.set_yticklabels(["Art","W","R","N1","N2","N3"])ax0.set_ylim(-4.5,1.5)else:# No artefacts or Unscoredax0.set_yticks([0,-1,-2,-3,-4])
- ax0.set_yticklabels(['W','R','N1','N2','N3'])
+ ax0.set_yticklabels(["W","R","N1","N2","N3"])ax0.set_ylim(-4.5,0.5)ax0.set_xlim(0,t_hyp.max())
- ax0.set_ylabel('Stage')
- ax0.set_xlabel('Time [hrs]')
- ax0.spines['right'].set_visible(False)
- ax0.spines['top'].set_visible(False)
+ ax0.set_ylabel("Stage")
+ ax0.set_xlabel("Time [hrs]")
+ ax0.spines["right"].set_visible(False)
+ ax0.spines["top"].set_visible(False)# Revert font-size
- plt.rcParams.update({'font.size':old_fontsize})
+ plt.rcParams.update({"font.size":old_fontsize})returnax0
[docs]defplot_spectrogram(
+ data,
+ sf,
+ hypno=None,
+ win_sec=30,
+ fmin=0.5,
+ fmax=25,
+ trimperc=2.5,
+ cmap="RdBu_r",
+ vmin=None,
+ vmax=None,
+):""" Plot a full-night multi-taper spectrogram, optionally with the hypnogram on top.
@@ -249,6 +263,10 @@
Source code for yasa.plotting
are defined as the 2.5 and 97.5 percentiles of the spectrogram. cmap : str Colormap. Default to 'RdBu_r'.
+ vmin : int or float
+ The lower range of color scale. Overwrites ``trimperc``
+ vmax : int or float
+ The upper range of color scale. Overwrites ``trimperc`` Returns -------
@@ -291,22 +309,28 @@
Source code for yasa.plotting
>>> fig = yasa.plot_spectrogram(data, sf, hypno, cmap='Spectral_r') """# Increase font size while preserving original
- old_fontsize=plt.rcParams['font.size']
- plt.rcParams.update({'font.size':18})
+ old_fontsize=plt.rcParams["font.size"]
+ plt.rcParams.update({"font.size":18})# Safety checks
- assertisinstance(data,np.ndarray),'Data must be a 1D NumPy array.'
- assertisinstance(sf,(int,float)),'sf must be int or float.'
- assertdata.ndim==1,'Data must be a 1D (single-channel) NumPy array.'
- assertisinstance(win_sec,(int,float)),'win_sec must be int or float.'
- assertisinstance(fmin,(int,float)),'fmin must be int or float.'
- assertisinstance(fmax,(int,float)),'fmax must be int or float.'
- assertfmin<fmax,'fmin must be strictly inferior to fmax.'
- assertfmax<sf/2,'fmax must be less than Nyquist (sf / 2).'
+ assertisinstance(data,np.ndarray),"Data must be a 1D NumPy array."
+ assertisinstance(sf,(int,float)),"sf must be int or float."
+ assertdata.ndim==1,"Data must be a 1D (single-channel) NumPy array."
+ assertisinstance(win_sec,(int,float)),"win_sec must be int or float."
+ assertisinstance(fmin,(int,float)),"fmin must be int or float."
+ assertisinstance(fmax,(int,float)),"fmax must be int or float."
+ assertfmin<fmax,"fmin must be strictly inferior to fmax."
+ assertfmax<sf/2,"fmax must be less than Nyquist (sf / 2)."
+ assertisinstance(vmin,(int,float,type(None))),"vmin must be int, float, or None."
+ assertisinstance(vmax,(int,float,type(None))),"vmax must be int, float, or None."
+ ifvminisnotNone:
+ assertisinstance(vmax,(int,float)),"vmax must be int or float if vmin is provided"
+ ifvmaxisnotNone:
+ assertisinstance(vmin,(int,float)),"vmin must be int or float if vmax is provided"# Calculate multi-taper spectrogramnperseg=int(win_sec*sf)
- assertdata.size>2*nperseg,'Data length must be at least 2 * win_sec.'
+ assertdata.size>2*nperseg,"Data length must be at least 2 * win_sec."f,t,Sxx=spectrogram_lspopt(data,sf,nperseg=nperseg,noverlap=0)Sxx=10*np.log10(Sxx)# Convert uV^2 / Hz --> dB / Hz
@@ -317,77 +341,94 @@
Source code for yasa.plotting
t/=3600# Convert t to hours# Normalization
- vmin,vmax=np.percentile(Sxx,[0+trimperc,100-trimperc])
- norm=Normalize(vmin=vmin,vmax=vmax)
+ ifvminisNone:
+ vmin,vmax=np.percentile(Sxx,[0+trimperc,100-trimperc])
+ norm=Normalize(vmin=vmin,vmax=vmax)
+ else:
+ norm=Normalize(vmin=vmin,vmax=vmax)ifhypnoisNone:fig,ax=plt.subplots(nrows=1,figsize=(12,4))im=ax.pcolormesh(t,f,Sxx,norm=norm,cmap=cmap,antialiased=True,shading="auto")ax.set_xlim(0,t.max())
- ax.set_ylabel('Frequency [Hz]')
- ax.set_xlabel('Time [hrs]')
+ ax.set_ylabel("Frequency [Hz]")
+ ax.set_xlabel("Time [hrs]")# Add colorbarcbar=fig.colorbar(im,ax=ax,shrink=0.95,fraction=0.1,aspect=25)
- cbar.ax.set_ylabel('Log Power (dB / Hz)',rotation=270,labelpad=20)
+ cbar.ax.set_ylabel("Log Power (dB / Hz)",rotation=270,labelpad=20)returnfigelse:hypno=np.asarray(hypno).astype(int)
- asserthypno.ndim==1,'Hypno must be 1D.'
- asserthypno.size==data.size,'Hypno must have the same sf as data.'
+ asserthypno.ndim==1,"Hypno must be 1D."
+ asserthypno.size==data.size,"Hypno must have the same sf as data."t_hyp=np.arange(hypno.size)/(sf*3600)# Make sure that REM is displayed after Wakehypno=pd.Series(hypno).map({-2:-2,-1:-1,0:0,1:2,2:3,3:4,4:1}).valueshypno_rem=np.ma.masked_not_equal(hypno,1)fig,(ax0,ax1)=plt.subplots(
- nrows=2,figsize=(12,6),gridspec_kw={'height_ratios':[1,2]})
+ nrows=2,figsize=(12,6),gridspec_kw={"height_ratios":[1,2]}
+ )plt.subplots_adjust(hspace=0.1)# Hypnogram (top axis)
- ax0.step(t_hyp,-1*hypno,color='k')
- ax0.step(t_hyp,-1*hypno_rem,color='r')
+ ax0.step(t_hyp,-1*hypno,color="k")
+ ax0.step(t_hyp,-1*hypno_rem,color="r")if-2inhypnoand-1inhypno:# Both Unscored and Artefacts are presentax0.set_yticks([2,1,0,-1,-2,-3,-4])
- ax0.set_yticklabels(['Uns','Art','W','R','N1','N2','N3'])
+ ax0.set_yticklabels(["Uns","Art","W","R","N1","N2","N3"])ax0.set_ylim(-4.5,2.5)elif-2inhypnoand-1notinhypno:# Only Unscored are presentax0.set_yticks([2,0,-1,-2,-3,-4])
- ax0.set_yticklabels(['Uns','W','R','N1','N2','N3'])
+ ax0.set_yticklabels(["Uns","W","R","N1","N2","N3"])ax0.set_ylim(-4.5,2.5)elif-2notinhypnoand-1inhypno:# Only Artefacts are presentax0.set_yticks([1,0,-1,-2,-3,-4])
- ax0.set_yticklabels(['Art','W','R','N1','N2','N3'])
+ ax0.set_yticklabels(["Art","W","R","N1","N2","N3"])ax0.set_ylim(-4.5,1.5)else:# No artefacts or Unscoredax0.set_yticks([0,-1,-2,-3,-4])
- ax0.set_yticklabels(['W','R','N1','N2','N3'])
+ ax0.set_yticklabels(["W","R","N1","N2","N3"])ax0.set_ylim(-4.5,0.5)ax0.set_xlim(0,t_hyp.max())
- ax0.set_ylabel('Stage')
+ ax0.set_ylabel("Stage")ax0.xaxis.set_visible(False)
- ax0.spines['right'].set_visible(False)
- ax0.spines['top'].set_visible(False)
+ ax0.spines["right"].set_visible(False)
+ ax0.spines["top"].set_visible(False)# Spectrogram (bottom axis)im=ax1.pcolormesh(t,f,Sxx,norm=norm,cmap=cmap,antialiased=True,shading="auto")ax1.set_xlim(0,t.max())
- ax1.set_ylabel('Frequency [Hz]')
- ax1.set_xlabel('Time [hrs]')
+ ax1.set_ylabel("Frequency [Hz]")
+ ax1.set_xlabel("Time [hrs]")# Revert font-size
- plt.rcParams.update({'font.size':old_fontsize})
+ plt.rcParams.update({"font.size":old_fontsize})returnfig
... cbar_title="Pearson correlation") """# Increase font size while preserving original
- old_fontsize=plt.rcParams['font.size']
- plt.rcParams.update({'font.size':fontsize})
- plt.rcParams.update({'savefig.bbox':'tight'})
- plt.rcParams.update({'savefig.transparent':'True'})
+ old_fontsize=plt.rcParams["font.size"]
+ plt.rcParams.update({"font.size":fontsize})
+ plt.rcParams.update({"savefig.bbox":"tight"})
+ plt.rcParams.update({"savefig.transparent":"True"})# Make sure we don't do any in-place modification
- assertisinstance(data,pd.Series),'Data must be a Pandas Series'
+ assertisinstance(data,pd.Series),"Data must be a Pandas Series"data=data.copy()# Add mask, if presentifmaskisnotNone:
- assertisinstance(mask,pd.Series),'mask must be a Pandas Series'
- assertmask.dtype.kindin'bi',"mask must be True/False or 0/1."
+ assertisinstance(mask,pd.Series),"mask must be a Pandas Series"
+ assertmask.dtype.kindin"bi","mask must be True/False or 0/1."else:mask=pd.Series(1,index=data.index,name="mask")
@@ -483,11 +524,11 @@
"""stats={}hypno=np.asarray(hypno)
- asserthypno.ndim==1,'hypno must have only one dimension.'
- asserthypno.size>1,'hypno must have at least two elements.'
+ asserthypno.ndim==1,"hypno must have only one dimension."
+ asserthypno.size>1,"hypno must have at least two elements."# TIB, first and last sleep
- stats['TIB']=len(hypno)
+ stats["TIB"]=len(hypno)first_sleep=np.where(hypno>0)[0][0]last_sleep=np.where(hypno>0)[0][-1]# Crop to SPT
- hypno_s=hypno[first_sleep:(last_sleep+1)]
- stats['SPT']=hypno_s.size
- stats['WASO']=hypno_s[hypno_s==0].size
+ hypno_s=hypno[first_sleep:(last_sleep+1)]
+ stats["SPT"]=hypno_s.size
+ stats["WASO"]=hypno_s[hypno_s==0].size# Before YASA v0.5.0, TST was calculated as SPT - WASO, meaning that Art# and Unscored epochs were included. TST is now restrained to sleep stages.
- stats['TST']=hypno_s[hypno_s>0].size
+ stats["TST"]=hypno_s[hypno_s>0].size# Duration of each sleep stages
- stats['N1']=hypno[hypno==1].size
- stats['N2']=hypno[hypno==2].size
- stats['N3']=hypno[hypno==3].size
- stats['REM']=hypno[hypno==4].size
- stats['NREM']=stats['N1']+stats['N2']+stats['N3']
+ stats["N1"]=hypno[hypno==1].size
+ stats["N2"]=hypno[hypno==2].size
+ stats["N3"]=hypno[hypno==3].size
+ stats["REM"]=hypno[hypno==4].size
+ stats["NREM"]=stats["N1"]+stats["N2"]+stats["N3"]# Sleep stage latencies -- only relevant if hypno is cropped to TIB
- stats['SOL']=first_sleep
- stats['Lat_N1']=np.where(hypno==1)[0].min()if1inhypnoelsenp.nan
- stats['Lat_N2']=np.where(hypno==2)[0].min()if2inhypnoelsenp.nan
- stats['Lat_N3']=np.where(hypno==3)[0].min()if3inhypnoelsenp.nan
- stats['Lat_REM']=np.where(hypno==4)[0].min()if4inhypnoelsenp.nan
+ stats["SOL"]=first_sleep
+ stats["Lat_N1"]=np.where(hypno==1)[0].min()if1inhypnoelsenp.nan
+ stats["Lat_N2"]=np.where(hypno==2)[0].min()if2inhypnoelsenp.nan
+ stats["Lat_N3"]=np.where(hypno==3)[0].min()if3inhypnoelsenp.nan
+ stats["Lat_REM"]=np.where(hypno==4)[0].min()if4inhypnoelsenp.nan# Convert to minutesforkey,valueinstats.items():stats[key]=value/(60*sf_hyp)# Percentage
- stats['%N1']=100*stats['N1']/stats['TST']
- stats['%N2']=100*stats['N2']/stats['TST']
- stats['%N3']=100*stats['N3']/stats['TST']
- stats['%REM']=100*stats['REM']/stats['TST']
- stats['%NREM']=100*stats['NREM']/stats['TST']
- stats['SE']=100*stats['TST']/stats['TIB']
- stats['SME']=100*stats['TST']/stats['SPT']
+ stats["%N1"]=100*stats["N1"]/stats["TST"]
+ stats["%N2"]=100*stats["N2"]/stats["TST"]
+ stats["%N3"]=100*stats["N3"]/stats["TST"]
+ stats["%REM"]=100*stats["REM"]/stats["TST"]
+ stats["%NREM"]=100*stats["NREM"]/stats["TST"]
+ stats["SE"]=100*stats["TST"]/stats["TIB"]
+ stats["SME"]=100*stats["TST"]/stats["SPT"]returnstats
[docs]defbandpower(
+ data,
+ sf=None,
+ ch_names=None,
+ hypno=None,
+ include=(2,3),
+ win_sec=4,
+ relative=True,
+ bandpass=False,
+ bands=[
+ (0.5,4,"Delta"),
+ (4,8,"Theta"),
+ (8,12,"Alpha"),
+ (12,16,"Sigma"),
+ (16,30,"Beta"),
+ (30,40,"Gamma"),
+ ],
+ kwargs_welch=dict(average="median",window="hamming"),
+):""" Calculate the Welch bandpower for each channel and, if specified, for each sleep stage.
@@ -181,74 +197,85 @@
Source code for yasa.spectral
https://github.com/raphaelvallat/yasa/blob/master/notebooks/08_bandpower.ipynb """# Type checks
- assertisinstance(bands,list),'bands must be a list of tuple(s)'
- assertisinstance(relative,bool),'relative must be a boolean'
- assertisinstance(bandpass,bool),'bandpass must be a boolean'
+ assertisinstance(bands,list),"bands must be a list of tuple(s)"
+ assertisinstance(relative,bool),"relative must be a boolean"
+ assertisinstance(bandpass,bool),"bandpass must be a boolean"# Check if input data is a MNE Raw objectifisinstance(data,mne.io.BaseRaw):
- sf=data.info['sfreq']# Extract sampling frequency
+ sf=data.info["sfreq"]# Extract sampling frequencych_names=data.ch_names# Extract channel names
- data=data.get_data()*1e6# Convert from V to uV
+ data=data.get_data(units=dict(eeg="uV",emg="uV",eog="uV",ecg="uV"))_,npts=data.shapeelse:# Safety checks
- assertisinstance(data,np.ndarray),'Data must be a numpy array.'
+ assertisinstance(data,np.ndarray),"Data must be a numpy array."data=np.atleast_2d(data)
- assertdata.ndim==2,'Data must be of shape (nchan, n_samples).'
+ assertdata.ndim==2,"Data must be of shape (nchan, n_samples)."nchan,npts=data.shape# assert nchan < npts, 'Data must be of shape (nchan, n_samples).'
- assertsfisnotNone,'sf must be specified if passing a numpy array.'
+ assertsfisnotNone,"sf must be specified if passing a numpy array."assertisinstance(sf,(int,float))ifch_namesisNone:
- ch_names=['CHAN'+str(i).zfill(3)foriinrange(nchan)]
+ ch_names=["CHAN"+str(i).zfill(3)foriinrange(nchan)]else:ch_names=np.atleast_1d(np.asarray(ch_names,dtype=str))
- assertch_names.ndim==1,'ch_names must be 1D.'
- assertlen(ch_names)==nchan,'ch_names must match data.shape[0].'
+ assertch_names.ndim==1,"ch_names must be 1D."
+ assertlen(ch_names)==nchan,"ch_names must match data.shape[0]."ifbandpass:# Apply FIR bandpass filterall_freqs=np.hstack([[b[0],b[1]]forbinbands])fmin,fmax=min(all_freqs),max(all_freqs)
- data=mne.filter.filter_data(data.astype('float64'),sf,fmin,fmax,verbose=0)
+ data=mne.filter.filter_data(data.astype("float64"),sf,fmin,fmax,verbose=0)win=int(win_sec*sf)# npersegifhypnoisNone:# Calculate the PSD over the whole datafreqs,psd=signal.welch(data,sf,nperseg=win,**kwargs_welch)
- returnbandpower_from_psd(
- psd,freqs,ch_names,bands=bands,relative=relative).set_index('Chan')
+ returnbandpower_from_psd(psd,freqs,ch_names,bands=bands,relative=relative).set_index(
+ "Chan"
+ )else:# Per each sleep stage defined in ``include``.hypno=np.asarray(hypno)
- assertincludeisnotNone,'include cannot be None if hypno is given'
+ assertincludeisnotNone,"include cannot be None if hypno is given"include=np.atleast_1d(np.asarray(include))
- asserthypno.ndim==1,'Hypno must be a 1D array.'
- asserthypno.size==npts,'Hypno must have same size as data.shape[1]'
- assertinclude.size>=1,'`include` must have at least one element.'
- asserthypno.dtype.kind==include.dtype.kind,'hypno and include must have same dtype'
- assertnp.in1d(hypno,include).any(),(
- 'None of the stages specified in `include` are present in hypno.')
+ asserthypno.ndim==1,"Hypno must be a 1D array."
+ asserthypno.size==npts,"Hypno must have same size as data.shape[1]"
+ assertinclude.size>=1,"`include` must have at least one element."
+ asserthypno.dtype.kind==include.dtype.kind,"hypno and include must have same dtype"
+ assertnp.in1d(
+ hypno,include
+ ).any(),"None of the stages specified in `include` are present in hypno."# Initialize empty dataframe and loop over stagesdf_bp=pd.DataFrame([])forstageininclude:ifstagenotinhypno:continuedata_stage=data[:,hypno==stage]
- freqs,psd=signal.welch(data_stage,sf,nperseg=win,
- **kwargs_welch)
- bp_stage=bandpower_from_psd(psd,freqs,ch_names,bands=bands,
- relative=relative)
- bp_stage['Stage']=stage
+ freqs,psd=signal.welch(data_stage,sf,nperseg=win,**kwargs_welch)
+ bp_stage=bandpower_from_psd(psd,freqs,ch_names,bands=bands,relative=relative)
+ bp_stage["Stage"]=stagedf_bp=pd.concat([df_bp,bp_stage],axis=0)
- returndf_bp.set_index(['Stage','Chan'])
[docs]defbandpower_from_psd(
+ psd,
+ freqs,
+ ch_names=None,
+ bands=[
+ (0.5,4,"Delta"),
+ (4,8,"Theta"),
+ (8,12,"Alpha"),
+ (12,16,"Sigma"),
+ (16,30,"Beta"),
+ (30,40,"Gamma"),
+ ],
+ relative=True,
+):"""Compute the average power of the EEG in specified frequency band(s) given a pre-computed PSD.
@@ -277,27 +304,27 @@
Source code for yasa.spectral
Bandpower dataframe, in which each row is a channel and each column a spectral band. """# Type checks
- assertisinstance(bands,list),'bands must be a list of tuple(s)'
- assertisinstance(relative,bool),'relative must be a boolean'
+ assertisinstance(bands,list),"bands must be a list of tuple(s)"
+ assertisinstance(relative,bool),"relative must be a boolean"# Safety checksfreqs=np.asarray(freqs)assertfreqs.ndim==1psd=np.atleast_2d(psd)
- assertpsd.ndim==2,'PSD must be of shape (n_channels, n_freqs).'
+ assertpsd.ndim==2,"PSD must be of shape (n_channels, n_freqs)."all_freqs=np.hstack([[b[0],b[1]]forbinbands])fmin,fmax=min(all_freqs),max(all_freqs)idx_good_freq=np.logical_and(freqs>=fmin,freqs<=fmax)freqs=freqs[idx_good_freq]res=freqs[1]-freqs[0]nchan=psd.shape[0]
- assertnchan<psd.shape[1],'PSD must be of shape (n_channels, n_freqs).'
+ assertnchan<psd.shape[1],"PSD must be of shape (n_channels, n_freqs)."ifch_namesisnotNone:ch_names=np.atleast_1d(np.asarray(ch_names,dtype=str))
- assertch_names.ndim==1,'ch_names must be 1D.'
- assertlen(ch_names)==nchan,'ch_names must match psd.shape[0].'
+ assertch_names.ndim==1,"ch_names must be 1D."
+ assertlen(ch_names)==nchan,"ch_names must match psd.shape[0]."else:
- ch_names=['CHAN'+str(i).zfill(3)foriinrange(nchan)]
+ ch_names=["CHAN"+str(i).zfill(3)foriinrange(nchan)]bp=np.zeros((nchan,len(bands)),dtype=np.float64)psd=psd[:,idx_good_freq]total_power=simps(psd,dx=res)
@@ -309,7 +336,8 @@
Source code for yasa.spectral
"There are negative values in PSD. This will result in incorrect ""bandpower values. We highly recommend working with an ""all-positive PSD. For more details, please refer to: "
- "https://github.com/raphaelvallat/yasa/issues/29")
+ "https://github.com/raphaelvallat/yasa/issues/29"
+ )logger.warning(msg)# Enumerate over the frequency bands
@@ -325,21 +353,30 @@
[docs]defbandpower_from_psd_ndarray(
+ psd,
+ freqs,
+ bands=[
+ (0.5,4,"Delta"),
+ (4,8,"Theta"),
+ (8,12,"Alpha"),
+ (12,16,"Sigma"),
+ (16,30,"Beta"),
+ (30,40,"Gamma"),
+ ],
+ relative=True,
+):"""Compute bandpowers in N-dimensional PSD. This is a NumPy-only implementation of the :py:func:`yasa.bandpower_from_psd` function,
@@ -368,14 +405,14 @@
Source code for yasa.spectral
Bandpower array of shape *(n_bands, ...)*. """# Type checks
- assertisinstance(bands,list),'bands must be a list of tuple(s)'
- assertisinstance(relative,bool),'relative must be a boolean'
+ assertisinstance(bands,list),"bands must be a list of tuple(s)"
+ assertisinstance(relative,bool),"relative must be a boolean"# Safety checksfreqs=np.asarray(freqs)psd=np.asarray(psd)
- assertfreqs.ndim==1,'freqs must be a 1-D array of shape (n_freqs,)'
- assertpsd.shape[-1]==freqs.shape[-1],'n_freqs must be last axis of psd'
+ assertfreqs.ndim==1,"freqs must be a 1-D array of shape (n_freqs,)"
+ assertpsd.shape[-1]==freqs.shape[-1],"n_freqs must be last axis of psd"# Extract frequencies of interestall_freqs=np.hstack([[b[0],b[1]]forbinbands])
@@ -393,7 +430,8 @@
Source code for yasa.spectral
"There are negative values in PSD. This will result in incorrect ""bandpower values. We highly recommend working with an ""all-positive PSD. For more details, please refer to: "
- "https://github.com/raphaelvallat/yasa/issues/29")
+ "https://github.com/raphaelvallat/yasa/issues/29"
+ )logger.warning(msg)# Calculate total power
@@ -416,11 +454,35 @@
[docs]defirasa(
+ data,
+ sf=None,
+ ch_names=None,
+ band=(1,30),
+ hset=[
+ 1.1,
+ 1.15,
+ 1.2,
+ 1.25,
+ 1.3,
+ 1.35,
+ 1.4,
+ 1.45,
+ 1.5,
+ 1.55,
+ 1.6,
+ 1.65,
+ 1.7,
+ 1.75,
+ 1.8,
+ 1.85,
+ 1.9,
+ ],
+ return_fit=True,
+ win_sec=4,
+ kwargs_welch=dict(average="median",window="hamming"),
+ verbose=True,
+):r""" Separate the aperiodic (= fractal, or 1/f) and oscillatory component of the power spectra of EEG data using the IRASA method.
@@ -533,40 +595,41 @@
Source code for yasa.spectral
[5] https://doi.org/10.1101/2021.10.15.464483 """importfractions
+
set_log_level(verbose)# Check if input data is a MNE Raw objectifisinstance(data,mne.io.BaseRaw):
- sf=data.info['sfreq']# Extract sampling frequency
+ sf=data.info["sfreq"]# Extract sampling frequencych_names=data.ch_names# Extract channel names
- hp=data.info['highpass']# Extract highpass filter
- lp=data.info['lowpass']# Extract lowpass filter
- data=data.get_data()*1e6# Convert from V to uV
+ hp=data.info["highpass"]# Extract highpass filter
+ lp=data.info["lowpass"]# Extract lowpass filter
+ data=data.get_data(units=dict(eeg="uV",emg="uV",eog="uV",ecg="uV"))else:# Safety checks
- assertisinstance(data,np.ndarray),'Data must be a numpy array.'
+ assertisinstance(data,np.ndarray),"Data must be a numpy array."data=np.atleast_2d(data)
- assertdata.ndim==2,'Data must be of shape (nchan, n_samples).'
+ assertdata.ndim==2,"Data must be of shape (nchan, n_samples)."nchan,npts=data.shape
- assertnchan<npts,'Data must be of shape (nchan, n_samples).'
- assertsfisnotNone,'sf must be specified if passing a numpy array.'
+ assertnchan<npts,"Data must be of shape (nchan, n_samples)."
+ assertsfisnotNone,"sf must be specified if passing a numpy array."assertisinstance(sf,(int,float))ifch_namesisNone:
- ch_names=['CHAN'+str(i).zfill(3)foriinrange(nchan)]
+ ch_names=["CHAN"+str(i).zfill(3)foriinrange(nchan)]else:ch_names=np.atleast_1d(np.asarray(ch_names,dtype=str))
- assertch_names.ndim==1,'ch_names must be 1D.'
- assertlen(ch_names)==nchan,'ch_names must match data.shape[0].'
+ assertch_names.ndim==1,"ch_names must be 1D."
+ assertlen(ch_names)==nchan,"ch_names must match data.shape[0]."hp=0# Highpass filter unknown -> set to 0 Hzlp=sf/2# Lowpass filter unknown -> set to Nyquist# Check the other argumentshset=np.asarray(hset)
- asserthset.ndim==1,'hset must be 1D.'
- asserthset.size>1,'2 or more resampling fators are required.'
+ asserthset.ndim==1,"hset must be 1D."
+ asserthset.size>1,"2 or more resampling fators are required."hset=np.round(hset,4)# avoid float precision error with np.arange.band=sorted(band)
- assertband[0]>0,'first element of band must be > 0.'
- assertband[1]<(sf/2),'second element of band must be < (sf / 2).'
+ assertband[0]>0,"first element of band must be > 0."
+ assertband[1]<(sf/2),"second element of band must be < (sf / 2)."win=int(win_sec*sf)# nperseg# Inform about maximum resampled fitting range
@@ -577,21 +640,27 @@
Source code for yasa.spectral
logging.info(f"Fitting range: {band[0]:.2f}Hz-{band[1]:.2f}Hz")logging.info(f"Evaluated frequency range: {band_evaluated[0]:.2f}Hz-{band_evaluated[1]:.2f}Hz")ifband_evaluated[0]<hp:
- logging.warning("The evaluated frequency range starts below the "
- f"highpass filter ({hp:.2f}Hz). Increase the lower band"
- f" ({band[0]:.2f}Hz) or decrease the maximum value of "
- f"the hset ({h_max:.2f}).")
+ logging.warning(
+ "The evaluated frequency range starts below the "
+ f"highpass filter ({hp:.2f}Hz). Increase the lower band"
+ f" ({band[0]:.2f}Hz) or decrease the maximum value of "
+ f"the hset ({h_max:.2f})."
+ )ifband_evaluated[1]>lpandlp<freq_Nyq_res:
- logging.warning("The evaluated frequency range ends after the "
- f"lowpass filter ({lp:.2f}Hz). Decrease the upper band"
- f" ({band[1]:.2f}Hz) or decrease the maximum value of "
- f"the hset ({h_max:.2f}).")
+ logging.warning(
+ "The evaluated frequency range ends after the "
+ f"lowpass filter ({lp:.2f}Hz). Decrease the upper band"
+ f" ({band[1]:.2f}Hz) or decrease the maximum value of "
+ f"the hset ({h_max:.2f})."
+ )ifband_evaluated[1]>freq_Nyq_res:
- logging.warning("The evaluated frequency range ends after the "
- "resampled Nyquist frequency "
- f"({freq_Nyq_res:.2f}Hz). Decrease the upper band "
- f"({band[1]:.2f}Hz) or decrease the maximum value "
- f"of the hset ({h_max:.2f}).")
+ logging.warning(
+ "The evaluated frequency range ends after the "
+ "resampled Nyquist frequency "
+ f"({freq_Nyq_res:.2f}Hz). Decrease the upper band "
+ f"({band[1]:.2f}Hz) or decrease the maximum value "
+ f"of the hset ({h_max:.2f})."
+ )# Calculate the original PSD over the whole datafreqs,psd=signal.welch(data,sf,nperseg=win,**kwargs_welch)
@@ -628,6 +697,7 @@
Source code for yasa.spectral
ifreturn_fit:# Aperiodic fit in semilog space for each channelfromscipy.optimizeimportcurve_fit
+
intercepts,slopes,r_squared=[],[],[]deffunc(t,a,b):
@@ -638,26 +708,31 @@
Source code for yasa.spectral
y_log=np.log(y)# Note that here we define bounds for the slope but not for the# intercept.
- popt,pcov=curve_fit(func,freqs,y_log,p0=(2,-1),
- bounds=((-np.inf,-10),(np.inf,2)))
+ popt,pcov=curve_fit(
+ func,freqs,y_log,p0=(2,-1),bounds=((-np.inf,-10),(np.inf,2))
+ )intercepts.append(popt[0])slopes.append(popt[1])# Calculate R^2: https://stackoverflow.com/q/19189362/10581531residuals=y_log-func(freqs,*popt)ss_res=np.sum(residuals**2)
- ss_tot=np.sum((y_log-np.mean(y_log))**2)
+ ss_tot=np.sum((y_log-np.mean(y_log))**2)r_squared.append(1-(ss_res/ss_tot))# Create fit parameters dataframe
- fit_params={'Chan':ch_names,'Intercept':intercepts,
- 'Slope':slopes,'R^2':r_squared,
- 'std(osc)':np.std(psd_osc,axis=-1,ddof=1)}
+ fit_params={
+ "Chan":ch_names,
+ "Intercept":intercepts,
+ "Slope":slopes,
+ "R^2":r_squared,
+ "std(osc)":np.std(psd_osc,axis=-1,ddof=1),
+ }returnfreqs,psd_aperiodic,psd_osc,pd.DataFrame(fit_params)else:returnfreqs,psd_aperiodic,psd_osc
[docs]defstft_power(data,sf,window=2,step=0.2,band=(1,30),interp=True,norm=False):"""Compute the pointwise power via STFT and interpolation. Parameters
@@ -712,7 +787,8 @@
Source code for yasa.spectral
# Compute STFT and remove the last epochf,t,Sxx=signal.stft(
- data,sf,nperseg=nperseg,noverlap=noverlap,detrend=False,padded=True)
+ data,sf,nperseg=nperseg,noverlap=noverlap,detrend=False,padded=True
+ )# Let's keep only the frequency of interestifbandisnotNone:
diff --git a/docs/build/html/_modules/yasa/staging.html b/docs/build/html/_modules/yasa/staging.html
index 10a445b..9dcef5b 100644
--- a/docs/build/html/_modules/yasa/staging.html
+++ b/docs/build/html/_modules/yasa/staging.html
@@ -3,11 +3,14 @@
- yasa.staging — yasa 0.6.1 documentation
+ yasa.staging — yasa 0.6.2 documentation
+
+
+
@@ -41,7 +44,7 @@
yasa
- 0.6.1
+ 0.6.2
# Validate metadataifisinstance(metadata,dict):
- if'age'inmetadata.keys():
- assert0<metadata['age']<120,'age must be between 0 and 120.'
- if'male'inmetadata.keys():
- metadata['male']=int(metadata['male'])
- assertmetadata['male']in[0,1],'male must be 0 or 1.'
+ if"age"inmetadata.keys():
+ assert0<metadata["age"]<120,"age must be between 0 and 120."
+ if"male"inmetadata.keys():
+ metadata["male"]=int(metadata["male"])
+ assertmetadata["male"]in[0,1],"male must be 0 or 1."# Validate Raw instance and load data
- assertisinstance(raw,mne.io.BaseRaw),'raw must be a MNE Raw object.'
- sf=raw.info['sfreq']
+ assertisinstance(raw,mne.io.BaseRaw),"raw must be a MNE Raw object."
+ sf=raw.info["sfreq"]ch_names=np.array([eeg_name,eog_name,emg_name])
- ch_types=np.array(['eeg','eog','emg'])
+ ch_types=np.array(["eeg","eog","emg"])keep_chan=[]forcinch_names:ifcisnotNone:
- assertcinraw.ch_names,'%s does not exist'%c
+ assertcinraw.ch_names,"%s does not exist"%ckeep_chan.append(True)else:keep_chan.append(False)
@@ -284,17 +287,17 @@
Source code for yasa.staging
raw_pick=raw.copy().pick_channels(ch_names,ordered=True)# Downsample if sf != 100
- assertsf>80,'Sampling frequency must be at least 80 Hz.'
+ assertsf>80,"Sampling frequency must be at least 80 Hz."ifsf!=100:raw_pick.resample(100,npad="auto")
- sf=raw_pick.info['sfreq']
+ sf=raw_pick.info["sfreq"]# Get data and convert to microVolts
- data=raw_pick.get_data()*1e6
+ data=raw_pick.get_data(units=dict(eeg="uV",emg="uV",eog="uV",ecg="uV"))# Extract duration of recording in minutesduration_minutes=data.shape[1]/sf/60
- assertduration_minutes>=5,'At least 5 minutes of data is required.'
+ assertduration_minutes>=5,"At least 5 minutes of data is required."# Add to selfself.sf=sf
@@ -320,10 +323,14 @@
# Save features to dataframefeatures=pd.concat(features,axis=1)
- features.index.name='epoch'
+ features.index.name="epoch"# Apply centered rolling average (15 epochs = 7 min 30)# Triang: [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.,# 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125]
- rollc=features.rolling(
- window=15,center=True,min_periods=1,win_type='triang').mean()
+ rollc=features.rolling(window=15,center=True,min_periods=1,win_type="triang").mean()rollc[rollc.columns]=robust_scale(rollc,quantile_range=(5,95))
- rollc=rollc.add_suffix('_c7min_norm')
+ rollc=rollc.add_suffix("_c7min_norm")# Now look at the past 2 minutesrollp=features.rolling(window=4,min_periods=1).mean()rollp[rollp.columns]=robust_scale(rollp,quantile_range=(5,95))
- rollp=rollp.add_suffix('_p2min_norm')
+ rollp=rollp.add_suffix("_p2min_norm")# Add to current set of featuresfeatures=features.join(rollc).join(rollp)
@@ -413,8 +418,8 @@
Source code for yasa.staging
######################################################################## Add temporal features
- features['time_hour']=times/3600
- features['time_norm']=times/times[-1]
+ features["time_hour"]=times/3600
+ features["time_norm"]=times/times[-1]# Add metadata if presentifself.metadataisnotNone:
@@ -425,10 +430,10 @@
Source code for yasa.staging
cols_float=features.select_dtypes(np.float64).columns.tolist()features[cols_float]=features[cols_float].astype(np.float32)# Make sure that age and sex are encoded as int
- if'age'infeatures.columns:
- features['age']=features['age'].astype(int)
- if'male'infeatures.columns:
- features['male']=features['male'].astype(int)
+ if"age"infeatures.columns:
+ features["age"]=features["age"].astype(int)
+ if"male"infeatures.columns:
+ features["male"]=features["male"].astype(int)# Sort the column names here (same behavior as lightGBM)features.sort_index(axis=1,inplace=True)
@@ -445,7 +450,7 @@
Source code for yasa.staging
features : :py:class:`pandas.DataFrame` Feature dataframe. """
- ifnothasattr(self,'_features'):
+ ifnothasattr(self,"_features"):self.fit()returnself._features.copy()
@@ -455,22 +460,32 @@
Source code for yasa.staging
# Note that clf.feature_name_ is only available in lightgbm>=3.0f_diff=np.setdiff1d(clf.feature_name_,self.feature_name_)iflen(f_diff):
- raiseValueError("The following features are present in the "
- "classifier but not in the current features set:",f_diff)
- f_diff=np.setdiff1d(self.feature_name_,clf.feature_name_,)
+ raiseValueError(
+ "The following features are present in the "
+ "classifier but not in the current features set:",
+ f_diff,
+ )
+ f_diff=np.setdiff1d(
+ self.feature_name_,
+ clf.feature_name_,
+ )iflen(f_diff):
- raiseValueError("The following features are present in the "
- "current feature set but not in the classifier:",f_diff)
+ raiseValueError(
+ "The following features are present in the "
+ "current feature set but not in the classifier:",
+ f_diff,
+ )def_load_model(self,path_to_model):"""Load the relevant trained classifier."""ifpath_to_model=="auto":frompathlibimportPath
- clf_dir=os.path.join(str(Path(__file__).parent),'classifiers/')
- name='clf_eeg'
- name=name+'+eog'if'eog'inself.ch_typeselsename
- name=name+'+emg'if'emg'inself.ch_typeselsename
- name=name+'+demo'ifself.metadataisnotNoneelsename
+
+ clf_dir=os.path.join(str(Path(__file__).parent),"classifiers/")
+ name="clf_eeg"
+ name=name+"+eog"if"eog"inself.ch_typeselsename
+ name=name+"+emg"if"emg"inself.ch_typeselsename
+ name=name+"+demo"ifself.metadataisnotNoneelsename# e.g. clf_eeg+eog+emg+demo_lgb_0.4.0.jobliball_matching_files=glob.glob(clf_dir+name+"*.joblib")# Find the latest file
@@ -503,7 +518,7 @@
Source code for yasa.staging
pred : :py:class:`numpy.ndarray` The predicted sleep stages. """
- ifnothasattr(self,'_features'):
+ ifnothasattr(self,"_features"):self.fit()# Load and validate pre-trained classifierclf=self._load_model(path_to_model)
@@ -512,7 +527,7 @@
Source code for yasa.staging
# Predict the sleep stages and probabilitiesself._predicted=clf.predict(X)proba=pd.DataFrame(clf.predict_proba(X),columns=clf.classes_)
- proba.index.name='epoch'
+ proba.index.name="epoch"self._proba=probareturnself._predicted.copy()
@@ -535,13 +550,16 @@
Source code for yasa.staging
proba : :py:class:`pandas.DataFrame` The predicted probability for each sleep stage for each 30-sec epoch of data. """
- ifnothasattr(self,'_proba'):
+ ifnothasattr(self,"_proba"):self.predict(path_to_model)returnself._proba.copy()
[docs]defplot_predict_proba(
+ self,
+ proba=None,
+ majority_only=False,
+ palette=["#99d7f1","#009DDC","xkcd:twilight blue","xkcd:rich purple","xkcd:sunflower"],
+ ):""" Plot the predicted probability for each sleep stage for each 30-sec epoch of data.
@@ -552,16 +570,16 @@
Added the yasa.hypno_find_periods() function to find sequences of consecutive values in hypnogram that are longer than a certain duration. This is a flexible function that can be used to detect NREM/REM periods.
+
Added the yasa.hrv_stage() function, which calculates heart rate (HR) and heart rate variability (HRV) by stage and periods.
+
Added a new dataset containing 8 hours of ECG data. The dataset is in compressed NumPy format and can be found in notebooks/data_ECG_8hrs_200Hz.npz. The dataset also includes an upsampled hypnogram.
Added the yasa.compare_detection() function to determine the correctness of detected events against ground-truth events. It calculates the true positive, false positives and false negatives, and from those, the precision, recall and F1-scores. The input should be the indices of the onset of the event, in samples. It includes a max_distance argument which specifies the tolerance window (in number of samples) for two events to be considered the same.
+
Added the yasa.SpindlesResults.compare_detection() and yasa.SWResults.compare_detection() method. This is a powerful and flexible function that allows to calculate the performance of the current detection against a) another detection or b) ground-truth annotations. For example, we can compare the output of the spindles detection with different thresholds.
Better handling of flat data in yasa.spindles_detect(). The function previously returned a division by zero error if part of the data was flat. See issue 85
+
When using an MNE.Raw object, conversion of the data from Volts to micro-Volts is now performed within MNE. PR 70
This release fixes a CRITICAL BUG with the spindles detection. Specifically, the yasa.spindles_detect() could return different results depending on the sampling rate of the data.
For example, downsampling the data from 256 Hz to 128 Hz may have significantly reduced the number of detected spindles. As explained in issue 54, this bug was caused by a floating-point error
-in numpy.convolve() when calculating the soft spindle threshold. Tests seem to indicate that only certain sampling frequencies were impacted, such as 200 Hz, 256 Hz or 400 Hz. Other sampling frequencies such as 100 Hz and 500 Hz were seemingly not affected by this bug. Please double-check any results obtained with yasa.spindles_detect()!
+in numpy.convolve() when calculating the soft spindle threshold. Tests seem to indicate that only certain sampling frequencies were impacted, such as 200 Hz, 256 Hz or 400 Hz. Other sampling frequencies such as 100 Hz and 500 Hz were seemingly not affected by this bug. Please double-check any results obtained with yasa.spindles_detect()!
Warning
We recommend all users to upgrade to this new version ASAP and check any results obtained with the yasa.spindles_detect() function!
This is a bugfix release. The latest pre-trained classifiers for yasa.SleepStaging were accidentally missing from the previous release. They have now been included in this release.
This is a major release with an important bugfix for the slow-waves detection as well as API-breaking changes in the automatic sleep staging module. We recommend all users to upgrade to this version with pip install –upgrade yasa.
Slow-waves detection
We have fixed a critical bug in yasa.sw_detect() in which the detection could keep slow-waves with invalid duration (e.g. several tens of seconds). We have now added extra safety checks to make sure that the total duration of the slow-waves does not exceed the maximum duration allowed by the dur_neg and dur_pos parameters (default = 2.5 seconds).
This is a major release with several new functions, the biggest of which is the addition of an automatic sleep staging module (yasa.SleepStaging). This means that YASA can now automatically score the sleep stages of your raw EEG data. The classifier was trained and validated on more than 3000 nights from the National Sleep Research Resource (NSRR) website.
Briefly, the algorithm works by calculating a set of features for each 30-sec epochs from a central EEG channel (required), as well as an EOG channel (optional) and an EMG channel (optional). For best performance, users can also specify the age and the sex of the participants. Pre-trained classifiers are already included in YASA. The automatic sleep staging algorithm requires the LightGBM and antropy package.
This is a major release with several API-breaking changes in the spindles, slow-waves and REMs detection.
First, the yasa.spindles_detect_multi() and yasa.sw_detect_multi() have been removed. Instead, the yasa.spindles_detect() and yasa.sw_detect() functions can now handle both single and multi-channel data.
Second, I was getting some feedback that it was difficult to get summary statistics from the detection dataframe. For instance, how can you get the average duration of the detected spindles, per channel and/or per stage? Similarly, how can you get the slow-waves count and density per stage and channel? To address these issues, I’ve now modified the output of the yasa.spindles_detect(), yasa.sw_detect() and yasa.rem_detect() functions, which is now a class (= object) and not a simple Pandas DataFrame. The advantage is that the new output allows you to quickly get the raw data or summary statistics grouped by channel and/or sleep stage using the .summary() method.
The coupling argument has been removed from the yasa.spindles_detect() function. Instead, slow-oscillations / sigma coupling can only be calculated from the slow-waves detection, which is 1) the most standard way, 2) better because PAC assumptions require a strong oscillatory component in the lower frequency range (slow-oscillations). This also avoids unecessary confusion between spindles-derived coupling and slow-waves-derived coupling. For more details, refer to the Jupyter notebooks.
-
Downsampling of data in detection functions has been removed. In other words, YASA will no longer downsample the data to 100 / 128 Hz before applying the events detection. If the detection is too slow, we recommend that you manually downsample your data before applying the detection. See for example mne.filter.resample().
+
Downsampling of data in detection functions has been removed. In other words, YASA will no longer downsample the data to 100 / 128 Hz before applying the events detection. If the detection is too slow, we recommend that you manually downsample your data before applying the detection. See for example mne.filter.resample().
yasa.trimbothstd() can now work with multi-dimensional arrays. The trimmed standard deviation will always be calculated on the last axis of the array.
Filtering and Hilbert transform are now applied at once on all channels (instead of looping across individual channels) in the yasa.spindles_detect() and yasa.sw_detect() functions. This should lead to some improvements in computation time.
This is a major release with several new functions, bugfixes and miscellaneous enhancements in existing functions.
+
Bugfixes
-
Sleep efficiency in the yasa.sleep_statistics() is now calculated using time in bed (TIB) as the denominator instead of sleep period time (SPT), in agreement with the AASM guidelines. The old way of computing the efficiency (TST / SPT) has now been renamed Sleep Maintenance Efficiency (SME).
-
The yasa.sliding_window() now always return an array of shape (n_epochs, …, n_samples), i.e. the epochs are now always the first dimension of the epoched array. This is consistent with MNE default shape of mne.Epochs objects.
+
Sleep efficiency in the yasa.sleep_statistics() is now calculated using time in bed (TIB) as the denominator instead of sleep period time (SPT), in agreement with the AASM guidelines. The old way of computing the efficiency (TST / SPT) has now been renamed Sleep Maintenance Efficiency (SME).
+
The yasa.sliding_window() now always return an array of shape (n_epochs, …, n_samples), i.e. the epochs are now always the first dimension of the epoched array. This is consistent with MNE default shape of mne.Epochs objects.
-
New functions
+
New functions
-
Added yasa.art_detect() to automatically detect artefacts on single or multi-channel EEG data.
-
Added yasa.bandpower_from_psd_ndarray() to calculate band power from a multi-dimensional PSD. This is a NumPy-only implementation and this function will return a np.array and not a pandas DataFrame. This function is useful if you need to calculate the bandpower from a 3-D PSD array, e.g. of shape (n_epochs, n_chan, n_freqs).
-
Added yasa.get_centered_indices() to extract indices in data centered around specific events or peaks.
Added yasa.art_detect() to automatically detect artefacts on single or multi-channel EEG data.
+
Added yasa.bandpower_from_psd_ndarray() to calculate band power from a multi-dimensional PSD. This is a NumPy-only implementation and this function will return a np.array and not a pandas DataFrame. This function is useful if you need to calculate the bandpower from a 3-D PSD array, e.g. of shape (n_epochs, n_chan, n_freqs).
+
Added yasa.get_centered_indices() to extract indices in data centered around specific events or peaks.
yasa.sleep_statistics() now also returns the sleep onset latency, i.e. the latency to the first epoch of any sleep.
-
Added the bandpass argument to yasa.bandpower() to apply a FIR bandpass filter using the lowest and highest frequencies defined in bands. This is useful if you work with absolute power and want to remove contributions from frequency bands of non-interests.
-
The yasa.bandpower_from_psd() now always return the total absolute physical power (TotalAbsPow) of the signal, in units of uV^2 / Hz. This allows to quickly calculate the absolute bandpower from the relative bandpower.
Added the coupling and freq_sp keyword-arguments to the yasa.sw_detect() function. If coupling=True, the function will return the phase of the slow-waves (in radians) at the most prominent peak of sigma-filtered band (PhaseAtSigmaPeak), as well as the normalized mean vector length (ndPAC).
-
Added an section in the 06_sw_detection.ipynb notebooks on how to use relative amplitude thresholds (e.g. z-scores or percentiles) instead of absolute thresholds in slow-waves detection.
-
The upper frequency band for yasa.sw_detect() has been changed from freq_sw=(0.3,3.5) to freq_sw=(0.3,2) Hz to comply with AASM guidelines.
Added the verbose parameter to all detection functions.
-
Added -2 to the default hypnogram format to denote unscored data.
+
yasa.sleep_statistics() now also returns the sleep onset latency, i.e. the latency to the first epoch of any sleep.
+
Added the bandpass argument to yasa.bandpower() to apply a FIR bandpass filter using the lowest and highest frequencies defined in bands. This is useful if you work with absolute power and want to remove contributions from frequency bands of non-interests.
+
The yasa.bandpower_from_psd() now always return the total absolute physical power (TotalAbsPow) of the signal, in units of uV^2 / Hz. This allows to quickly calculate the absolute bandpower from the relative bandpower.
Added the coupling and freq_sp keyword-arguments to the yasa.sw_detect() function. If coupling=True, the function will return the phase of the slow-waves (in radians) at the most prominent peak of sigma-filtered band (PhaseAtSigmaPeak), as well as the normalized mean vector length (ndPAC).
+
Added an section in the 06_sw_detection.ipynb notebooks on how to use relative amplitude thresholds (e.g. z-scores or percentiles) instead of absolute thresholds in slow-waves detection.
+
The upper frequency band for yasa.sw_detect() has been changed from freq_sw=(0.3,3.5) to freq_sw=(0.3,2) Hz to comply with AASM guidelines.
Added the coupling and freq_so keyword-arguments to the yasa.spindles_detect() function. If coupling=True, the function will also returns the phase of the slow-waves (in radians) at the most prominent peak of the spindles. This can be used to perform spindles-SO coupling, as explained in the new Jupyter notebooks on PAC and spindles-SO coupling.
Added the coupling and freq_so keyword-arguments to the yasa.spindles_detect() function. If coupling=True, the function will also returns the phase of the slow-waves (in radians) at the most prominent peak of the spindles. This can be used to perform spindles-SO coupling, as explained in the new Jupyter notebooks on PAC and spindles-SO coupling.
-
Enhancements
+
Enhancements
-
It is now possible to disable one or two out of the three thresholds in the yasa.spindles_detect(). This allows the users to run a simpler detection (for example focusing exclusively on the moving root mean square signal).
-
The yasa.spindles_detect() now returns the timing (in seconds) of the most prominent peak of each spindles (Peak).
-
The yasa.get_sync_sw has been renamed to yasa.get_sync_events() and is now compatible with spindles detection. This can be used for instance to plot the peak-locked grand averaged spindle.
+
It is now possible to disable one or two out of the three thresholds in the yasa.spindles_detect(). This allows the users to run a simpler detection (for example focusing exclusively on the moving root mean square signal).
+
The yasa.spindles_detect() now returns the timing (in seconds) of the most prominent peak of each spindles (Peak).
+
The yasa.get_sync_sw has been renamed to yasa.get_sync_events() and is now compatible with spindles detection. This can be used for instance to plot the peak-locked grand averaged spindle.
-
Code testing
+
Code testing
-
Removed Travis and AppVeyor testing for Python 3.5.
+
Removed Travis and AppVeyor testing for Python 3.5.
One can now directly pass a raw MNE object in several multi-channel functions of YASA, instead of manually passing data, sf, and ch_names. YASA will automatically convert MNE data from Volts to uV, and extract the sampling frequency and channel names. Examples of this can be found in the Jupyter notebooks examples.
+
Added bandpower function
+
One can now directly pass a raw MNE object in several multi-channel functions of YASA, instead of manually passing data, sf, and ch_names. YASA will automatically convert MNE data from Volts to uV, and extract the sampling frequency and channel names. Examples of this can be found in the Jupyter notebooks examples.
Added REM detection (rem_detect) on LOC and ROC EOG channels + example notebook
-
Added yasa/hypno.py file, with several functions to load and upsample sleep stage vector (hypnogram).
-
Added yasa/spectral.py file, which includes the bandpower_from_psd function to calculate the single or multi-channel spectral power in specified bands from a pre-computed PSD (see example notebook at notebooks/10_bandpower.ipynb)
+
Added REM detection (rem_detect) on LOC and ROC EOG channels + example notebook
+
Added yasa/hypno.py file, with several functions to load and upsample sleep stage vector (hypnogram).
+
Added yasa/spectral.py file, which includes the bandpower_from_psd function to calculate the single or multi-channel spectral power in specified bands from a pre-computed PSD (see example notebook at notebooks/10_bandpower.ipynb)
Added get_sync_sw function to get the synchronized timings of landmarks timepoints in slow-wave sleep. This can be used in combination with seaborn.lineplot to plot an average template of the detected slow-wave, per channel.
+
Added get_sync_sw function to get the synchronized timings of landmarks timepoints in slow-wave sleep. This can be used in combination with seaborn.lineplot to plot an average template of the detected slow-wave, per channel.