import argparse import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import to_rgba import scipy.signal as sg import soundfile as sf import os import sys from matplotlib.widgets import Button, Cursor, RadioButtons, AxesWidget from fractions import Fraction from pydub import AudioSegment from pydub.playback import play import pandas as pd from typing import List, Union FSSR = 48_000 # Sampling rate of full signal plot FSPK = 0.1 # Max distance to detect a click in full sig in seconds IPIPK = 0.15 # Max distance to detect a IPI in milliseconds SPSC = 80 # Spectrogram scale EMLN = {'p1_pos': np.nan, 'ipi_sig': np.nan, 'ipi_corr_man': np.nan, 'ipi_corr_auto': np.nan, 'ipi_ceps_man': np.nan, 'ipi_ceps_auto': np.nan, 'ind_number': np.nan} # Empty dataline def read(file_path, always_2d=True): try: return sf.read(file_path, always_2d=always_2d) except Exception as e: return load_anysound(file_path) def load_anysound(file_path): tmp = AudioSegment.from_file(file_path) return np.array(tmp.get_array_of_samples()).reshape(-1, tmp.channels), tmp.frame_rate def load_file(in_path, channel, low, high): print(f'Loading and processing {in_path}') song, sr = read(in_path, always_2d=True) song = song[:, channel] sos = sg.butter(3, [low, high], 'bandpass', fs=sr, output='sos') song = sg.sosfiltfilt(sos, song) if len(song) < 20*sr: song = np.pad(song, (1, 20*sr), mode='constant') frac = Fraction(FSSR, sr) song_resample = sg.resample_poly(song, frac.numerator, frac.denominator) print('Done processing') return song, sr, song_resample def norm(x): return x/(np.abs(x).max()+1e-10) def norm_std(x, alpha=1.5): return x/(1.5*np.std(x)+1e-10) class MyRadioButtons(RadioButtons): def __init__(self, ax, labels, active=0, activecolor='blue', size=49, orientation="vertical", **kwargs): """ Add radio buttons to an `~.axes.Axes`. Parameters ---------- ax : `~matplotlib.axes.Axes` The axes to add the buttons to. labels : list of str The button labels. active : int The index of the initially selected button. activecolor : color The color of the selected button. size : float Size of the radio buttons orientation : str The orientation of the buttons: 'vertical' (default), or 'horizontal'. Further parameters are passed on to `Legend`. """ AxesWidget.__init__(self, ax) self._activecolor = activecolor axcolor = ax.get_facecolor() self.value_selected = None ax.set_xticks([]) ax.set_yticks([]) ax.set_navigate(False) circles = [] for i, label in enumerate(labels): if i == active: self.value_selected = label facecolor = self.activecolor else: facecolor = axcolor p = ax.scatter([],[], s=size, marker="o", edgecolor='black', facecolor=facecolor) circles.append(p) if orientation == "horizontal": kwargs.update(ncol=len(labels), mode="expand") kwargs.setdefault("frameon", False) self.box = ax.legend(circles, labels, loc="center", **kwargs) self.labels = self.box.texts self.circles = self.box.legendHandles for c in self.circles: c.set_picker(5) self.cnt = 0 self.observers = {} self.connect_event('pick_event', self._clicked) def _clicked(self, event): if self.ignore(event) or event.mouseevent.button != 1 or event.mouseevent.inaxes != self.ax: return if event.artist in self.circles: self.set_active(self.circles.index(event.artist)) @property def activecolor(self): if hasattr(self._activecolor, '__getitem__'): return self._activecolor[int(self.value_selected[-1])] else: return self._activecolor class MyMultiCursor(AxesWidget): def __init__(self, axes, p1, num_cur, vertOn=True, useblit=False, **lineprops): AxesWidget.__init__(self, axes[0]) self.axes = axes self.p1 = p1 self.num_cur = num_cur self.connect_event('motion_notify_event', self.onmove) self.connect_event('draw_event', self.clear) self.visible = True self.vertOn = vertOn self.useblit = useblit and self.canvas.supports_blit if self.useblit: lineprops['animated'] = True self.linev = [] for ax in self.axes: for line in ax.cursors: if line.get_linestyle() == '--': continue self.linev.append(ax.axes.axvline(ax.axes.get_xbound()[0], visible=False, **lineprops)) self.linev[0].set_linestyle('--') self.linev[self.num_cur].set_linestyle('--') self.n_axes = len(self.linev) self.background = len(self.axes)*[None] self.needclear = False def clear(self, event): """Internal event handler to clear the cursor.""" if self.ignore(event): return for i, ax in enumerate(self.axes): if self.useblit: self.background[i] = self.canvas.copy_from_bbox(ax.axes.bbox) self.linev[i].set_visible(False) def onmove(self, event): """Internal event handler to draw the cursor when the mouse moves.""" if self.ignore(event): return if not self.canvas.widgetlock.available(self): return if event.inaxes not in self.axes: for i in range(self.n_axes): self.linev[i].set_visible(False) if self.needclear: self.canvas.draw() self.needclear = False return self.needclear = True if not self.visible: return for ax_idx, ax in enumerate(self.axes): if ax.axes == event.inaxes: break if ax_idx < 2: if ax_idx == 1: pos = event.xdata*1e3 else: pos = event.xdata if self.linev[0].get_linestyle() == '--': self.linev[0].set_xdata((pos, pos)) self.linev[self.num_cur].set_xdata((pos / 1e3, pos / 1e3)) self.linev[0].set_visible(self.visible and self.vertOn) self.linev[self.num_cur].set_visible(self.visible and self.vertOn) self._update() return else: ipi_pos = abs(pos - self.p1.get_xdata()[0]) for i in range(self.n_axes): if i < self.num_cur: self.linev[i].set_xdata(2*(pos + i*ipi_pos,)) elif i < 2*self.num_cur: self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,)) else: self.linev[i].set_xdata((ipi_pos, ipi_pos)) self.linev[i].set_visible(self.visible and self.vertOn) self._update() return else: ipi_pos = event.xdata pos = self.p1.get_xdata()[0] + ipi_pos restore = False if self.linev[0].get_linestyle() == '--': restore = True self.linev[0].set_linestyle('-') self.linev[self.num_cur].set_linestyle('-') for i in range(self.n_axes): if self.p1.get_visible(): if i < self.num_cur: self.linev[i].set_xdata(2*(pos + i*ipi_pos,)) elif i < 2 * self.num_cur: self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,)) self.linev[i].set_visible(self.visible and self.vertOn) if i >= 2 * self.num_cur: self.linev[i].set_xdata((ipi_pos, ipi_pos)) self.linev[i].set_visible(self.visible and self.vertOn) self._update() if restore: self.linev[0].set_linestyle('--') self.linev[self.num_cur].set_linestyle('--') def _update(self): if self.useblit: k = 0 for bk, ax in zip(self.background, self.axes): if bk is not None: self.canvas.restore_region(bk) for k in range(k, k + self.num_cur if k < 2*self.num_cur else k+1): ax.axes.draw_artist(self.linev[k]) k += 1 self.canvas.blit(ax.axes.bbox) else: self.canvas.draw_idle() return False class Callback(object): def __init__(self, line, song, song_resample, sr, full_ax, num_view, after_length): self.p = 0 self.df = dict() self.line = line self.song = song self.song_resample = song_resample self.sr = sr self.fax = full_ax self.num_view = num_view self.after_length = after_length self.view_data: [AxesGroup, None] = None self.curr = 0 # current view selected self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]]) self.offset = np.zeros((0, 2)) self.curr_ind: List[Union[int, None]] = num_view*[None] # Indices of click for each plot self.curr_vert = num_view*[0] # Current vertical line of sig/spec for each plot self.cursor = None self.f_cursor = None self.reset_b = None self.r_button = None self.ind_b = None self.onaxis_b = None self.spec_b = None self.nfft = 128 self.ind_select = False def shift_left(self, event): self.p = max(0, self.p - FSSR*13) self._shit() def shift_right(self, event): self.p = min(len(self.song_resample) - FSSR*20, self.p + FSSR*13) self._shit() def _shit(self): self.line.set_ydata(self.song_resample[self.p:self.p+FSSR*20]) lim = np.abs(self.song_resample[self.p:self.p+FSSR*20]).max()*1.2 self.fax.set_ylim(-lim, lim) self.line.set_xdata(np.linspace(self.p/FSSR, self.p/FSSR+20, FSSR*20, False)) self.fax.set_xlim(self.p/FSSR, self.p/FSSR+20) plt.draw() def _on_clicked_fax(self, event): # Click on full signal plot pos = int(FSSR * event.xdata) mpos = np.argmax(self.song_resample[max(pos - int(FSSR * FSPK), 0): pos + int(FSSR * FSPK)]) + max(pos - int(FSSR * FSPK), 0) ax_group = self.view_data[self.curr] if self.curr_ind[self.curr] is not None: if np.all(pd.isnull(list(self.df[self.offset[self.curr_ind[self.curr], 0]].values()))): self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1] else: self.scat._facecolors[self.curr_ind[self.curr]] = [0, 0, 0, 1] if mpos / FSSR not in self.offset[:, 0]: self.offset = np.concatenate([self.offset, [[mpos / FSSR, self.song_resample[mpos]]]], axis=0) self.scat.set_offsets(self.offset) self.df[mpos / FSSR] = EMLN.copy() c = [[0, 0, 0, 1]] c[0][self.curr] = 1 if len(self.offset) == 1: self.scat.set_color(c) else: self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0)) self.curr_ind[self.curr] = len(self.offset) - 1 self.curr_vert[self.curr] = 0 self.cursor[self.curr].linev[0].set_linestyle('--') self.cursor[self.curr].linev[self.cursor[self.curr].num_cur].set_linestyle('--') ax_group.set_visible(False) else: self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0]) c = np.array([0, 0, 0, 0.]) k = 0 for i, v in enumerate(self.curr_ind): if v == self.curr_ind[self.curr]: c += to_rgba('rgbkcmyrgbkcmy'[i]) k += 1 for i in range(self.num_view): c /= k self.scat._facecolors[self.curr_ind[self.curr]] = c self.scat.set_color(self.scat._facecolors) row = self.df[mpos / FSSR] if np.isnan(row['p1_pos']): ax_group.signal.set_visible(False) ax_group.spectrogram.set_visible(False) else: ax_group.signal.cursors[0].set_xdata((row['p1_pos'], row['p1_pos'])) ax_group.spectrogram.cursors[0].set_xdata((row['p1_pos'] * 1e-3, row['p1_pos'] * 1e-3)) ax_group.signal.cursors[0].set_visible(True) ax_group.spectrogram.cursors[0].set_visible(True) if np.isnan(row['ipi_sig']): ax_group.signal.cursors[1].set_visible(False) ax_group.spectrogram.cursors[1].set_visible(False) else: p2 = row['p1_pos'] + row['ipi_sig'] ax_group.signal.cursors[1].set_xdata((p2, p2)) ax_group.spectrogram.cursors[1].set_xdata((p2 * 1e-3, p2 * 1e-3)) ax_group.signal.cursors[1].set_visible(True) ax_group.spectrogram.cursors[1].set_visible(True) if np.isnan(row['ipi_corr_man']): ax_group.correlation.set_visible(False) else: ax_group.correlation.cursors[0].set_xdata((row['ipi_corr_man'], row['ipi_corr_man'])) ax_group.correlation.set_visible(True) if np.isnan(row['ipi_ceps_man']): ax_group.cepstrum.set_visible(False) else: ax_group.cepstrum.cursors[0].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man'])) ax_group.cepstrum.set_visible(True) self.ind_select = False self.ind_b.color = '0.85' self.ind_b.howercolor = '0.95' self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos / FSSR]["ind_number"]}') click = self.song[max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0): int(mpos * self.sr / FSSR + self.after_length * self.sr)] if len(click) < 2 * int((10e-3 + self.after_length) * self.sr): np.pad(click, (0, 2 * int((10e-3 + self.after_length) * self.sr) - len(click)), mode='constant') ax_group.signal.im.set_ydata(norm(click)) self._update_spectrogram(ax_group) ax_group.correlation.im.set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):])) ax_group.cepstrum.im.set_ydata( norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3 * self.sr)]))) self._set_label() plt.draw() def _on_clicked_ax_group(self, event, ax_group, group_num, cell_num): if cell_num < 2: pos = event.xdata if cell_num == 0 else event.xdata * 1000 cur_type = self.curr_vert[group_num] ax_group.signal.cursors[cur_type].set_xdata(2*(pos,)) ax_group.spectrogram.cursors[cur_type].set_xdata(2*(pos*10**-3,)) ax_group.signal.cursors[cur_type].set_visible(True) ax_group.spectrogram.cursors[cur_type].set_visible(True) if cur_type == 0: self.df[self.offset[self.curr_ind[group_num], 0]]['p1_pos'] = pos self.curr_vert[group_num] ^= 1 cur_type ^= 1 self.cursor[group_num].linev[0].set_linestyle(['--', '-'][cur_type]) self.cursor[group_num].linev[self.cursor[group_num].num_cur].set_linestyle(['--', '-'][cur_type]) if ax_group.signal.cursors[1].get_visible(): ipi_man = ax_group.signal_ipi self.df[self.offset[self.curr_ind[group_num], 0]]['ipi_sig'] = ipi_man ax_group.signal.axes.set_xlabel(f'Sig man:{ipi_man:.5f}') else: cell = ax_group[cell_num] cell.cursors[0].set_xdata((event.xdata, event.xdata)) cell.cursors[0].set_visible(True) lim_min = max(int(self.sr/1e3*(event.xdata-IPIPK)), 0) ipi_auto = np.argmax(cell.im.get_data()[1][lim_min:int(self.sr/1e3*(event.xdata+IPIPK))]) + lim_min col = 'ipi_' + ('corr' if cell_num == 2 else 'ceps') self.df[self.offset[self.curr_ind[group_num], 0]][col + '_man'] = event.xdata self.df[self.offset[self.curr_ind[group_num], 0]][col + '_auto'] = ipi_auto*1e3/self.sr cell.axes.set_xlabel(f'{"Corr" if cell_num == 2 else "Ceps"} man:{event.xdata:.3f} ' f'auto:{ipi_auto*1e3/self.sr:.3f}') plt.draw() def on_clicked(self, event): if event.inaxes == self.fax: self._on_clicked_fax(event) for i, ax_group in enumerate(self.view_data): # Look if a click plot was clicked and which one for j in range(len(ax_group)): if event.inaxes == ax_group[j].axes: self._on_clicked_ax_group(event, ax_group, i, j) return def change_curr(self, label): self.curr = int(label[-1]) self.f_cursor.linev.set_color('rgbkcmyrgbkcmy'[self.curr]) self.reset_b.label.set_c('rgbkcmyrgbkcmy'[self.curr]) plt.draw() def toggle_ind(self, event): if not len(self.offset): return self.ind_select = not self.ind_select self.ind_b.color = 'limegreen' if self.ind_select else '0.85' self.ind_b.hovercolor = 'lime' if self.ind_select else '0.95' if self.ind_select: self.df[self.offset[self.curr_ind[self.curr], 0]]['ind_number'] = '' self.ind_b.label.set_text(f'Current individual:\n' f'{self.df[self.offset[self.curr_ind[self.curr], 0]]["ind_number"]}') plt.draw() def key_pressed(self, event): if self.ind_select: row = self.df[self.offset[self.curr_ind[self.curr], 0]] if event.key == 'backspace': row['ind_number'] = row['ind_number'][:-1] elif event.key in ['shift', 'control', 'alt']: pass else: row['ind_number'] = row['ind_number'] + event.key self.ind_b.label.set_text(f'Current individual:\n{row["ind_number"]}') plt.draw() else: if event.key in '012': self.change_curr(event.key) self.r_button.set_active(int(event.key)) def play(self, event): sound = (norm(self.song_resample[self.p:self.p+FSSR*20])*(2**15-1)).astype(np.int16) try: play(AudioSegment(sound.tobytes(), frame_rate=FSSR, sample_width=sound.dtype.itemsize, channels=1)) except KeyboardInterrupt: pass def resize(self, event): self.fax.get_figure().set_constrained_layout(True) plt.draw() plt.pause(0.2) self.fax.get_figure().set_constrained_layout(False) def onaxis(self, event): row = self.df[self.offset[self.curr_ind[self.curr], 0]] if row['onaxis'] == -1: row['onaxis'] = 1 else: row['onaxis'] ^= 1 self.onaxis_b.label.set_text('On-axis' if row['onaxis'] else 'Off-axis') plt.draw() def increase_freq(self, event): lim = self.view_data[0].spectrogram.axes.get_ylim()[1] + 1e3 for ax_group in self.view_data: ax_group.spectrogram.axes.set_ylim(0, lim) plt.draw() def decrease_freq(self, event): lim = max(self.view_data[0].spectrogram.axes.get_ylim()[1] - 1e3, 1e3) for ax_group in self.view_data: ax_group.spectrogram.axes.set_ylim(0, lim) plt.draw() def _update_spectrogram(self, ax_group): click = ax_group.signal.im.get_ydata() spec, _, t = plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1, pad_to=2*self.nfft) spec = np.flipud(20*np.log10(spec)) ax_group.spectrogram.im.set_data(spec) ax_group.spectrogram.im.set_clim(spec.max() - SPSC, spec.max()) ax_group.spectrogram.im.set_extent([t[0] - 1/self.sr/2, t[-1] + 1/self.sr/2, 0., self.sr/2]) def increase_res(self, event): if self.nfft > int(10e-3*self.sr): return self.nfft *= 2 for ax_group in self.view_data: self._update_spectrogram(ax_group) if self.nfft > int(10e-3*self.sr): self.spec_b['plus_res'].label.set_text('Can\'t go\nhigher') else: self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins') self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins') plt.draw() def decrease_res(self, event): if self.nfft < 8: return self.nfft //= 2 for ax_group in self.view_data: self._update_spectrogram(ax_group) self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins') if self.nfft < 8: self.spec_b['minus_res'].label.set_text('Can\'t go\nlower') else: self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins') plt.draw() def _set_label(self, ind=None, dic=None): if ind is None: ind = self.curr if dic is None: dic = self.df[self.offset[self.curr_ind[ind], 0]] ax_group = self.view_data[ind] ax_group.signal.axes.set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}') ax_group.correlation.axes.set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}') ax_group.cepstrum.axes.set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}') self.ind_b.label.set_text(f'Current individual:\n{dic["ind_number"]}') if self.onaxis_b is not None: self.onaxis_b.label.set_text(['Off', 'On', '?'][dic['onaxis']] + '-axis') def _set_visible(self, ind=None, state=False): if ind is None: ind = self.curr self.view_data[ind].set_visible(state) def reset_curr(self, event): self.df[self.offset[self.curr_ind[self.curr], 0]] = EMLN.copy() self._set_label() self._set_visible() self.curr_vert[self.curr] = 0 self.cursor[self.curr].linev[0].set_linestyle('--') self.cursor[self.curr].linev[self.cursor[self.curr].num_cur].set_linestyle('--') self.view_data[self.curr].set_visible(False) plt.draw() def reset(self, song, sr, song_resample): self.p = 0 self.df = dict() self.song = song self.song_resample = song_resample self._shit() sr_update = False if self.sr != sr: self.sr = sr sr_update = True self.change_curr('0') # reset current view to 0 self.r_button.set_active(0) self.offset = np.zeros((0, 2)) self.scat.set_offsets(self.offset) self.scat.set_color([[0, 0, 0, 1]]) self.curr_ind = self.num_view * [None] # Ind of click for each plot self.curr_vert = self.num_view * [0] # Current vertical line of sig/spec for each plot for ax_group in self.view_data: ax_group.signal.im.set_ydata(np.zeros(int((10e-3 + self.after_length) * sr))) ax_group.correlation.im.set_ydata(np.zeros(int(10e-3 * sr))) ax_group.cepstrum.im.set_ydata(np.zeros(int(10e-3 * sr))) if sr_update: ax_group.signal.im.set_xdata(np.linspace(0, (10e-3 + self.after_length)*1e3, int((10e-3 + self.after_length)*sr), False)) ax_group.correlation.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) ax_group.cepstrum.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) ax_group.spectrogram.im.set_clim(2000, 2100) for i in range(self.num_view): self._set_label(i, EMLN) self._set_visible(i) self.ind_select = False self.ind_b.color = '0.85' self.ind_b.howercolor = '0.95' self.ind_b.label.set_text(f'Current individual:\nnan') plt.draw() class AxesWithCursor: __slots__ = ['axes', 'im', 'cursors'] def __init__(self, axes: plt.Axes, im=None, cursors=None): self.axes = axes self.im = im self.cursors = list() if cursors is None else cursors def set_visible(self, state): for cursor in self.cursors: cursor.set_visible(state) @property def figure(self): return self.axes.figure class AxesGroup: __slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum'] def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2, after_length=10e-3): self.signal = AxesWithCursor(ax_sig, ax_sig.plot(np.linspace(0, 10 + after_length*1e3, int((10e-3 + after_length)*sr), False), np.zeros(int((10e-3 + after_length)*sr)))[0], (ax_sig.axvline(10, c='k', linestyle='--'),) + tuple(ax_sig.axvline(10, c='k') for _ in range(num_cur))) ax_sig.set_xlim(0, 10 + after_length*1e3) ax_sig.set_xlabel('IPI man:None') ax_sig.set_ylim(-1, 1) ax_sig.set_yticks([]) self.spectrogram = AxesWithCursor(ax_spec, ax_spec.specgram(np.random.normal(0, 1e-6, int((10e-3 + after_length)*sr)), Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1], (ax_spec.axvline(0.01, c='k', linestyle='--'),) + tuple(ax_spec.axvline(0.01, c='k') for _ in range(num_cur))) self.spectrogram.im.set_clim(2000, 2100) ax_spec.set_ylim(0, min(20e3, sr/2)) ax_spec.set_yticks(ax_spec.get_yticks()) # Needed, otherwise updating ylim doesn't update ticks properly ax_spec.set_yticklabels((ax_spec.get_yticks() / 1e3).astype(int)) self.correlation = AxesWithCursor(ax_corr, ax_corr.plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[0], (ax_corr.axvline(10, c='k'),)) ax_corr.set_xlim(0, 10) ax_corr.set_xlabel('IPI man:None auto:None') ax_corr.set_ylim(-1, 1) ax_corr.set_yticks([]) self.cepstrum = AxesWithCursor(ax_cep, ax_cep.plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[0], (ax_cep.axvline(10, c='k'),)) ax_cep.set_xlim(0, 10) ax_cep.set_xlabel('IPI man:None auto:None') ax_cep.set_ylim(0, 0.33) ax_cep.set_yticks([]) self.set_visible(False) def __getitem__(self, item): return (self.signal, self.spectrogram, self.correlation, self.cepstrum)[item] def __len__(self): return 4 def __contains__(self, item): if isinstance(item, AxesWithCursor): return any(item == ax for ax in self) else: return any(item == ax.axes for ax in self) def set_visible(self, state): self.signal.set_visible(state) self.spectrogram.set_visible(state) self.correlation.set_visible(state) self.cepstrum.set_visible(state) @property def signal_ipi(self): return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0] def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3, after_length=10e-3): song, sr, song_resample = load_file(in_path, channel, low, high) fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True) gs = fig.add_gridspec(12, 20) full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1]) # Type annotation needed with plt 3.1 callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0], song, song_resample, sr, full_sig, num_graph, after_length) callback.fax = full_sig full_sig.set_xlim(0, 20) lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2 full_sig.set_ylim(-lim, lim) full_sig.set_yticks([]) callback.f_cursor = Cursor(full_sig, horizOn=False, useblit=True, c='r') cid1 = fig.canvas.mpl_connect('button_release_event', callback.on_clicked) b_left_ax = plt.subplot(gs[:2, 0]) b_right_ax = plt.subplot(gs[:2, -1]) b_left = Button(b_left_ax, '<|') b_right = Button(b_right_ax, '|>') b_left.on_clicked(callback.shift_left) b_right.on_clicked(callback.shift_right) r_button_ax = plt.subplot(gs[10, 1:-1]) r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(num_graph)], orientation='horizontal', size=128, activecolor=list('rgbkcmyrgbkcmy'[:num_graph])) # !Last char of labels is use as index in rest of code r_button_ax.axis('off') r_button.on_clicked(callback.change_curr) for i, c in enumerate('rgbkcmyrgbkcmy'[:num_graph]): r_button.labels[i].set_c(c) callback.r_button = r_button for v in "fullscreen,home,back,forward,pan,zoom,save,quit,grid,yscale,xscale,all_axes".split(','): plt.rcParams[f'keymap.{v}'] = [] # disable default shortcut but fullsreen cid2 = fig.canvas.mpl_connect('key_press_event', callback.key_pressed) # c_button_ax = plt.subplot(gs[10,3:6]) # c_button = CheckButtons(c_button_ax, [f'Save {i}' for i in range(3)], [False for i in range(3)]) vfs = 4 vs = 2 hs = 1 hfs = (20-hs)//(2*num_graph) ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]), plt.subplot(gs[vs:vs + vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]), plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]), plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]), sr, num_cur, after_length) for i in range(num_graph)] play_b_ax = plt.subplot(gs[-1, :2]) play_b = Button(play_b_ax, 'Play\ncurrent segment') play_b.on_clicked(callback.play) resize_b_ax = plt.subplot(gs[-1, 2:4]) resize_b = Button(resize_b_ax, 'Resize plot') resize_b.on_clicked(callback.resize) reset_b_ax = plt.subplot(gs[-1,4:6]) reset_b = Button(reset_b_ax, 'Reset current') reset_b.label.set_c('r') reset_b.on_clicked(callback.reset_curr) callback.reset_b = reset_b ind_b_ax = plt.subplot(gs[-1,6:8]) ind_b = Button(ind_b_ax, 'Current individual:\nnan') ind_b.on_clicked(callback.toggle_ind) callback.ind_b = ind_b freq_p_b_ax = plt.subplot(gs[2, 0]) freq_m_b_ax = plt.subplot(gs[3, 0]) freq_res_p_b_ax = plt.subplot(gs[4, 0]) freq_res_m_b_ax = plt.subplot(gs[5, 0]) freq_p_b = Button(freq_p_b_ax, '+\n1kHz') freq_p_b.on_clicked(callback.increase_freq) freq_m_b = Button(freq_m_b_ax, '-\n1kHz') freq_m_b.on_clicked(callback.decrease_freq) freq_res_p_b = Button(freq_res_p_b_ax, '256\nbins') freq_res_p_b.on_clicked(callback.increase_res) freq_res_m_b = Button(freq_res_m_b_ax, '64\nbins') freq_res_m_b.on_clicked(callback.decrease_res) spec_button = {'plus': freq_p_b, 'minus': freq_m_b, 'plus_res': freq_res_p_b, 'minus_res': freq_res_m_b} callback.spec_b = spec_button m_cursor: List[Union[MyMultiCursor, None]] = len(ax_group) * [None] callback.cursor = m_cursor for i in range(len(ax_group)): m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], num_cur, useblit=True, c='k') callback.view_data = ax_group return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons': {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button, 'ind_b': ind_b, 'fs_click': cid1, 'key_pressed': cid2, 'reset_b': reset_b, 'spec_button': spec_button}} def reset(callback, in_path, channel, low=2e3, high=20e3): song, sr, song_resample = load_file(in_path, channel, low, high) callback.reset(song, sr, song_resample) def main(args): if args.out == '': outpath = args.input.rsplit('.', 1)[0] + '.pred.h5' else: outpath = args.out if os.path.isfile(outpath): if not (args.erase or args.resume): print(f'Out file {outpath} already exist and erase or resume option isn\'t set.') return 1 elif args.resume: print(f'Out file {outpath} does not already exist and resume option is set.') return 1 if args.pulse < 1: print(f'{args.pulse} is an invalid number of pulses.') return 1 elif args.click < 1: print(f'{args.click} is an invalid number of clicks.') return 1 EMLN['onaxis'] = -1 ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click, args.after*1e-3) if args.resume: df = pd.read_hdf(outpath) if 'onaxis' not in df.columns: df['onaxis'] = -1 ref_dict['callback'].df = df.to_dict(orient='index') ref_dict['callback'].offset = np.tile(np.array(list(ref_dict['callback'].df.keys()))[:, np.newaxis], (1, 2)) ref_dict['callback'].offset[:, 1] = ref_dict['callback'].song_resample[(ref_dict['callback'].offset[:, 1] * FSSR).astype(int)] ref_dict['callback'].scat.set_offsets(ref_dict['callback'].offset) colors = np.array(len(ref_dict['callback'].offset)*[[0.7, 0.2, 0.5, 1]]) for i, p in enumerate(ref_dict['callback'].offset[:, 0]): if not np.all(pd.isnull(list(ref_dict['callback'].df[p].values()))): colors[i] = [0, 0, 0, 1] ref_dict['callback'].scat.set_color(colors) onaxis_b_ax = plt.subplot(ref_dict['gridspec'][-1, 8:10]) onaxis_b = Button(onaxis_b_ax, '?-axis') onaxis_b.on_clicked(ref_dict['callback'].onaxis) ref_dict['callback'].onaxis_b = onaxis_b plt.draw() plt.pause(0.2) ref_dict['fig'].set_constrained_layout(False) plt.show() df = pd.DataFrame.from_dict(ref_dict['callback'].df, orient='index') df.to_hdf(outpath, 'df', format='table') return 0 if __name__ == '__main__': class ArgumentParser(argparse.ArgumentParser): def error(self, message): if message.startswith('the following arguments are required:'): raise ValueError(message) super(ArgumentParser, self).error(message) parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("input", type=str, help="Input file") parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'") parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.") parser.add_argument("--pulse", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.") parser.add_argument("--click", type=int, default=3, help="Number of click to plot simultaneously.") parser.add_argument("--after", type=float, default=10., help="Duration to plot after P1 in ms.") parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter') parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter') group = parser.add_mutually_exclusive_group() group.add_argument("--erase", action='store_true', help="If out file exist and this option is not given," " the computation will be halted") group.add_argument("--resume", action='store_true', help="If out file exist and this option is given," " the previous annotation file will be loaded") try: args = parser.parse_args() except ValueError as e: print(f'Error while parsing the command line arguments: {e}') def ask(string): y = {'y', 'yes', 'o', 'oui'} a = {'y', 'yes', 'o', 'oui', 'n', 'no', 'non'} while True: ans = input(string + ' [y/n] ').lower() if ans in a: return ans in y if not ask('Do you want to manually specify them?'): sys.exit(2) # exit code of invalid argparse class VirtualArgParse(object): def __init__(self): self.input = input("What is the input file path? ") self.out = input("What is the out file path? (Leave empty for default)") self.channel = int(input("Which channel do you want to use starting from 0? ")) self.pulse = int(input("How many pulses do you want to display (> 1) ")) self.click = int(input("How many clicks do you want to display (> 1) ")) self.after = float(input("What duration do you want after P1?")) self.low = float(input("What is the lower frequency of the bandpass? ")) self.up = float(input("What is the upper frequency of the bandpass? ")) self.erase = ask("Do you want to erase the out file if it already exist?") if not self.erase: self.resume = ask("Do you want to resume the out file if it already exist?") else: self.resume = False args = VirtualArgParse() sys.exit(main(args))