diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 2c8855d..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 6da4e34..12e3474 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-latest", "windows-latest"] + os: ["ubuntu-latest"] python-version: ["3.9", "3.10"] steps: - name: Checkout ibl-neuropixel repo diff --git a/release_notes.md b/release_notes.md index f03c5d6..1a802e4 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,3 +1,13 @@ +# 1.4 + +## 1.4.0 2024-10-05 +- Waveform extraction: + - Optimization of the waveform extractor, outputs flattened waveforms + - Refactoring ot the waveform loader with back compability +- Bad channel detector: + - The bad channel detector has a plot option to visualize the bad channels and thresholds + - The default low-cut filters are set to 300Hz for AP band and 2 Hz for LF band + # 1.3 ## 1.3.2 2024-09-18 diff --git a/setup.py b/setup.py index 70a90ed..34daece 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name="ibl-neuropixel", - version="1.3.2", + version="1.4.0", author="The International Brain Laboratory", description="Collection of tools for Neuropixel 1.0 and 2.0 probes data", long_description=long_description, diff --git a/src/.DS_Store b/src/.DS_Store deleted file mode 100644 index b21c9d1..0000000 Binary files a/src/.DS_Store and /dev/null differ diff --git a/src/ibldsp/plots.py b/src/ibldsp/plots.py index 9ab2487..4434b6a 100644 --- a/src/ibldsp/plots.py +++ b/src/ibldsp/plots.py @@ -1,9 +1,8 @@ import numpy as np import matplotlib.pyplot as plt -import scipy.signal -def show_channels_labels(raw, fs, channel_labels, xfeats): +def show_channels_labels(raw, fs, channel_labels, xfeats, similarity_threshold, psd_hf_threshold=0.02): """ Shows the features side by side a snippet of raw data :param sr: @@ -11,26 +10,24 @@ def show_channels_labels(raw, fs, channel_labels, xfeats): """ nc, ns = raw.shape ns_plot = np.minimum(ns, 3000) - vaxis_uv = 75 - sos_hp = scipy.signal.butter(**{"N": 3, "Wn": 300 / fs * 2, "btype": "highpass"}, output="sos") - butt = scipy.signal.sosfiltfilt(sos_hp, raw) + vaxis_uv = 250 if fs < 2600 else 75 fig, ax = plt.subplots(1, 5, figsize=(18, 6), gridspec_kw={'width_ratios': [1, 1, 1, 8, .2]}) ax[0].plot(xfeats['xcor_hf'], np.arange(nc)) - ax[0].plot(xfeats['xcor_hf'][(iko := channel_labels == 1)], np.arange(nc)[iko], 'r*') - ax[0].plot([- .5, -.5], [0, nc], 'r--') + ax[0].plot(xfeats['xcor_hf'][(iko := channel_labels == 1)], np.arange(nc)[iko], 'k*') + ax[0].plot(similarity_threshold[0] * np.ones(2), [0, nc], 'k--') + ax[0].plot(similarity_threshold[1] * np.ones(2), [0, nc], 'r--') ax[0].set(ylabel='channel #', xlabel='high coherence', ylim=[0, nc], title='a) dead channel') ax[1].plot(xfeats['psd_hf'], np.arange(nc)) ax[1].plot(xfeats['psd_hf'][(iko := channel_labels == 2)], np.arange(nc)[iko], 'r*') - ax[1].plot([.02, .02], [0, nc], 'r--') - + ax[1].plot(psd_hf_threshold * np.array([1, 1]), [0, nc], 'r--') ax[1].set(yticklabels=[], xlabel='PSD', ylim=[0, nc], title='b) noisy channel') ax[1].sharey(ax[0]) ax[2].plot(xfeats['xcor_lf'], np.arange(nc)) - ax[2].plot(xfeats['xcor_lf'][(iko := channel_labels == 3)], np.arange(nc)[iko], 'r*') - ax[2].plot([-.75, -.75], [0, nc], 'r--') - ax[2].set(yticklabels=[], xlabel='low coherence', ylim=[0, nc], title='c) outside') + ax[2].plot(xfeats['xcor_lf'][(iko := channel_labels == 3)], np.arange(nc)[iko], 'y*') + ax[2].plot([-.75, -.75], [0, nc], 'y--') + ax[2].set(yticklabels=[], xlabel='LF coherence', ylim=[0, nc], title='c) outside') ax[2].sharey(ax[0]) - im = ax[3].imshow(butt[:, :ns_plot] * 1e6, origin='lower', cmap='PuOr', aspect='auto', + im = ax[3].imshow(raw[:, :ns_plot] * 1e6, origin='lower', cmap='PuOr', aspect='auto', vmin=-vaxis_uv, vmax=vaxis_uv, extent=[0, ns_plot / fs * 1e3, 0, nc]) ax[3].set(yticklabels=[], title='d) Raw data', xlabel='time (ms)', ylim=[0, nc]) ax[3].grid(False) diff --git a/src/ibldsp/utils.py b/src/ibldsp/utils.py index 1b96b6b..6ad9719 100644 --- a/src/ibldsp/utils.py +++ b/src/ibldsp/utils.py @@ -267,7 +267,39 @@ def __init__(self, ns, nswin, overlap): self.iw = None @property - def firstlast(self): + def firstlast_splicing(self): + """ + Generator that yields the indices as well as an amplitude function that can be used + to splice the windows together. + In the overlap, the amplitude function gradually transitions the amplitude from one window + to the next. The amplitudes always sum to one (ie. windows are symmetrical) + + :return: tuple of (first_index, last_index, amplitude_vector] + """ + w = scipy.signal.windows.hann((self.overlap + 1) * 2 + 1, sym=True)[1:self.overlap + 1] + assert np.all(np.isclose(w + np.flipud(w), 1)) + + for first, last in self.firstlast: + amp = np.ones(last - first) + amp[:self.overlap] = 1 if first == 0 else w + amp[-self.overlap:] = 1 if last == self.ns else np.flipud(w) + yield (first, last, amp) + + @property + def firstlast_valid(self): + """ + Generator that yields a tuple of first, last, first_valid, last_valid index of windows + The valid indices span up to half of the overlap + :return: + """ + assert self.overlap % 2 == 0, "Overlap must be even" + for first, last in self.firstlast: + first_valid = 0 if first == 0 else first + self.overlap // 2 + last_valid = last if last == self.ns else last - self.overlap // 2 + yield (first, last, first_valid, last_valid) + + @property + def firstlast(self, return_valid=False): """ Generator that yields first and last index of windows diff --git a/src/ibldsp/voltage.py b/src/ibldsp/voltage.py index e59496d..82b04d6 100644 --- a/src/ibldsp/voltage.py +++ b/src/ibldsp/voltage.py @@ -142,40 +142,28 @@ def fk( return xf * gain -def car(x, collection=None, lagc=300, butter_kwargs=None, **kwargs): +def car(x, collection=None, operator='median', **kwargs): """ Applies common average referencing with optional automatic gain control - :param x: the input array to be filtered. dimension, the filtering is considering + :param x: np.array(nc, ns) the input array to be de-referenced. dimension, the filtering is considering axis=0: spatial dimension, axis=1 temporal dimension. (ntraces, ns) - :param collection: - :param lagc: window size for time domain automatic gain control (no agc otherwise) - :param butter_kwargs: filtering parameters: defaults: {'N': 3, 'Wn': 0.1, 'btype': 'highpass'} + :param collection: vector length ntraces. Each unique value set of traces is a collection and will be handled + separately. Useful for shanks. + :param operator: 'median' or 'average' :return: """ - if butter_kwargs is None: - butter_kwargs = {"N": 3, "Wn": 0.1, "btype": "highpass"} if collection is not None: xout = np.zeros_like(x) for c in np.unique(collection): sel = collection == c - xout[sel, :] = kfilt( - x=x[sel, :], - ntr_pad=0, - ntr_tap=None, - collection=None, - butter_kwargs=butter_kwargs, - ) + xout[sel, :] = car(x=x[sel, :], collection=None, **kwargs) return xout - # apply agc and keep the gain in handy - if not lagc: - xf = np.copy(x) - gain = 1 - else: - xf, gain = agc(x, wl=lagc, si=1.0) - # apply CAR and then un-apply the gain - xf = xf - np.median(xf, axis=0) - return xf * gain + if operator == 'median': + x = x - np.median(x, axis=0) + elif operator == 'average': + x = x - np.mean(x, axis=0) + return x def kfilt( @@ -390,21 +378,18 @@ def destripe( return x -def destripe_lfp(x, fs, channel_labels=None, **kwargs): +def destripe_lfp(x, fs, channel_labels=None, butter_kwargs=None, k_filter=False): """ - Wrapper around the destipe function with some default parameters to destripe the LFP band + Wrapper around the destripe function with some default parameters to destripe the LFP band See help destripe function for documentation - :param x: - :param fs: - :return: + :param x: demultiplexed array (nc, ns) + :param fs: sampling frequency + :param channel_labels: see destripe """ - kwargs["butter_kwargs"] = {"N": 3, "Wn": 2 / fs * 2, "btype": "highpass"} - kwargs["k_filter"] = False + butter_kwargs = {"N": 3, "Wn": [0.5, 300], "btype": "bandpass", "fs": fs} if butter_kwargs is None else butter_kwargs if channel_labels is True: - kwargs["channel_labels"], _ = detect_bad_channels( - x, fs=fs, psd_hf_threshold=1.4 - ) - return destripe(x, fs, **kwargs) + channel_labels, _ = detect_bad_channels(x, fs=fs, psd_hf_threshold=1.4) + return destripe(x, fs, butter_kwargs=butter_kwargs, k_filter=k_filter, channel_labels=channel_labels) def decompress_destripe_cbin( @@ -632,11 +617,9 @@ def my_function(i_chunk, n_chunk): saturation_data = np.load(file_saturation) assert rms_data.shape[0] == time_data.shape[0] * ncv rms_data = rms_data.reshape(time_data.shape[0], ncv) - output_qc_path = output_qc_path or output_file.parent + output_qc_path = output_file.parent if output_qc_path is None else output_qc_path np.save(output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.rms.npy"), rms_data) - np.save( - output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.timestamps.npy"), time_data - ) + np.save(output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.timestamps.npy"), time_data) np.save(output_qc_path.joinpath("_iblqc_ephysSaturation.samples.npy"), saturation_data) @@ -715,15 +698,18 @@ def nxcor(x, ref): raw = raw - np.mean(raw, axis=-1)[:, np.newaxis] # removes DC offset xcor = channels_similarity(raw) fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz - if psd_hf_threshold is None: - # the LFP band data is obviously much stronger so auto-adjust the default threshold - psd_hf_threshold = 1.4 if fs < 5000 else 0.02 - sos_hp = scipy.signal.butter( - **{"N": 3, "Wn": 300 / fs * 2, "btype": "highpass"}, output="sos" - ) + # auto-detection of the band with which we are working + band = 'ap' if fs > 2600 else 'lf' + # the LFP band data is obviously much stronger so auto-adjust the default threshold + if band == 'ap': + psd_hf_threshold = 0.02 if psd_hf_threshold is None else psd_hf_threshold + filter_kwargs = {"N": 3, "Wn": 300 / fs * 2, "btype": "highpass"} + elif band == 'lf': + psd_hf_threshold = 1.4 if psd_hf_threshold is None else psd_hf_threshold + filter_kwargs = {"N": 3, "Wn": 1 / fs * 2, "btype": "highpass"} + sos_hp = scipy.signal.butter(**filter_kwargs, output="sos") hf = scipy.signal.sosfiltfilt(sos_hp, raw) xcorf = channels_similarity(hf) - xfeats = { "ind": np.arange(nc), "rms_raw": utils.rms(raw), # very similar to the rms avfter butterworth filter @@ -754,7 +740,8 @@ def nxcor(x, ref): # from ibllib.plots.figures import ephys_bad_channels # ephys_bad_channels(x, 30000, ichannels, xfeats) if display: - ibldsp.plots.show_channels_labels(raw, fs, ichannels, xfeats) + ibldsp.plots.show_channels_labels( + raw, fs, ichannels, xfeats, similarity_threshold=similarity_threshold, psd_hf_threshold=psd_hf_threshold) return ichannels, xfeats diff --git a/src/ibldsp/waveform_extraction.py b/src/ibldsp/waveform_extraction.py index 9075cb0..5041480 100644 --- a/src/ibldsp/waveform_extraction.py +++ b/src/ibldsp/waveform_extraction.py @@ -1,9 +1,9 @@ import logging +from pathlib import Path import scipy import pandas as pd import numpy as np -from pathlib import Path from numpy.lib.format import open_memmap from joblib import Parallel, delayed, cpu_count @@ -11,10 +11,25 @@ from ibldsp.voltage import detect_bad_channels, interpolate_bad_channels, car, kfilt from ibldsp.fourier import fshift from ibldsp.utils import make_channel_index +from iblutil.numerical import ismember logger = logging.getLogger(__name__) +def aggregate_by_clusters(df_wavs): + """ + Group by the waveform dataframe by clusters + :param df_wavs: + :return: + """ + df_clusters = df_wavs.loc[df_wavs['sample'] >= 0, :].groupby('cluster').aggregate( + count=pd.NamedAgg(column="cluster", aggfunc="count"), + first_index=pd.NamedAgg(column="waveform_index", aggfunc="min"), + last_index=pd.NamedAgg(column="waveform_index", aggfunc="max"), + ) + return df_clusters + + def extract_wfs_array( arr, df, @@ -41,15 +56,15 @@ def extract_wfs_array( """ # This is to do fast index assignment to assign missing channels (out of the probe) to NaN if add_nan_trace: - newcol = np.empty((arr.shape[0], 1)) + newcol = np.empty((1, arr.shape[1])) newcol[:] = np.nan - arr = np.hstack([arr, newcol]) + arr = np.vstack([arr, newcol]) # check that the spike window is included in the recording: last_idx = df["sample"].iloc[-1] assert ( - last_idx + (spike_length_samples - trough_offset) < arr.shape[0] - ), f"Spike index {last_idx} extends past end of recording ({arr.shape[0]} samples)." + last_idx + (spike_length_samples - trough_offset) < arr.shape[1] + ), f"Spike index {last_idx} extends past end of recording ({arr.shape[1]} samples)." nwf = len(df) @@ -62,7 +77,7 @@ def extract_wfs_array( ) nchan = cind.shape[1] - wfs = np.zeros((nwf, spike_length_samples, nchan), arr.dtype) + wfs = np.zeros((nwf, nchan, spike_length_samples), arr.dtype) fun = range if verbose: try: @@ -72,9 +87,9 @@ def extract_wfs_array( except ImportError: pass for i in fun(nwf): - wfs[i, :, :] = arr[sind[i], :][:, cind[i]] + wfs[i, :, :] = arr[:, sind[i]][cind[i], :] - return wfs.swapaxes(1, 2), cind, trough_offset + return wfs, cind, trough_offset def _get_channel_labels(sr, num_snippets=20, verbose=True): @@ -155,20 +170,25 @@ def _make_wfs_table( wf_flat = pd.DataFrame( { "index": np.arange(wf_idx.shape[0]), - "sample": spike_samples[wf_idx].astype(int), + "sample": spike_samples[wf_idx].astype(np.int64), "cluster": spike_clusters[wf_idx].astype(int), "peak_channel": spike_channels[wf_idx].astype(int), + "waveform_index": np.zeros(wf_idx.shape[0], int), } ) + # we pre-compute the final absolute indices of each waveform + unique_clusters, cluster_index, cluster_counts = np.unique( + wf_flat["cluster"], return_inverse=True, return_counts=True) + index_order_clusters = np.argsort(cluster_index, kind='stable') + wf_flat.loc[index_order_clusters, 'waveform_index'] = np.arange(wf_flat.shape[0]) # 3d "flat" version return wf_flat, unit_ids def write_wfs_chunk( i_chunk, cbin, - wfs_fn, - mmap_shape, + wfs_mmap, geom_dict, channel_labels, channel_neighbors, @@ -190,8 +210,6 @@ def write_wfs_chunk( my_sr = spikeglx.Reader(cbin, **reader_kwargs) s0, s1 = sr_sl - wfs_mmap = open_memmap(wfs_fn, shape=mmap_shape, mode="r+", dtype=np.float32) - if i_chunk == 0: offset = 0 else: @@ -204,15 +222,15 @@ def write_wfs_chunk( snip = my_sr[ s0 - offset:s1 + spike_length_samples - trough_offset, :-my_sr.nsync - ] + ].T if "butterworth" in preprocess_steps: butter_kwargs = {"N": 3, "Wn": 300 / my_sr.fs * 2, "btype": "highpass"} sos = scipy.signal.butter(**butter_kwargs, output="sos") - snip = scipy.signal.sosfiltfilt(sos, snip.T).T + snip = scipy.signal.sosfiltfilt(sos, snip) if "phase_shift" in preprocess_steps: - snip = fshift(snip, geom_dict["sample_shift"], axis=0) + snip = fshift(snip, geom_dict["sample_shift"], axis=-1) if "bad_channel_interpolation" in preprocess_steps: snip = interpolate_bad_channels( @@ -225,7 +243,7 @@ def write_wfs_chunk( k_kwargs = { "ntr_pad": 60, "ntr_tap": 0, - "lagc": int(my_sr.fs / 10), + "lagc": 0, # no agc for the median estimator of common reference channel "butter_kwargs": {"N": 3, "Wn": 0.01, "btype": "highpass"}, } if "car" in preprocess_steps: @@ -235,12 +253,8 @@ def write_wfs_chunk( if "kfilt" in preprocess_steps: kfilt_func = lambda dat: kfilt(dat, **k_kwargs) # noqa: E731 snip = kfilt_func(snip) - - wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array( - snip, df, channel_neighbors, add_nan_trace=True - )[0] - - wfs_mmap.flush() + iw = wf_flat['waveform_index'].values + wfs_mmap[iw, :, :] = extract_wfs_array(snip, df, channel_neighbors, add_nan_trace=True)[0] def extract_wfs_cbin( @@ -255,11 +269,12 @@ def extract_wfs_cbin( trough_offset=42, spike_length_samples=128, chunksize_samples=int(3000), - reader_kwargs={}, + reader_kwargs=None, n_jobs=None, wfs_dtype=np.float32, - preprocess_steps=[], - seed=None + preprocess_steps=None, + seed=None, + scratch_dir=None, ): """ Given a bin file and locations of spikes, extract waveforms for each unit, compute @@ -269,7 +284,7 @@ def extract_wfs_cbin( reference procedure is applied in the spatial dimension. The following files will be generated: - - waveforms.traces.npy: `(num_units, max_wf, nc, spike_length_samples)` + - waveforms.traces.npy: `(total_waveforms, nc, spike_length_samples)` This file contains the lightly processed waveforms indexed by cluster in the first dimension. By default `max_wf=256, nc=40, spike_length_samples=128`. @@ -307,8 +322,11 @@ def extract_wfs_cbin( :param wfs_dtype: Data type of raw waveforms saved (default np.float32) :param preprocess: Preprocessing options to apply, list which must be a subset of ["phase_shift", "bad_channel_interpolation", "butterworth", "car", "kfilt"] + By default a butterworth 300Hz high-pass and the rephasing of the channels is perfomed """ n_jobs = n_jobs or int(cpu_count() / 2) + preprocess_steps = ['butterworth', 'phase_shift'] if preprocess_steps is None else preprocess_steps + reader_kwargs = {} if reader_kwargs is None else reader_kwargs assert set(preprocess_steps).issubset( { @@ -327,6 +345,13 @@ def extract_wfs_cbin( if h is None: h = sr.geometry + if sr.is_mtscomp: + bin_file = sr.decompress_to_scratch(scratch_dir=scratch_dir) + sr = spikeglx.Reader(bin_file, **reader_kwargs) + file_to_unlink = bin_file + else: + file_to_unlink = None + s0_arr = np.arange(0, sr.ns, chunksize_samples) s1_arr = s0_arr + chunksize_samples s1_arr[-1] = sr.ns @@ -353,7 +378,7 @@ def extract_wfs_cbin( elif channel_labels is None: channel_labels = np.zeros(sr.nc - sr.nsync) - nwf = len(wf_flat) + nwf = wf_flat.shape[0] nu = unit_ids.shape[0] logger.info(f"Extracting {nwf} waveforms from {nu} units") @@ -365,9 +390,9 @@ def extract_wfs_cbin( # this intermediate memmap is written to in parallel # the waveforms are ordered only by their chronological position # in the recording, as we are reading them in time chunks - int_fn = output_dir.joinpath("_wf_extract_intermediate.npy") + traces_fn = output_dir.joinpath("waveforms.traces.npy") wfs = open_memmap( - int_fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 + traces_fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 ) slices = [ @@ -379,8 +404,7 @@ def extract_wfs_cbin( delayed(write_wfs_chunk)( i, bin_file, - int_fn, - wfs.shape, + wfs, h, channel_labels, channel_neighbors, @@ -396,86 +420,46 @@ def extract_wfs_cbin( ) # output files - traces_fn = output_dir.joinpath("waveforms.traces.npy") templates_fn = output_dir.joinpath("waveforms.templates.npy") table_fn = output_dir.joinpath("waveforms.table.pqt") channels_fn = output_dir.joinpath("waveforms.channels.npz") - ## rearrange and save traces by unit + ## rearrange dataframe: sort waveforms by cluster and aggregate by cluster + wf_flat.sort_values(by=["cluster", "sample"], inplace=True) + df_clusters = aggregate_by_clusters(wf_flat) + + # we want to store the index of the waveform within each cluster to facilitate loading later + wf_flat['index_within_clusters'] = np.ones(wf_flat.shape[0]) + inewc = np.diff(wf_flat['cluster'].values, prepend=wf_flat['cluster'].values[0]) != 0 + wf_flat.loc[inewc, 'index_within_clusters'] = - df_clusters['count'].values[:-1] + 1 + wf_flat['index_within_clusters'] = np.cumsum(wf_flat['index_within_clusters'].values).astype(int) - 1 + # store medians across waveforms wfs_templates = np.full((nu, nc, spike_length_samples), np.nan, dtype=np.float32) - # create waveform output file (~2-3 GB) - traces_by_unit = open_memmap( - traces_fn, - mode="w+", - shape=(nu, max_wf, nc, spike_length_samples), - dtype=wfs_dtype, - ) logger.info("Writing to output files") - - for i, u in enumerate(unit_ids): - idx = np.where(wf_flat["cluster"] == u)[0] - nwf_u = idx.shape[0] - # reopening these memmaps on each iteration - # forces Python to clean up each large array it loads - # and prevent a memory leak - wfs = open_memmap( - int_fn, mode="r+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 - ) - traces_by_unit = open_memmap( - traces_fn, - mode="r+", - shape=(nu, max_wf, nc, spike_length_samples), - dtype=wfs_dtype, - ) - # write up to 256 waveforms and leave the rest of dimensions 1-3 as NaNs - traces_by_unit[i, : min(max_wf, nwf_u), :, :] = wfs[idx].astype(wfs_dtype) - traces_by_unit.flush() - # populate this array in memory as it's 256x smaller - wfs_templates[i, :, :] = np.nanmedian(wfs[idx], axis=0) - - # cleanup intermediate file - int_fn.unlink() - + wfs = open_memmap(traces_fn) + for i, rec in enumerate(df_clusters.itertuples()): + wfs_templates[i] = np.nanmedian(wfs[rec.first_index:rec.last_index + 1], axis=0) # save templates np.save(templates_fn, wfs_templates) + # save the waveform table - # add in dummy rows and order by unit, and then sample - unit_counts = wf_flat.groupby("cluster")["sample"].count().reset_index(name="count") - unit_counts["missing"] = max_wf - unit_counts["count"] - missing_wf = unit_counts[unit_counts["missing"] > 0] - total_missing = sum(missing_wf.missing) - extra_rows = pd.DataFrame( - { - "sample": [np.nan] * total_missing, - "peak_channel": [np.nan] * total_missing, - "index": [np.nan] * total_missing, - "cluster": sum( - [[row["cluster"]] * row["missing"] for _, row in missing_wf.iterrows()], - [], - ), - } - ) - save_df = pd.concat([wf_flat, extra_rows]) - # now the waveforms are arranged by cluster, and then in time - # these match dimensions 0 and 1 of waveforms.traces.npy - save_df.sort_values(["cluster", "sample"], inplace=True) - save_df.to_parquet(table_fn) + wf_flat.to_parquet(table_fn) # save channel map for each waveform # these values are now reordered so that they match the pqt # and the traces file - peak_channel = np.nan_to_num(save_df["peak_channel"].to_numpy(), nan=-1).astype( - np.int16 - ) - dummy_idx = np.where(peak_channel >= 0)[0] - # leave "missing" waveforms as -1 since we can't have NaN with int dtype - chan_map = np.ones((max_wf * nu, nc), np.int16) * -1 - chan_map[dummy_idx] = channel_neighbors[peak_channel[dummy_idx].astype(int)] + peak_channel = np.nan_to_num(wf_flat["peak_channel"].to_numpy(), nan=-1).astype(np.int16) + chan_map = channel_neighbors[peak_channel.astype(int)] np.savez(channels_fn, channels=chan_map) + # clean up the cached bin file + if file_to_unlink is not None: + file_to_unlink.with_suffix(".meta").unlink() + file_to_unlink.unlink() class WaveformsLoader: + data_version = None """ Interface to the output of `extract_wfs_cbin`. Requires the following four files to @@ -502,26 +486,17 @@ class WaveformsLoader: WaveformsLoader.load_waveforms() and random_waveforms() allow selection of a subset of waveforms. - """ def __init__( self, data_dir, - max_wf=256, trough_offset=42, - spike_length_samples=128, - num_channels=40, - wfs_dtype=np.float32 + **kwargs, ): self.data_dir = Path(data_dir) - self.max_wf = max_wf self.trough_offset = trough_offset - self.spike_length_samples = spike_length_samples - self.num_channels = num_channels - self.wfs_dtype = wfs_dtype - self.traces_fp = self.data_dir.joinpath("waveforms.traces.npy") self.templates_fp = self.data_dir.joinpath("waveforms.templates.npy") self.table_fp = self.data_dir.joinpath("waveforms.table.pqt") @@ -532,89 +507,92 @@ def __init__( assert self.table_fp.exists(), "waveforms.table.pqt file missing!" assert self.channels_fp.exists(), "waveforms.channels.npz file missing!" - # ingest parquet table - self.table = pd.read_parquet(self.table_fp).reset_index(drop=True).drop(columns=["index"]) - self.table["sample"] = self.table["sample"].astype("Int64") - self.table["peak_channel"] = self.table["peak_channel"].astype("Int64") - self.num_labels = self.table["cluster"].nunique() - self.labels = np.array(self.table["cluster"].unique()) - self.total_wfs = sum(~self.table["peak_channel"].isna()) - self.table["wf_number"] = np.tile(np.arange(self.max_wf), self.num_labels) - self.table["linear_index"] = np.arange(len(self.table)) - - traces_shape = (self.num_labels, max_wf, num_channels, spike_length_samples) - templates_shape = (self.num_labels, num_channels, spike_length_samples) + self.traces = np.lib.format.open_memmap(self.traces_fp) + self.df_wav = pd.read_parquet(self.table_fp).reset_index(drop=True).drop(columns=["index"]) + if len(self.traces.shape) == 4: + self.data_version = 1 + self.df_wav["sample"] = self.df_wav["sample"].astype('Int64') + self.df_wav["peak_channel"] = self.df_wav["peak_channel"].astype('Int64') + self.df_wav['waveform_index'] = np.arange(self.df_wav.shape[0], dtype=np.int64) + self.df_wav['index_within_cluster'] = np.tile(np.arange(self.traces.shape[1]), self.traces.shape[0]) + self.total_wfs = sum(~self.df_wav["peak_channel"].isna()) + else: + self.data_version = 2 - self.traces = np.lib.format.open_memmap(self.traces_fp, dtype=wfs_dtype, shape=traces_shape) - self.templates = np.lib.format.open_memmap(self.templates_fp, dtype=np.float32, shape=templates_shape) - self.channels = np.load(self.channels_fp, allow_pickle="True")["channels"] + self.df_clusters = aggregate_by_clusters(self.df_wav) + self.templates = np.lib.format.open_memmap(self.templates_fp, dtype=np.float32) + self.channels = np.load(self.channels_fp)["channels"] def __repr__(self): - s1 = f"WaveformsLoader with {self.total_wfs} waveforms in {self.wfs_dtype} from {self.num_labels} labels.\n" - s2 = f"Data path: {self.data_dir}\n" - s3 = f"{self.spike_length_samples} samples, {self.num_channels} channels, {self.max_wf} max waveforms per label\n" + return f""" + WaveformsLoader data version {self.data_version} + {self.nw:_} total waveforms {self.ns} samples, {self.nc} channels + {self.nu:_} units, {self.max_wf:_} max waveforms per label + dtype: {self.wfs_dtype} + data path: {self.data_dir} + """ - return s1 + s2 + s3 + @property + def max_wf(self): + return self.df_clusters['count'].max() @property - def wf_counts(self): - """ - pandas Series containing number of (non-NaN) waveforms for each label. - """ - return self.table.groupby("cluster").count()["sample"].rename("num_wfs") + def wfs_dtype(self): + return self.traces.dtype + + @property + def nu(self): + return self.df_clusters.shape[0] + + @property + def ns(self): + return self.traces.shape[-1] + + @property + def nc(self): + return self.traces.shape[-2] + + @property + def nw(self): + return self.df_wav.shape[0] def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=False): """ Returns a specified subset of waveforms from the dataset. :param labels: (list, NumPy array) Label ids (usually clusters) from which to get waveforms. - :param indices: (list, NumPy array) Waveform indices to grab for each waveform. - Can be 1D in which case the same indices are returned for each waveform, or - 2D with first dimension == len(labels) to return a specific set of indices for - each waveform. + :param indices: (list, NumPy array) Waveform indices to grab for each waveform 1D. :param return_info: If True, returns waveforms, table, channels, where table is a DF containing information about the waveforms returned, and channels is the channel map for each waveform. :param flatten: If True, returns all waveforms stacked along dimension zero, otherwise returns array of shape (num_labels, num_indices_per_label, num_channels, spike_length_samples) - """ - if labels is None: - labels = self.labels - if indices is None: - indices = np.arange(self.max_wf) - - labels = np.array(labels) - label_idx = np.array([np.where(self.labels == label)[0][0] for label in labels]) - indices = np.array(indices) - - num_labels = labels.shape[0] - - if indices.ndim == 1: - indices = np.tile(indices, (num_labels, 1)) - - wfs = self.traces[label_idx[:, None], indices].astype(np.float32) - - if flatten: - wfs = wfs.reshape(-1, self.num_channels, self.spike_length_samples) - - info = self.table[self.table["cluster"].isin(labels)].copy() - dfs = [] - for i, l in enumerate(labels): - _idx = indices[i] - dfs.append(info[(info["wf_number"].isin(_idx)) & (info["cluster"] == l)]) - info = pd.concat(dfs).reset_index(drop=True) - - channels = self.channels[info["linear_index"].to_numpy()].astype(int) - + labels = np.array(self.df_clusters.index if labels is None else labels) + iw, _ = ismember(self.df_wav['cluster'], labels) + if self.data_version == 1: + indices = np.array(np.arange(self.max_wf) if indices is None else indices) + indices = np.tile(indices, (labels.size, 1)) if indices.ndim < 2 else indices + assert indices.shape[0] == labels.size, \ + "If indices is a 2D-array, the second dimension must match the number of clusters." + _, iu, _ = np.intersect1d(self.df_clusters.index, labels, return_indices=True) + assert iu.size == labels.size, "Not all labels found in dataset." + wfs = self.traces[iu[:, np.newaxis], indices].astype(np.float32) + if flatten: + wfs = wfs.reshape(-1, self.nc, self.ns) + elif self.data_version == 2: + if indices is not None: + iw = np.where(iw)[0] + iw = iw[self.df_wav.loc[iw, 'index_within_clusters'].isin(np.atleast_1d(np.array(indices)))] + wfs = self.traces[iw].astype(np.float32) + info = self.df_wav.loc[iw, :].copy() + channels = self.channels[iw].astype(int) n_nan = sum(info["sample"].isna()) if n_nan > 0: - logger.warning(f"{n_nan} NaN waveforms included in result.") + logger.info(f"{n_nan} NaN waveforms included in result.") if return_info: return wfs, info, channels - - logger.info("Use return_info=True and check the table for details.") - - return wfs + else: + return wfs def random_waveforms( self, diff --git a/src/ibldsp/waveforms.py b/src/ibldsp/waveforms.py index b05882d..edae543 100644 --- a/src/ibldsp/waveforms.py +++ b/src/ibldsp/waveforms.py @@ -6,7 +6,9 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt +import matplotlib as mpl import scipy + from ibldsp.utils import parabolic_max from ibldsp.fourier import fshift @@ -231,16 +233,24 @@ def find_tip_trough(arr_peak, arr_peak_real, df): return df, arr_peak -def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=1.5, **axkwargs): +def plot_wiggle(wav, fs=1, ax=None, scale=0.3, clip=10, fill_sign=-1, plot_kwargs=None, fill_kwargs=None): """ Displays a multi-trace waveform in a wiggle traces with negative amplitudes filled :param wav: (nchannels, nsamples) - :param axkwargs: keyword arguments to feed to ax.set() - :return: + :param fs: sampling rate + :param ax: axis to plot on + :param scale: waveform amplitude that will be displayed as one inter-trace: if scale = 20e-6 one intertrace will be 20uV + :param clip: maximum value for the traces + :param fill_sign: -1 for negative (default for spikes), 1 for positive + :param plot_kwargs: kwargs for the line plot + :param fill_kwargs: kwargs for the fill + :return: axis """ if ax is None: fig, ax = plt.subplots() + plot_kwargs = {'color': 'k', 'linewidth': 0.5} | (plot_kwargs or {}) + fill_kwargs = {'color': 'k', 'aa': True} | (fill_kwargs or {}) nc, ns = wav.shape vals = np.c_[wav, wav[:, :1] * np.nan].ravel() # flat view of the 2d array. vect = np.arange(vals.size).astype( @@ -255,22 +265,51 @@ def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=1.5, **axkwargs): m = (y2 - y1) / (x2 - x1) c = y1 - m * x1 # tack these values onto the end of the existing data - x = np.hstack([vals, np.zeros_like(c)]) * scalar + x = np.hstack([vals, np.zeros_like(c)]) / scale x = np.maximum(np.minimum(x, clip), -clip) y = np.hstack([vect, c]) # resort the data order = np.argsort(y) # shift from amplitudes to plotting coordinates x_shift, y = y[order].__divmod__(ns + 1) - ax.plot(y / fs, x[order] + x_shift + 1, 'k', linewidth=.5) - x[x > 0] = np.nan + print(plot_kwargs) + ax.plot(y / fs, x[order] + x_shift + 1, **plot_kwargs) + if fill_sign < 0: + x[x > 0] = np.nan + else: + x[x < 0] = np.nan x = x[order] + x_shift + 1 - ax.fill(y / fs, x, 'k', aa=True) - ax.set(xlim=[0, ns / fs], ylim=[0, nc], xlabel='sample', ylabel='trace') + ax.fill(y / fs, x, **fill_kwargs) + ax.set(xlim=[0, ns / fs], ylim=[0, nc]) plt.tight_layout() return ax +def double_wiggle(wav, fs=1, ax=None, colors=None, **kwargs): + """ + Double trouble: this wiggle colours both the negative and the postive values + :param wav: (nchannels, nsamples) + :param fs: sampling rate + :param ax: axis to plot on + :param scale: scale factor for the traces + :param clip: maximum value for the traces + :param fill_sign: -1 for negative (default for spikes), 1 for positive + :param plot_kwargs: kwargs for the line plot + :param fill_kwargs: kwargs for the fill + :return: + """ + if colors is None: + cmap = 'PuOr' + _cmap = mpl.colormaps.get_cmap(cmap) + colors = _cmap(np.linspace(0, 1, 256)) + colors = [colors[50], colors[-50]] + if ax is None: + fig, ax = plt.subplots() + plot_wiggle(wav, fs=fs / 1e3, ax=ax, plot_kwargs={'linewidth': 0}, fill_kwargs={'color': colors[0]}, **kwargs) + plot_wiggle(wav, fs=fs / 1e3, ax=ax, fill_sign=1, plot_kwargs={'linewidth': 0.5}, fill_kwargs={'color': colors[1]}, **kwargs) + return ax + + def plot_peaktiptrough(df, arr, ax, nth_wav=0, plot_grey=True, fs=30000): # Time axix nech, ntr = arr[nth_wav].shape diff --git a/src/neurodsp/__init__.py b/src/neurodsp/__init__.py deleted file mode 100644 index 99f527a..0000000 --- a/src/neurodsp/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -import ibldsp, sys -from warnings import warn - -sys.modules["neurodsp"] = ibldsp -warn( - "neurodsp has been renamed to ibldsp and the old name will be deprecated on 01-Oct-2024.", - FutureWarning, -) diff --git a/src/spikeglx.py b/src/spikeglx.py index f6a3c55..b7cfa3c 100644 --- a/src/spikeglx.py +++ b/src/spikeglx.py @@ -2,6 +2,8 @@ import logging from pathlib import Path import re +import shutil +import time import numpy as np @@ -14,7 +16,7 @@ SAMPLE_SIZE = 2 # int16 DEFAULT_BATCH_SIZE = 1e6 -_logger = logging.getLogger("ibllib") +_logger = logging.getLogger(__name__) def _get_companion_file(sglx_file, pattern='.meta'): @@ -374,6 +376,27 @@ def compress_file(self, keep_original=True, **kwargs): self.file_bin = file_out return file_out + def decompress_to_scratch(self, scratch_dir=None): + """ + Decompresses the file to a temporary directory + Copy over the metadata file + """ + if scratch_dir is None: + bin_file = Path(self.file_bin).with_suffix('.bin') + else: + scratch_dir.mkdir(exist_ok=True, parents=True) + bin_file = Path(scratch_dir).joinpath(self.file_bin.name).with_suffix('.bin') + shutil.copy(self.file_meta_data, bin_file.with_suffix('.meta')) + if not bin_file.exists(): + t0 = time.time() + _logger.info('File is compressed, decompressing to a temporary file...') + self.decompress_file( + keep_original=True, out=bin_file.with_suffix('.bin_temp'), check_after_decompress=False, overwrite=True + ) + shutil.move(bin_file.with_suffix('.bin_temp'), bin_file) + _logger.info(f"Decompression complete: {time.time() - t0:.2f}s") + return bin_file + def decompress_file(self, keep_original=True, **kwargs): """ Decompresses a mtscomp file diff --git a/src/tests/unit/cpu/test_ibldsp.py b/src/tests/unit/cpu/test_ibldsp.py index a8214c9..16e4391 100644 --- a/src/tests/unit/cpu/test_ibldsp.py +++ b/src/tests/unit/cpu/test_ibldsp.py @@ -356,6 +356,22 @@ def test_firstlast_slices(self): my_rms_[wg.iw] = utils.rms(my_sig[sl]) self.assertTrue(np.all(my_rms_ == my_rms)) + def test_firstlast_splicing(self): + sig_in = np.random.randn(600) + sig_out = np.zeros_like(sig_in) + wg = utils.WindowGenerator(ns=600, nswin=100, overlap=20) + for first, last, amp in wg.firstlast_splicing: + sig_out[first:last] = sig_out[first:last] + amp * sig_in[first:last] + np.testing.assert_allclose(sig_out, sig_in) + + def test_firstlast_valid(self): + sig_in = np.random.randn(600) + sig_out = np.zeros_like(sig_in) + wg = utils.WindowGenerator(ns=600, nswin=100, overlap=20) + for first, last, first_valid, last_valid in wg.firstlast_valid: + sig_out[first_valid:last_valid] = sig_in[first_valid:last_valid] + np.testing.assert_array_equal(sig_out, sig_in) + def test_tscale(self): wg = utils.WindowGenerator(ns=500, nswin=100, overlap=50) ts = wg.tscale(fs=1000) @@ -626,25 +642,3 @@ def test_compute_features(self): self.assertEqual(multi_index, list(df.index)) self.assertEqual(["snippet_id", "channel_id"], list(df.index.names)) self.assertEqual(num_snippets * (self.nc - 1), len(df)) - - -class TestNameDeprecationDate(unittest.TestCase): - def test_neurodsp_import(self): - # Check that the old import still works and gives the same package. - # (ibldsp.voltage is imported at the top of this file.) - with self.assertWarnsRegex(FutureWarning, "01-Oct-2024"): - import neurodsp - self.assertEqual(neurodsp.voltage, voltage) - - def test_deprecation_countdown(self): - # Fail on 01-Sep-2024, when `neurodsp` will be retired. - # When this test fails, remove the entire dummy - # `neurodsp` package at the top level of the ibl-neuropixel - # repository - import datetime - - if datetime.datetime.now() > datetime.datetime(2024, 10, 1): - raise NotImplementedError( - "neurodsp will not longer be supported. " - "Change all references to ibldsp." - ) diff --git a/src/tests/unit/cpu/test_waveforms.py b/src/tests/unit/cpu/test_waveforms.py index 4638788..d77dba6 100644 --- a/src/tests/unit/cpu/test_waveforms.py +++ b/src/tests/unit/cpu/test_waveforms.py @@ -1,9 +1,12 @@ from pathlib import Path +import shutil +import tempfile +import unittest import numpy as np import pandas as pd -import tempfile -import shutil +import matplotlib.pyplot as plt +import scipy import ibldsp.utils as utils import ibldsp.waveforms as waveforms @@ -11,9 +14,7 @@ from neurowaveforms.model import generate_waveform from neuropixel import trace_header from ibldsp.fourier import fshift -import scipy -import unittest TEST_PATH = Path(__file__).parent.joinpath("fixtures") @@ -188,6 +189,7 @@ class TestWaveformExtractorArray(unittest.TestCase): channel_neighbors = utils.make_channel_index(geom, radius=200.0) # radius = 200um, 38 chans num_channels = 40 + arr = arr.T def test_extract_waveforms_array(self): wfs, _, _ = waveform_extraction.extract_wfs_array( @@ -307,7 +309,7 @@ class TestWaveformExtractorBin(unittest.TestCase): max_wf = 25 # 2 clusters - spike_samples = np.repeat(np.arange(0, ns, 1600), 2) # 50 spikes + spike_samples = np.repeat(np.arange(0, ns, 1600), 2) # 50 spikes, but 2 of them are on 0 sample spike_channels = np.tile(np.array([100, 368]), 25) spike_clusters = np.tile(np.array([1, 2]), 25) @@ -327,19 +329,20 @@ def tearDown(self): shutil.rmtree(self.tmpdir) def _ground_truth_values(self): - + # here we have to hard-code 48 and 24 because the 2 first spikes are rejected since on sample 0 nc_extract = self.chan_map.shape[1] gt_templates = np.ones((self.n_clusters, nc_extract, self.ns_extract), np.float32) * np.nan - gt_waveforms = np.ones((self.n_clusters, self.max_wf, nc_extract, self.ns_extract), np.float32) * np.nan + gt_waveforms = np.ones((48, nc_extract, self.ns_extract), np.float32) * np.nan c0_chans = self.chan_map[100].astype(np.float32) gt_templates[0, :, :] = np.tile(c0_chans, (self.ns_extract, 1)).T - gt_waveforms[0, :self.max_wf - 1, :, :] = np.tile(c0_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2) + + gt_waveforms[:24, :, :] = gt_templates[0] c1_chans = self.chan_map[368].astype(np.float32) c1_chans[c1_chans == 384] = np.nan gt_templates[1, :, :] = np.tile(c1_chans, (self.ns_extract, 1)).T - gt_waveforms[1, :self.max_wf - 1, :, :] = np.tile(c1_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2) + gt_waveforms[24:, :, :] = gt_templates[1] return gt_templates, gt_waveforms @@ -357,16 +360,20 @@ def test_extract_waveforms_bin(self): ) templates = np.load(self.tmpdir.joinpath("waveforms.templates.npy")) waveforms = np.load(self.tmpdir.joinpath("waveforms.traces.npy")) + table = pd.read_parquet(self.tmpdir.joinpath("waveforms.table.pqt")) - for u in [0, 1]: - assert np.allclose(np.nan_to_num(templates[u]), np.nanmedian(waveforms[u], axis=0)) + cluster_ids = table.cluster.unique() + + for i, u in enumerate(cluster_ids): + inds = table[table.cluster == u].waveform_index.to_numpy() + assert np.allclose(templates[i], np.nanmedian(waveforms[inds], axis=0), equal_nan=True) gt_templates, gt_waveforms = self._ground_truth_values() assert np.allclose(np.nan_to_num(gt_templates), np.nan_to_num(templates)) assert np.allclose(np.nan_to_num(gt_waveforms), np.nan_to_num(waveforms)) - wfl = waveform_extraction.WaveformsLoader(self.tmpdir, max_wf=self.max_wf) + wfl = waveform_extraction.WaveformsLoader(self.tmpdir) wfs = wfl.load_waveforms(return_info=False) assert np.allclose(np.nan_to_num(waveforms), np.nan_to_num(wfs)) @@ -374,16 +381,20 @@ def test_extract_waveforms_bin(self): labels = np.array([1, 2]) indices = np.arange(10) + # test the waveform loader wfs, info, channels = wfl.load_waveforms(labels=labels, indices=indices) + # right waveforms - assert np.allclose(np.nan_to_num(waveforms[:, :10]), np.nan_to_num(wfs)) + assert np.allclose(np.nan_to_num(waveforms[:10, :]), np.nan_to_num(wfs[info['cluster'] == 1, :, :])) + assert np.allclose(np.nan_to_num(waveforms[25:35, :]), np.nan_to_num(wfs[info['cluster'] == 2, :, :])) # right channels assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()]) - wfs, info, channels = wfl.load_waveforms(labels=labels, indices=np.array([[1, 2, 3], [5, 6, 7]])) - # right waveforms - assert np.allclose(np.nan_to_num(waveforms[0, [1, 2, 3]]), np.nan_to_num(wfs[0])) - assert np.allclose(np.nan_to_num(waveforms[1, [5, 6, 7]]), np.nan_to_num(wfs[1])) - # right channels - assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()]) +def test_wiggle(): + wav = generate_waveform() + wav = wav / np.max(np.abs(wav)) * 120 * 1e-6 + fig, ax = plt.subplots(1, 2) + waveforms.plot_wiggle(wav, scale=40 * 1e-6, ax=ax[0]) + waveforms.double_wiggle(wav, scale=40 * 1e-6, fs=30_000, ax=ax[1]) + plt.close('all')