Skip to content
Snippets Groups Projects
Select Git revision
  • 3c3acb33e9df3fd5b68f039f0d6efb2a790cb152
  • master default protected
  • loss
  • producer
4 results

gettingStarted.md

Blame
  • ipi_extract.py 37.21 KiB
    import argparse
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.colors import to_rgba
    import scipy.signal as sg
    import soundfile as sf
    import os
    import sys
    from matplotlib.widgets import Button, Cursor, RadioButtons, AxesWidget
    from fractions import Fraction
    from pydub import AudioSegment
    from pydub.playback import play
    import pandas as pd
    from typing import List, Union
    
    FSSR = 48_000  # Sampling rate of full signal plot
    FSPK = 0.1     # Max distance to detect a click in full sig in seconds
    IPIPK = 0.15   # Max distance to detect a IPI in milliseconds
    SPSC = 80      # Spectrogram scale
    EMLN = {'p1_pos': np.nan, 'ipi_sig': np.nan,
            'ipi_corr_man': np.nan, 'ipi_corr_auto': np.nan,
            'ipi_ceps_man': np.nan, 'ipi_ceps_auto': np.nan,
            'ind_number': np.nan}  # Empty dataline
    
    
    def read(file_path, always_2d=True):
        try:
            return sf.read(file_path, always_2d=always_2d)
        except Exception as e:
            return load_anysound(file_path)
    
    
    def load_anysound(file_path):
        tmp = AudioSegment.from_file(file_path)
        return np.array(tmp.get_array_of_samples()).reshape(-1, tmp.channels), tmp.frame_rate
    
    
    def load_file(in_path, channel, low, high):
        print(f'Loading and processing {in_path}')
        song, sr = read(in_path, always_2d=True)
        song = song[:, channel]
        sos = sg.butter(3, [low, high], 'bandpass', fs=sr, output='sos')
        song = sg.sosfiltfilt(sos, song)
        if len(song) < 20*sr:
            song = np.pad(song, (1, 20*sr), mode='constant')
        frac = Fraction(FSSR, sr)
        song_resample = sg.resample_poly(song, frac.numerator, frac.denominator)
        print('Done processing')
        return song, sr, song_resample
    
    
    def norm(x):
        return x/(np.abs(x).max()+1e-10)
    
    
    def norm_std(x, alpha=1.5):
        return x/(1.5*np.std(x)+1e-10)
    
    
    class MyRadioButtons(RadioButtons):
    
        def __init__(self, ax, labels, active=0, activecolor='blue', size=49,
                     orientation="vertical", **kwargs):
            """
            Add radio buttons to an `~.axes.Axes`.
            Parameters
            ----------
            ax : `~matplotlib.axes.Axes`
                The axes to add the buttons to.
            labels : list of str
                The button labels.
            active : int
                The index of the initially selected button.
            activecolor : color
                The color of the selected button.
            size : float
                Size of the radio buttons
            orientation : str
                The orientation of the buttons: 'vertical' (default), or 'horizontal'.
            Further parameters are passed on to `Legend`.
            """
            AxesWidget.__init__(self, ax)
            self._activecolor = activecolor
            axcolor = ax.get_facecolor()
            self.value_selected = None
    
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_navigate(False)
            circles = []
            for i, label in enumerate(labels):
                if i == active:
                    self.value_selected = label
                    facecolor = self.activecolor
                else:
                    facecolor = axcolor
                p = ax.scatter([],[], s=size, marker="o", edgecolor='black',
                               facecolor=facecolor)
                circles.append(p)
            if orientation == "horizontal":
                kwargs.update(ncol=len(labels), mode="expand")
            kwargs.setdefault("frameon", False)
            self.box = ax.legend(circles, labels, loc="center", **kwargs)
            self.labels = self.box.texts
            self.circles = self.box.legendHandles
            for c in self.circles:
                c.set_picker(5)
            self.cnt = 0
            self.observers = {}
            self.connect_event('pick_event', self._clicked)
    
        def _clicked(self, event):
            if self.ignore(event) or event.mouseevent.button != 1 or event.mouseevent.inaxes != self.ax:
                return
            if event.artist in self.circles:
                self.set_active(self.circles.index(event.artist))
    
        @property
        def activecolor(self):
            if hasattr(self._activecolor, '__getitem__'):
                return self._activecolor[int(self.value_selected[-1])]
            else:
                return self._activecolor
    
    
    class MyMultiCursor(AxesWidget):
        def __init__(self, axes, p1, num_cur, vertOn=True, useblit=False,
                     **lineprops):
            AxesWidget.__init__(self, axes[0])
            self.axes = axes
            self.p1 = p1
            self.num_cur = num_cur
            self.connect_event('motion_notify_event', self.onmove)
            self.connect_event('draw_event', self.clear)
    
            self.visible = True
            self.vertOn = vertOn
            self.useblit = useblit and self.canvas.supports_blit
    
            if self.useblit:
                lineprops['animated'] = True
            self.linev = []
            for ax in self.axes:
                for line in ax.cursors:
                    if line.get_linestyle() == '--':
                        continue
                    self.linev.append(ax.axes.axvline(ax.axes.get_xbound()[0], visible=False, **lineprops))
            self.linev[0].set_linestyle('--')
            self.linev[self.num_cur].set_linestyle('--')
            self.n_axes = len(self.linev)
    
            self.background = len(self.axes)*[None]
            self.needclear = False
    
        def clear(self, event):
            """Internal event handler to clear the cursor."""
            if self.ignore(event):
                return
            for i, ax in enumerate(self.axes):
                if self.useblit:
                    self.background[i] = self.canvas.copy_from_bbox(ax.axes.bbox)
                self.linev[i].set_visible(False)
    
        def onmove(self, event):
            """Internal event handler to draw the cursor when the mouse moves."""
            if self.ignore(event):
                return
            if not self.canvas.widgetlock.available(self):
                return
            if event.inaxes not in self.axes:
                for i in range(self.n_axes):
                    self.linev[i].set_visible(False)
                if self.needclear:
                    self.canvas.draw()
                    self.needclear = False
                return
            self.needclear = True
            if not self.visible:
                return
            for ax_idx, ax in enumerate(self.axes):
                if ax.axes == event.inaxes:
                    break
            if ax_idx < 2:
                if ax_idx == 1:
                    pos = event.xdata*1e3
                else:
                    pos = event.xdata
                if self.linev[0].get_linestyle() == '--':
                    self.linev[0].set_xdata((pos, pos))
                    self.linev[self.num_cur].set_xdata((pos / 1e3, pos / 1e3))
                    self.linev[0].set_visible(self.visible and self.vertOn)
                    self.linev[self.num_cur].set_visible(self.visible and self.vertOn)
                    self._update()
                    return
                else:
                    ipi_pos = abs(pos - self.p1.get_xdata()[0])
                    for i in range(self.n_axes):
                        if i < self.num_cur:
                            self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
                        elif i < 2*self.num_cur:
                            self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
                        else:
                            self.linev[i].set_xdata((ipi_pos, ipi_pos))
                        self.linev[i].set_visible(self.visible and self.vertOn)
                    self._update()
                    return
            else:
                ipi_pos = event.xdata
                pos = self.p1.get_xdata()[0] + ipi_pos
                restore = False
            if self.linev[0].get_linestyle() == '--':
                restore = True
                self.linev[0].set_linestyle('-')
                self.linev[self.num_cur].set_linestyle('-')
            for i in range(self.n_axes):
                if self.p1.get_visible():
                    if i < self.num_cur:
                        self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
                    elif i < 2 * self.num_cur:
                        self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
                    self.linev[i].set_visible(self.visible and self.vertOn)
                if i >= 2 * self.num_cur:
                    self.linev[i].set_xdata((ipi_pos, ipi_pos))
                    self.linev[i].set_visible(self.visible and self.vertOn)
            self._update()
    
            if restore:
                self.linev[0].set_linestyle('--')
                self.linev[self.num_cur].set_linestyle('--')
    
        def _update(self):
            if self.useblit:
                k = 0
                for bk, ax in zip(self.background, self.axes):
                    if bk is not None:
                        self.canvas.restore_region(bk)
                    for k in range(k, k + self.num_cur if k < 2*self.num_cur else k+1):
                        ax.axes.draw_artist(self.linev[k])
                    k += 1
                    self.canvas.blit(ax.axes.bbox)
            else:
                self.canvas.draw_idle()
            return False
    
    
    class Callback(object):
        def __init__(self, line, song, song_resample, sr, full_ax, num_view, after_length):
            self.p = 0
            self.df = dict()
            self.line = line
            self.song = song
            self.song_resample = song_resample
            self.sr = sr
            self.fax = full_ax
            self.num_view = num_view
            self.after_length = after_length
            self.view_data: [AxesGroup, None] = None
            self.curr = 0  # current view selected
            self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]])
            self.offset = np.zeros((0, 2))
            self.curr_ind: List[Union[int, None]] = num_view*[None]  # Indices of click for each plot
            self.curr_vert = num_view*[0]    # Current vertical line of sig/spec for each  plot
            self.cursor = None
            self.f_cursor = None
            self.reset_b = None
            self.r_button = None
            self.ind_b = None
            self.onaxis_b = None
            self.spec_b = None
            self.nfft = 128
            self.ind_select = False
    
        def shift_left(self, event):
            self.p = max(0, self.p - FSSR*13)
            self._shit()
    
        def shift_right(self, event):
            self.p = min(len(self.song_resample) - FSSR*20, self.p + FSSR*13)
            self._shit()
    
        def _shit(self):
            self.line.set_ydata(self.song_resample[self.p:self.p+FSSR*20])
            lim = np.abs(self.song_resample[self.p:self.p+FSSR*20]).max()*1.2
            self.fax.set_ylim(-lim, lim)
            self.line.set_xdata(np.linspace(self.p/FSSR, self.p/FSSR+20, FSSR*20, False))
            self.fax.set_xlim(self.p/FSSR, self.p/FSSR+20)
            plt.draw()
    
        def _on_clicked_fax(self, event):
            # Click on full signal plot
            pos = int(FSSR * event.xdata)
            mpos = np.argmax(self.song_resample[max(pos - int(FSSR * FSPK), 0):
                                                pos + int(FSSR * FSPK)]) + max(pos - int(FSSR * FSPK), 0)
            ax_group = self.view_data[self.curr]
            if self.curr_ind[self.curr] is not None:
                if np.all(pd.isnull(list(self.df[self.offset[self.curr_ind[self.curr], 0]].values()))):
                    self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1]
                else:
                    self.scat._facecolors[self.curr_ind[self.curr]] = [0, 0, 0, 1]
            if mpos / FSSR not in self.offset[:, 0]:
                self.offset = np.concatenate([self.offset, [[mpos / FSSR, self.song_resample[mpos]]]], axis=0)
                self.scat.set_offsets(self.offset)
                self.df[mpos / FSSR] = EMLN.copy()
                c = [[0, 0, 0, 1]]
                c[0][self.curr] = 1
                if len(self.offset) == 1:
                    self.scat.set_color(c)
                else:
                    self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0))
                self.curr_ind[self.curr] = len(self.offset) - 1
                self.curr_vert[self.curr] = 0
                self.cursor[self.curr].linev[0].set_linestyle('--')
                self.cursor[self.curr].linev[self.cursor[self.curr].num_cur].set_linestyle('--')
                ax_group.set_visible(False)
            else:
                self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0])
                c = np.array([0, 0, 0, 0.])
                k = 0
                for i, v in enumerate(self.curr_ind):
                    if v == self.curr_ind[self.curr]:
                        c += to_rgba('rgbkcmyrgbkcmy'[i])
                        k += 1
                for i in range(self.num_view):
                    c /= k
                self.scat._facecolors[self.curr_ind[self.curr]] = c
                self.scat.set_color(self.scat._facecolors)
                row = self.df[mpos / FSSR]
                if np.isnan(row['p1_pos']):
                    ax_group.signal.set_visible(False)
                    ax_group.spectrogram.set_visible(False)
                else:
                    ax_group.signal.cursors[0].set_xdata((row['p1_pos'], row['p1_pos']))
                    ax_group.spectrogram.cursors[0].set_xdata((row['p1_pos'] * 1e-3, row['p1_pos'] * 1e-3))
                    ax_group.signal.cursors[0].set_visible(True)
                    ax_group.spectrogram.cursors[0].set_visible(True)
                    if np.isnan(row['ipi_sig']):
                        ax_group.signal.cursors[1].set_visible(False)
                        ax_group.spectrogram.cursors[1].set_visible(False)
                    else:
                        p2 = row['p1_pos'] + row['ipi_sig']
                        ax_group.signal.cursors[1].set_xdata((p2, p2))
                        ax_group.spectrogram.cursors[1].set_xdata((p2 * 1e-3, p2 * 1e-3))
                        ax_group.signal.cursors[1].set_visible(True)
                        ax_group.spectrogram.cursors[1].set_visible(True)
                if np.isnan(row['ipi_corr_man']):
                    ax_group.correlation.set_visible(False)
                else:
                    ax_group.correlation.cursors[0].set_xdata((row['ipi_corr_man'], row['ipi_corr_man']))
                    ax_group.correlation.set_visible(True)
                if np.isnan(row['ipi_ceps_man']):
                    ax_group.cepstrum.set_visible(False)
                else:
                    ax_group.cepstrum.cursors[0].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man']))
                    ax_group.cepstrum.set_visible(True)
            self.ind_select = False
            self.ind_b.color = '0.85'
            self.ind_b.howercolor = '0.95'
            self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos / FSSR]["ind_number"]}')
            click = self.song[max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):
                              int(mpos * self.sr / FSSR + self.after_length * self.sr)]
            if len(click) < 2 * int((10e-3 + self.after_length) * self.sr):
                np.pad(click, (0, 2 * int((10e-3 + self.after_length) * self.sr) - len(click)), mode='constant')
            ax_group.signal.im.set_ydata(norm(click))
            self._update_spectrogram(ax_group)
            ax_group.correlation.im.set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):]))
            ax_group.cepstrum.im.set_ydata(
                norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3 * self.sr)])))
            self._set_label()
            plt.draw()
    
        def _on_clicked_ax_group(self, event, ax_group, group_num, cell_num):
            if cell_num < 2:
                pos = event.xdata if cell_num == 0 else event.xdata * 1000
                cur_type = self.curr_vert[group_num]
                ax_group.signal.cursors[cur_type].set_xdata(2*(pos,))
                ax_group.spectrogram.cursors[cur_type].set_xdata(2*(pos*10**-3,))
                ax_group.signal.cursors[cur_type].set_visible(True)
                ax_group.spectrogram.cursors[cur_type].set_visible(True)
                if cur_type == 0:
                    self.df[self.offset[self.curr_ind[group_num], 0]]['p1_pos'] = pos
                self.curr_vert[group_num] ^= 1
                cur_type ^= 1
                self.cursor[group_num].linev[0].set_linestyle(['--', '-'][cur_type])
                self.cursor[group_num].linev[self.cursor[group_num].num_cur].set_linestyle(['--', '-'][cur_type])
                if ax_group.signal.cursors[1].get_visible():
                    ipi_man = ax_group.signal_ipi
                    self.df[self.offset[self.curr_ind[group_num], 0]]['ipi_sig'] = ipi_man
                    ax_group.signal.axes.set_xlabel(f'Sig man:{ipi_man:.5f}')
            else:
                cell = ax_group[cell_num]
                cell.cursors[0].set_xdata((event.xdata, event.xdata))
                cell.cursors[0].set_visible(True)
                lim_min = max(int(self.sr/1e3*(event.xdata-IPIPK)), 0)
                ipi_auto = np.argmax(cell.im.get_data()[1][lim_min:int(self.sr/1e3*(event.xdata+IPIPK))]) + lim_min
                col = 'ipi_' + ('corr' if cell_num == 2 else 'ceps')
                self.df[self.offset[self.curr_ind[group_num], 0]][col + '_man'] = event.xdata
                self.df[self.offset[self.curr_ind[group_num], 0]][col + '_auto'] = ipi_auto*1e3/self.sr
                cell.axes.set_xlabel(f'{"Corr" if cell_num == 2 else "Ceps"} man:{event.xdata:.3f} '
                                f'auto:{ipi_auto*1e3/self.sr:.3f}')
            plt.draw()
    
        def on_clicked(self, event):
            if event.inaxes == self.fax:
                self._on_clicked_fax(event)
            for i, ax_group in enumerate(self.view_data):  # Look if a click plot was clicked and which one
                for j in range(len(ax_group)):
                    if event.inaxes == ax_group[j].axes:
                        self._on_clicked_ax_group(event, ax_group, i, j)
                        return
    
        def change_curr(self, label):
            self.curr = int(label[-1])
            self.f_cursor.linev.set_color('rgbkcmyrgbkcmy'[self.curr])
            self.reset_b.label.set_c('rgbkcmyrgbkcmy'[self.curr])
            plt.draw()
    
        def toggle_ind(self, event):
            if not len(self.offset):
                return
            self.ind_select = not self.ind_select
            self.ind_b.color = 'limegreen' if self.ind_select else '0.85'
            self.ind_b.hovercolor = 'lime' if self.ind_select else '0.95'
            if self.ind_select:
                self.df[self.offset[self.curr_ind[self.curr], 0]]['ind_number'] = ''
            self.ind_b.label.set_text(f'Current individual:\n'
                                      f'{self.df[self.offset[self.curr_ind[self.curr], 0]]["ind_number"]}')
            plt.draw()
    
        def key_pressed(self, event):
            if self.ind_select:
                row = self.df[self.offset[self.curr_ind[self.curr], 0]]
                if event.key == 'backspace':
                    row['ind_number'] = row['ind_number'][:-1]
                elif event.key in ['shift', 'control', 'alt']:
                    pass
                else:
                    row['ind_number'] = row['ind_number'] + event.key
                self.ind_b.label.set_text(f'Current individual:\n{row["ind_number"]}')
                plt.draw()
            else:
                if event.key in '012':
                    self.change_curr(event.key)
                    self.r_button.set_active(int(event.key))
    
        def play(self, event):
            sound = (norm(self.song_resample[self.p:self.p+FSSR*20])*(2**15-1)).astype(np.int16)
            try:
                play(AudioSegment(sound.tobytes(), frame_rate=FSSR, sample_width=sound.dtype.itemsize, channels=1))
            except KeyboardInterrupt:
                pass
    
        def resize(self, event):
            self.fax.get_figure().set_constrained_layout(True)
            plt.draw()
            plt.pause(0.2)
            self.fax.get_figure().set_constrained_layout(False)
    
        def onaxis(self, event):
            row = self.df[self.offset[self.curr_ind[self.curr], 0]]
            if row['onaxis'] == -1:
                row['onaxis'] = 1
            else:
                row['onaxis'] ^= 1
            self.onaxis_b.label.set_text('On-axis'
                                         if row['onaxis'] else 'Off-axis')
            plt.draw()
    
        def increase_freq(self, event):
            lim = self.view_data[0].spectrogram.axes.get_ylim()[1] + 1e3
            for ax_group in self.view_data:
                ax_group.spectrogram.axes.set_ylim(0, lim)
            plt.draw()
    
        def decrease_freq(self, event):
            lim = max(self.view_data[0].spectrogram.axes.get_ylim()[1] - 1e3, 1e3)
            for ax_group in self.view_data:
                ax_group.spectrogram.axes.set_ylim(0, lim)
            plt.draw()
    
        def _update_spectrogram(self, ax_group):
            click = ax_group.signal.im.get_ydata()
            spec, _, t = plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1, pad_to=2*self.nfft)
            spec = np.flipud(20*np.log10(spec))
            ax_group.spectrogram.im.set_data(spec)
            ax_group.spectrogram.im.set_clim(spec.max() - SPSC, spec.max())
            ax_group.spectrogram.im.set_extent([t[0] - 1/self.sr/2, t[-1] + 1/self.sr/2, 0., self.sr/2])
    
        def increase_res(self, event):
            if self.nfft > int(10e-3*self.sr):
                return
            self.nfft *= 2
            for ax_group in self.view_data:
                self._update_spectrogram(ax_group)
            if self.nfft > int(10e-3*self.sr):
                self.spec_b['plus_res'].label.set_text('Can\'t go\nhigher')
            else:
                self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins')
            self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins')
            plt.draw()
    
        def decrease_res(self, event):
            if self.nfft < 8:
                return
            self.nfft //= 2
            for ax_group in self.view_data:
                self._update_spectrogram(ax_group)
            self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins')
            if self.nfft < 8:
                self.spec_b['minus_res'].label.set_text('Can\'t go\nlower')
            else:
                self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins')
            plt.draw()
    
        def _set_label(self, ind=None, dic=None):
            if ind is None:
                ind = self.curr
            if dic is None:
                dic = self.df[self.offset[self.curr_ind[ind], 0]]
            ax_group = self.view_data[ind]
            ax_group.signal.axes.set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
            ax_group.correlation.axes.set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
            ax_group.cepstrum.axes.set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
            self.ind_b.label.set_text(f'Current individual:\n{dic["ind_number"]}')
            if self.onaxis_b is not None:
                self.onaxis_b.label.set_text(['Off', 'On', '?'][dic['onaxis']] + '-axis')
    
        def _set_visible(self, ind=None, state=False):
            if ind is None:
                ind = self.curr
            self.view_data[ind].set_visible(state)
    
        def reset_curr(self, event):
            self.df[self.offset[self.curr_ind[self.curr], 0]] = EMLN.copy()
            self._set_label()
            self._set_visible()
            self.curr_vert[self.curr] = 0
            self.cursor[self.curr].linev[0].set_linestyle('--')
            self.cursor[self.curr].linev[self.cursor[self.curr].num_cur].set_linestyle('--')
            self.view_data[self.curr].set_visible(False)
            plt.draw()
    
        def reset(self, song, sr, song_resample):
            self.p = 0
            self.df = dict()
            self.song = song
            self.song_resample = song_resample
            self._shit()
            sr_update = False
            if self.sr != sr:
                self.sr = sr
                sr_update = True
            self.change_curr('0')  # reset current view to 0
            self.r_button.set_active(0)
            self.offset = np.zeros((0, 2))
            self.scat.set_offsets(self.offset)
            self.scat.set_color([[0, 0, 0, 1]])
            self.curr_ind = self.num_view * [None]  # Ind of click for each plot
            self.curr_vert = self.num_view * [0]  # Current vertical line of sig/spec for each  plot
            for ax_group in self.view_data:
                ax_group.signal.im.set_ydata(np.zeros(int((10e-3 + self.after_length) * sr)))
                ax_group.correlation.im.set_ydata(np.zeros(int(10e-3 * sr)))
                ax_group.cepstrum.im.set_ydata(np.zeros(int(10e-3 * sr)))
                if sr_update:
                    ax_group.signal.im.set_xdata(np.linspace(0, (10e-3 + self.after_length)*1e3,
                                                         int((10e-3 + self.after_length)*sr), False))
                    ax_group.correlation.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
                    ax_group.cepstrum.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
                ax_group.spectrogram.im.set_clim(2000, 2100)
            for i in range(self.num_view):
                self._set_label(i, EMLN)
                self._set_visible(i)
            self.ind_select = False
            self.ind_b.color = '0.85'
            self.ind_b.howercolor = '0.95'
            self.ind_b.label.set_text(f'Current individual:\nnan')
            plt.draw()
    
    
    class AxesWithCursor:
        __slots__ = ['axes', 'im', 'cursors']
    
        def __init__(self, axes: plt.Axes, im=None, cursors=None):
            self.axes = axes
            self.im = im
            self.cursors = list() if cursors is None else cursors
    
        def set_visible(self, state):
            for cursor in self.cursors:
                cursor.set_visible(state)
    
        @property
        def figure(self):
            return self.axes.figure
    
    
    class AxesGroup:
        __slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum']
    
        def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2, after_length=10e-3):
            self.signal = AxesWithCursor(ax_sig,
                                         ax_sig.plot(np.linspace(0, 10 + after_length*1e3, int((10e-3 + after_length)*sr),
                                                                 False), np.zeros(int((10e-3 + after_length)*sr)))[0],
                                         (ax_sig.axvline(10, c='k', linestyle='--'),) +
                                         tuple(ax_sig.axvline(10, c='k') for _ in range(num_cur)))
            ax_sig.set_xlim(0, 10 + after_length*1e3)
            ax_sig.set_xlabel('IPI man:None')
            ax_sig.set_ylim(-1, 1)
            ax_sig.set_yticks([])
            self.spectrogram = AxesWithCursor(ax_spec,
                                              ax_spec.specgram(np.random.normal(0, 1e-6, int((10e-3 + after_length)*sr)),
                                                               Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1],
                                              (ax_spec.axvline(0.01, c='k', linestyle='--'),) +
                                              tuple(ax_spec.axvline(0.01, c='k') for _ in range(num_cur)))
            self.spectrogram.im.set_clim(2000, 2100)
            ax_spec.set_ylim(0, min(20e3, sr/2))
            ax_spec.set_yticks(ax_spec.get_yticks())  # Needed, otherwise updating ylim doesn't update ticks properly
            ax_spec.set_yticklabels((ax_spec.get_yticks() / 1e3).astype(int))
            self.correlation = AxesWithCursor(ax_corr,
                                              ax_corr.plot(np.linspace(0, 10, int(10e-3 * sr), False),
                                                           np.zeros(int(10e-3 * sr)))[0],
                                              (ax_corr.axvline(10, c='k'),))
            ax_corr.set_xlim(0, 10)
            ax_corr.set_xlabel('IPI man:None auto:None')
            ax_corr.set_ylim(-1, 1)
            ax_corr.set_yticks([])
    
            self.cepstrum = AxesWithCursor(ax_cep,
                                           ax_cep.plot(np.linspace(0, 10, int(10e-3 * sr), False),
                                                       np.zeros(int(10e-3 * sr)))[0],
                                           (ax_cep.axvline(10, c='k'),))
            ax_cep.set_xlim(0, 10)
            ax_cep.set_xlabel('IPI man:None auto:None')
            ax_cep.set_ylim(0, 0.33)
            ax_cep.set_yticks([])
            self.set_visible(False)
    
        def __getitem__(self, item):
            return (self.signal, self.spectrogram, self.correlation, self.cepstrum)[item]
    
        def __len__(self):
            return 4
    
        def __contains__(self, item):
            if isinstance(item, AxesWithCursor):
                return any(item == ax for ax in self)
            else:
                return any(item == ax.axes for ax in self)
    
        def set_visible(self, state):
            self.signal.set_visible(state)
            self.spectrogram.set_visible(state)
            self.correlation.set_visible(state)
            self.cepstrum.set_visible(state)
    
        @property
        def signal_ipi(self):
            return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0]
    
    
    def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3, after_length=10e-3):
        song, sr, song_resample = load_file(in_path, channel, low, high)
        fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
        gs = fig.add_gridspec(12, 20)
    
        full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1])  # Type annotation needed with plt 3.1
        callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
                            song, song_resample, sr, full_sig, num_graph, after_length)
        callback.fax = full_sig
        full_sig.set_xlim(0, 20)
        lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2
        full_sig.set_ylim(-lim, lim)
        full_sig.set_yticks([])
        callback.f_cursor = Cursor(full_sig, horizOn=False, useblit=True, c='r')
        cid1 = fig.canvas.mpl_connect('button_release_event', callback.on_clicked)
    
        b_left_ax = plt.subplot(gs[:2, 0])
        b_right_ax = plt.subplot(gs[:2, -1])
        b_left = Button(b_left_ax, '<|')
        b_right = Button(b_right_ax, '|>')
        b_left.on_clicked(callback.shift_left)
        b_right.on_clicked(callback.shift_right)
        r_button_ax = plt.subplot(gs[10, 1:-1])
        r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(num_graph)], orientation='horizontal',
                                  size=128, activecolor=list('rgbkcmyrgbkcmy'[:num_graph]))  # !Last char of labels is use as index in rest of code
        r_button_ax.axis('off')
        r_button.on_clicked(callback.change_curr)
        for i, c in enumerate('rgbkcmyrgbkcmy'[:num_graph]):
            r_button.labels[i].set_c(c)
        callback.r_button = r_button
        for v in "fullscreen,home,back,forward,pan,zoom,save,quit,grid,yscale,xscale,all_axes".split(','):
            plt.rcParams[f'keymap.{v}'] = [] # disable default shortcut but fullsreen
        cid2 = fig.canvas.mpl_connect('key_press_event', callback.key_pressed)
        # c_button_ax = plt.subplot(gs[10,3:6])
        # c_button = CheckButtons(c_button_ax, [f'Save {i}' for i in range(3)], [False for i in range(3)])
    
        vfs = 4
        vs = 2
        hs = 1
        hfs = (20-hs)//(2*num_graph)
        ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
                              plt.subplot(gs[vs:vs + vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
                              plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
                              plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
                              sr, num_cur, after_length) for i in range(num_graph)]
    
        play_b_ax = plt.subplot(gs[-1, :2])
        play_b = Button(play_b_ax, 'Play\ncurrent segment')
        play_b.on_clicked(callback.play)
    
        resize_b_ax = plt.subplot(gs[-1, 2:4])
        resize_b = Button(resize_b_ax, 'Resize plot')
        resize_b.on_clicked(callback.resize)
    
        reset_b_ax = plt.subplot(gs[-1,4:6])
        reset_b = Button(reset_b_ax, 'Reset current')
        reset_b.label.set_c('r')
        reset_b.on_clicked(callback.reset_curr)
        callback.reset_b = reset_b
    
        ind_b_ax = plt.subplot(gs[-1,6:8])
        ind_b = Button(ind_b_ax, 'Current individual:\nnan')
        ind_b.on_clicked(callback.toggle_ind)
        callback.ind_b = ind_b
    
        freq_p_b_ax = plt.subplot(gs[2, 0])
        freq_m_b_ax = plt.subplot(gs[3, 0])
        freq_res_p_b_ax = plt.subplot(gs[4, 0])
        freq_res_m_b_ax = plt.subplot(gs[5, 0])
        freq_p_b = Button(freq_p_b_ax, '+\n1kHz')
        freq_p_b.on_clicked(callback.increase_freq)
        freq_m_b = Button(freq_m_b_ax, '-\n1kHz')
        freq_m_b.on_clicked(callback.decrease_freq)
        freq_res_p_b = Button(freq_res_p_b_ax, '256\nbins')
        freq_res_p_b.on_clicked(callback.increase_res)
        freq_res_m_b = Button(freq_res_m_b_ax, '64\nbins')
        freq_res_m_b.on_clicked(callback.decrease_res)
    
        spec_button = {'plus': freq_p_b, 'minus': freq_m_b, 'plus_res': freq_res_p_b, 'minus_res': freq_res_m_b}
        callback.spec_b = spec_button
    
        m_cursor: List[Union[MyMultiCursor, None]] = len(ax_group) * [None]
        callback.cursor = m_cursor
        for i in range(len(ax_group)):
            m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], num_cur, useblit=True, c='k')
        callback.view_data = ax_group
        return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons':
                {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
                 'ind_b': ind_b, 'fs_click': cid1, 'key_pressed': cid2, 'reset_b': reset_b, 'spec_button': spec_button}}
    
    
    def reset(callback, in_path, channel, low=2e3, high=20e3):
        song, sr, song_resample = load_file(in_path, channel, low, high)
        callback.reset(song, sr, song_resample)
    
    
    def main(args):
        if args.out == '':
            outpath = args.input.rsplit('.', 1)[0] + '.pred.h5'
        else:
            outpath = args.out
        if os.path.isfile(outpath):
            if not (args.erase or args.resume):
                print(f'Out file {outpath} already exist and erase or resume option isn\'t set.')
                return 1
        elif args.resume:
            print(f'Out file {outpath} does not already exist and resume option is set.')
            return 1
        if args.pulse < 1:
            print(f'{args.pulse} is an invalid number of pulses.')
            return 1
        elif args.click < 1:
            print(f'{args.click} is an invalid number of clicks.')
            return 1
    
        EMLN['onaxis'] = -1
        ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click, args.after*1e-3)
        if args.resume:
            df = pd.read_hdf(outpath)
            if 'onaxis' not in df.columns:
                df['onaxis'] = -1
            ref_dict['callback'].df = df.to_dict(orient='index')
            ref_dict['callback'].offset = np.tile(np.array(list(ref_dict['callback'].df.keys()))[:, np.newaxis], (1, 2))
            ref_dict['callback'].offset[:, 1] = ref_dict['callback'].song_resample[(ref_dict['callback'].offset[:, 1] * FSSR).astype(int)]
            ref_dict['callback'].scat.set_offsets(ref_dict['callback'].offset)
            colors = np.array(len(ref_dict['callback'].offset)*[[0.7, 0.2, 0.5, 1]])
            for i, p in enumerate(ref_dict['callback'].offset[:, 0]):
                if not np.all(pd.isnull(list(ref_dict['callback'].df[p].values()))):
                    colors[i] = [0, 0, 0, 1]
            ref_dict['callback'].scat.set_color(colors)
    
        onaxis_b_ax = plt.subplot(ref_dict['gridspec'][-1, 8:10])
        onaxis_b = Button(onaxis_b_ax, '?-axis')
        onaxis_b.on_clicked(ref_dict['callback'].onaxis)
        ref_dict['callback'].onaxis_b = onaxis_b
        plt.draw()
        plt.pause(0.2)
        ref_dict['fig'].set_constrained_layout(False)
        plt.show()
        df = pd.DataFrame.from_dict(ref_dict['callback'].df, orient='index')
        df.to_hdf(outpath, 'df', format='table')
        return 0
    
    
    if __name__ == '__main__':
    
        class ArgumentParser(argparse.ArgumentParser):
            def error(self, message):
                if message.startswith('the following arguments are required:'):
                    raise ValueError(message)
                super(ArgumentParser, self).error(message)
    
        parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument("input", type=str, help="Input file")
        parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
        parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
        parser.add_argument("--pulse", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
        parser.add_argument("--click", type=int, default=3, help="Number of click to plot simultaneously.")
        parser.add_argument("--after", type=float, default=10., help="Duration to plot after P1 in ms.")
        parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter')
        parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter')
        group = parser.add_mutually_exclusive_group()
        group.add_argument("--erase", action='store_true', help="If out file exist and this option is not given,"
                                                                 " the computation will be halted")
        group.add_argument("--resume", action='store_true', help="If out file exist and this option is given,"
                                                                 " the previous annotation file will be loaded")
        try:
            args = parser.parse_args()
        except ValueError as e:
            print(f'Error while parsing the command line arguments: {e}')
    
            def ask(string):
                y = {'y', 'yes', 'o', 'oui'}
                a = {'y', 'yes', 'o', 'oui', 'n', 'no', 'non'}
                while True:
                    ans = input(string + ' [y/n] ').lower()
                    if ans in a:
                        return ans in y
    
            if not ask('Do you want to manually specify them?'):
                sys.exit(2)  # exit code of invalid argparse
    
            class VirtualArgParse(object):
                def __init__(self):
                    self.input = input("What is the input file path? ")
                    self.out = input("What is the out file path? (Leave empty for default)")
                    self.channel = int(input("Which channel do you want to use starting from 0? "))
                    self.pulse = int(input("How many pulses do you want to display (> 1) "))
                    self.click = int(input("How many clicks do you want to display (> 1) "))
                    self.after = float(input("What duration do you want after P1?"))
                    self.low = float(input("What is the lower frequency of the bandpass? "))
                    self.up = float(input("What is the upper frequency of the bandpass? "))
                    self.erase = ask("Do you want to erase the out file if it already exist?")
                    if not self.erase:
                        self.resume = ask("Do you want to resume the out file if it already exist?")
                    else:
                        self.resume = False
    
            args = VirtualArgParse()
    
        sys.exit(main(args))