Skip to content
Snippets Groups Projects
Select Git revision
  • 08f03ccf184df98b9a8d52ac2974b13083891024
  • master default protected
2 results

ipi_extract.py

Blame
  • user avatar
    ferrari authored
    08f03ccf
    History
    ipi_extract.py 27.48 KiB
    import argparse
    import numpy as np
    import matplotlib.pyplot as plt
    import scipy.signal as sg
    import soundfile as sf
    import os
    import sys
    from matplotlib.widgets import Button, Cursor, RadioButtons, AxesWidget
    from fractions import Fraction
    from pydub import AudioSegment
    from pydub.playback import play
    import pandas as pd
    
    FSSR = 48_000  # Sampling rate of full signal plot
    FSPK = 0.1     # Max distance to detect a click in full sig in seconds
    IPIPK= 0.15    # Max distance to detect a IPI in milliseconds
    SPSC = 80      # Spectrogram scale
    EMLN = {'p1_pos':np.nan, 'ipi_sig': np.nan,
            'ipi_corr_man': np.nan, 'ipi_corr_auto': np.nan,
            'ipi_ceps_man': np.nan, 'ipi_ceps_auto': np.nan,
            'ind_number': np.nan}  # Empty dataline
    
    
    def read(file_path, always_2d=True):
        try:
            return sf.read(file_path, always_2d=always_2d)
        except Exception as e:
            return load_anysound(file_path)
    
    
    def load_anysound(file_path):
        tmp = AudioSegment.from_file(file_path)
        return np.array(tmp.get_array_of_samples()).reshape(-1, tmp.channels), tmp.frame_rate
    
    
    def load_file(in_path, channel, low, high):
        print(f'Loading and processing {in_path}')
        song, sr = read(in_path, always_2d=True)
        song = song[:, channel]
        sos = sg.butter(3, [low, high], 'bandpass', fs=sr, output='sos')
        song = sg.sosfiltfilt(sos, song)
        frac = Fraction(FSSR, sr)
        song_resample = sg.resample_poly(song, frac.numerator, frac.denominator)
        print('Done processing')
        return song, sr, song_resample
    
    
    def norm(x):
        return x/(np.abs(x).max()+1e-10)
    
    
    def norm_std(x, alpha=1.5):
        return x/(1.5*np.std(x)+1e-10)
    
    
    class MyRadioButtons(RadioButtons):
    
        def __init__(self, ax, labels, active=0, activecolor='blue', size=49,
                     orientation="vertical", **kwargs):
            """
            Add radio buttons to an `~.axes.Axes`.
            Parameters
            ----------
            ax : `~matplotlib.axes.Axes`
                The axes to add the buttons to.
            labels : list of str
                The button labels.
            active : int
                The index of the initially selected button.
            activecolor : color
                The color of the selected button.
            size : float
                Size of the radio buttons
            orientation : str
                The orientation of the buttons: 'vertical' (default), or 'horizontal'.
            Further parameters are passed on to `Legend`.
            """
            AxesWidget.__init__(self, ax)
            self._activecolor = activecolor
            axcolor = ax.get_facecolor()
            self.value_selected = None
    
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_navigate(False)
            circles = []
            for i, label in enumerate(labels):
                if i == active:
                    self.value_selected = label
                    facecolor = self.activecolor
                else:
                    facecolor = axcolor
                p = ax.scatter([],[], s=size, marker="o", edgecolor='black',
                               facecolor=facecolor)
                circles.append(p)
            if orientation == "horizontal":
                kwargs.update(ncol=len(labels), mode="expand")
            kwargs.setdefault("frameon", False)
            self.box = ax.legend(circles, labels, loc="center", **kwargs)
            self.labels = self.box.texts
            self.circles = self.box.legendHandles
            for c in self.circles:
                c.set_picker(5)
            self.cnt = 0
            self.observers = {}
            self.connect_event('pick_event', self._clicked)
    
        def _clicked(self, event):
            if self.ignore(event) or event.mouseevent.button != 1 or event.mouseevent.inaxes != self.ax:
                return
            if event.artist in self.circles:
                self.set_active(self.circles.index(event.artist))
    
        @property
        def activecolor(self):
            if hasattr(self._activecolor, '__getitem__'):
                return self._activecolor[int(self.value_selected[-1])]
            else:
                return self._activecolor
    
    
    class MyMultiCursor(AxesWidget):
        def __init__(self, axes, p1, horizOn=True, vertOn=True, useblit=False,
                     **lineprops):
            AxesWidget.__init__(self, axes[0])
            self.axes = axes
            self.n_axes = len(axes)
            self.p1 = p1
            self.connect_event('motion_notify_event', self.onmove)
            self.connect_event('draw_event', self.clear)
    
            self.visible = True
            self.horizOn = horizOn
            self.vertOn = vertOn
            self.useblit = useblit and self.canvas.supports_blit
    
            if self.useblit:
                lineprops['animated'] = True
            #self.lineh = []
            self.linev = []
            for i in range(self.n_axes):
                #self.lineh.append(axes[i].axhline(axes[i].get_ybound()[0], visible=False, **lineprops))
                self.linev.append(axes[i].axvline(axes[i].get_xbound()[0], visible=False, **lineprops))
    
            self.background = self.n_axes*[None]
            self.needclear = False
    
        def clear(self, event):
            """Internal event handler to clear the cursor."""
            if self.ignore(event):
                return
            for i in range(self.n_axes):
                if self.useblit:
                    self.background[i] = self.canvas.copy_from_bbox(self.axes[i].bbox)
                self.linev[i].set_visible(False)
                #self.lineh[i].set_visible(False)
    
        def onmove(self, event):
            """Internal event handler to draw the cursor when the mouse moves."""
            if self.ignore(event):
                return
            if not self.canvas.widgetlock.available(self):
                return
            if event.inaxes not in self.axes:
                for i in range(self.n_axes):
                    self.linev[i].set_visible(False)
                    #self.lineh[i].set_visible(False)
    
                if self.needclear:
                    self.canvas.draw()
                    self.needclear = False
                return
            self.needclear = True
            if not self.visible:
                return
            ax_idx = self.axes.index(event.inaxes)
            if ax_idx < 2:
                if ax_idx == 1:
                    pos = event.xdata*1e3
                else:
                    pos = event.xdata
                if self.linev[0].get_linestyle() == '--':
                    self.linev[0].set_xdata((pos, pos))
                    self.linev[1].set_xdata((pos / 1e3, pos / 1e3))
                    self.linev[0].set_visible(self.visible and self.vertOn)
                    self.linev[1].set_visible(self.visible and self.vertOn)
                    self._update()
                    return
                else:
                    ipi_pos = abs(pos - self.p1.get_xdata()[0])
                    for i in range(self.n_axes):
                        if i == 0:
                            self.linev[i].set_xdata((pos, pos))
                        elif i == 1:
                            self.linev[i].set_xdata((pos / 1e3, pos / 1e3))
                        else:
                            self.linev[i].set_xdata((ipi_pos, ipi_pos))
                        # self.lineh[i].set_ydata((event.ydata, event.ydata))
                        self.linev[i].set_visible(self.visible and self.vertOn)
                        # self.lineh[i].set_visible(self.visible and self.horizOn)
                    self._update()
                    return
            else:
                ipi_pos = event.xdata
                pos = self.p1.get_xdata()[0] + ipi_pos
                restore = False
            for i in range(self.n_axes):
                if self.p1.get_visible():
                    if i == 0:
                        if self.linev[0].get_linestyle() == '--':
                            restore = True
                            self.linev[0].set_linestyle('-')
                            self.linev[1].set_linestyle('-')
                        self.linev[i].set_xdata((pos, pos))
                    elif i == 1:
                        self.linev[i].set_xdata((pos/1e3, pos/1e3))
                    self.linev[i].set_visible(self.visible and self.vertOn)
                if i > 1:
                    self.linev[i].set_xdata((ipi_pos, ipi_pos))
                    #self.lineh[i].set_ydata((event.ydata, event.ydata))
                    self.linev[i].set_visible(self.visible and self.vertOn)
                    #self.lineh[i].set_visible(self.visible and self.horizOn)
    
            self._update()
    
            if restore:
                self.linev[0].set_linestyle('--')
                self.linev[1].set_linestyle('--')
    
        def _update(self):
            if self.useblit:
                for i in range(self.n_axes):
                    if self.background[i] is not None:
                        self.canvas.restore_region(self.background[i])
                    self.axes[i].draw_artist(self.linev[i])
                    #self.axes[i].draw_artist(self.lineh[i])
                    self.canvas.blit(self.axes[i].bbox)
            else:
                self.canvas.draw_idle()
            return False
    
    
    class Callback(object):
        def __init__(self, line, song, song_resample, sr, full_ax):
            self.p = 0
            self.df = dict()
            self.line = line
            self.song = song
            self.song_resample = song_resample
            self.sr = sr
            self.fax = full_ax
            self.view_ax = None
            self.view_data = None
            self.curr = 0  # current view selected
            self.scat = self.fax.scatter([],[], marker='x', c=[[0.7, 0.2, 0.5, 1]])
            self.offset = np.zeros((0,2))
            self.curr_ind = 3*[None]  # Ind of click for each plot
            self.curr_vert = 3*[0]    # Current vertical line of sig/spec for each  plot
            self.cursor = None
            self.f_cursor = None
            self.reset_b = None
    
        def shift_left(self, event):
            self.p = max(0, self.p - FSSR*13)
            self._shit()
    
        def shift_right(self, event):
            self.p = min(len(self.song_resample) - FSSR*20, self.p + FSSR*13)
            self._shit()
    
        def _shit(self):
            self.line.set_ydata(self.song_resample[self.p:self.p+FSSR*20])
            lim = np.abs(self.song_resample[self.p:self.p+FSSR*20]).max()*1.2
            self.fax.set_ylim(-lim, lim)
            self.line.set_xdata(np.linspace(self.p/FSSR, self.p/FSSR+20, FSSR*20, False))
            self.fax.set_xlim(self.p/FSSR, self.p/FSSR+20)
            plt.draw()
    
    
        def on_clicked(self, event):
            if event.inaxes == self.fax:  # Click on full signal plot
                pos = int(FSSR*event.xdata)
                mpos = np.argmax(self.song_resample[max(pos-int(FSSR*FSPK),0):pos+int(FSSR*FSPK)]) + max(pos-int(FSSR*FSPK),0)
                if self.curr_ind[self.curr] is not None:
                    self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1]
                if mpos/FSSR not in self.offset[:,0]:
                    self.offset = np.concatenate([self.offset, [[mpos/FSSR, self.song_resample[mpos]]]], axis=0)
                    self.scat.set_offsets(self.offset)
                    self.df[mpos/FSSR] = EMLN.copy()
                    c = [[0, 0, 0, 1]]
                    c[0][self.curr] = 1
                    if len(self.offset) == 1:
                        self.scat.set_color(c)
                    else:
                        self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0))
                    self.curr_ind[self.curr] = len(self.offset) - 1
                    self.curr_vert[self.curr] = 0
                    self.cursor[self.curr].linev[0].set_linestyle('--')
                    self.cursor[self.curr].linev[1].set_linestyle('--')
                    for i in range(4):
                        if i < 2:
                            self.view_data[self.curr][i][1][0].set_visible(False)
                            self.view_data[self.curr][i][1][1].set_visible(False)
                        else:
                            self.view_data[self.curr][i][1].set_visible(False)
                else:
                    self.curr_ind[self.curr] = np.argmax(mpos/FSSR == self.offset[:,0])
                    c = [0, 0, 0, 1]
                    k = 0
                    for i, v in enumerate(self.curr_ind):
                        if v == self.curr_ind[self.curr]:
                            c[i] = 1
                            k += 1
                    for i in range(3):
                        c[i] /=k
                    self.scat._facecolors[self.curr_ind[self.curr]] = c
                    self.scat.set_color(self.scat._facecolors)
                    row = self.df[mpos/FSSR]
                    if np.isnan(row['p1_pos']):
                        self.view_data[self.curr][0][1][0].set_visible(False)
                        self.view_data[self.curr][0][1][1].set_visible(False)
                        self.view_data[self.curr][1][1][0].set_visible(False)
                        self.view_data[self.curr][1][1][1].set_visible(False)
                    else:
                        self.view_data[self.curr][0][1][0].set_xdata((row['p1_pos'], row['p1_pos']))
                        self.view_data[self.curr][1][1][0].set_xdata((row['p1_pos']*1e-3, row['p1_pos']*1e-3))
                        self.view_data[self.curr][0][1][0].set_visible(True)
                        self.view_data[self.curr][1][1][0].set_visible(True)
                        if np.isnan(row['ipi_sig']):
                            self.view_data[self.curr][0][1][1].set_visible(False)
                            self.view_data[self.curr][1][1][1].set_visible(False)
                        else:
                            p2 = row['p1_pos'] + row['ipi_sig']
                            self.view_data[self.curr][0][1][1].set_xdata((p2, p2))
                            self.view_data[self.curr][1][1][1].set_xdata((p2*1e-3, p2*1e-3))
                            self.view_data[self.curr][0][1][1].set_visible(True)
                            self.view_data[self.curr][1][1][1].set_visible(True)
                    if np.isnan(row['ipi_corr_man']):
                        self.view_data[self.curr][2][1].set_visible(False)
                    else:
                        self.view_data[self.curr][2][1].set_xdata((row['ipi_corr_man'], row['ipi_corr_man']))
                        self.view_data[self.curr][2][1].set_visible(True)
                    if np.isnan(row['ipi_ceps_man']):
                        self.view_data[self.curr][3][1].set_visible(False)
                    else:
                        self.view_data[self.curr][3][1].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man']))
                        self.view_data[self.curr][3][1].set_visible(True)
                click = self.song[max(int(mpos*self.sr/FSSR-10e-3*self.sr),0):int(mpos*self.sr/FSSR+10e-3*self.sr)]
                if len(click) != 2*int(10e-3*self.sr):
                    np.pad(click, (0, 2*int(10e-3*self.sr) - len(click)), mode='constant')
                self.view_data[self.curr][0][0].set_ydata(norm(click))
                spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=128, noverlap=127)[0]))
                self.view_data[self.curr][1][0].set_data(spec)
                self.view_data[self.curr][1][0].set_clim(spec.max()-SPSC, spec.max())
                self.view_data[self.curr][2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3*self.sr):]))
                self.view_data[self.curr][3][0].set_ydata(norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3*self.sr)])))
                self._set_label()
                plt.draw()
                return
            for i in range(3):  # Look if a click plot was clicked and which one
                for j in range(4):
                    if event.inaxes == self.view_ax[i][j]:
                        break
                else:
                    continue
                break
            else:
                return
            if j < 2:
                pos = event.xdata * 10**(3*j)
                self.view_data[i][0][1][self.curr_vert[i]].set_xdata((pos, pos))
                self.view_data[i][1][1][self.curr_vert[i]].set_xdata((pos*10**-3, pos*10**-3))
                self.view_data[i][0][1][self.curr_vert[i]].set_visible(True)
                self.view_data[i][1][1][self.curr_vert[i]].set_visible(True)
                if self.curr_vert[i] == 0:
                    self.df[self.offset[self.curr_ind[i],0]]['p1_pos'] = pos
                self.curr_vert[i] ^= 1
                self.cursor[i].linev[0].set_linestyle(['--','-'][self.curr_vert[i]])
                self.cursor[i].linev[1].set_linestyle(['--','-'][self.curr_vert[i]])
                if self.view_data[i][0][1][1].get_visible():
                    ipi_man = self.view_data[i][0][1][1].get_xdata()[0] - self.view_data[i][0][1][0].get_xdata()[0]
                    self.df[self.offset[self.curr_ind[i], 0]]['ipi_sig'] = ipi_man
                    self.view_ax[i][0].set_xlabel(f'Sig man:{ipi_man:.5f}')
            else:
                self.view_data[i][j][1].set_xdata((event.xdata, event.xdata))
                self.view_data[i][j][1].set_visible(True)
                ipi_auto = np.argmax(self.view_data[i][j][0].get_data()[1][
                           max(int(self.sr/1e3*(event.xdata-IPIPK)),0):int(self.sr/1e3*(event.xdata+IPIPK))])\
                           + max(int(self.sr/1e3*(event.xdata-IPIPK)),0)
                col = 'ipi_' + ('corr' if j == 2 else 'ceps')
                self.df[self.offset[self.curr_ind[i],0]][col + '_man'] = event.xdata
                self.df[self.offset[self.curr_ind[i],0]][col + '_auto'] = ipi_auto*1e3/self.sr
                self.view_ax[i][j].set_xlabel(f'{"Corr" if j == 2 else "Ceps"} man:{event.xdata:.3f} auto:{ipi_auto*1e3/self.sr:.3f}')
            plt.draw()
    
        def change_curr(self, label):
            self.curr = int(label[-1])
            self.f_cursor.linev.set_color('rgb'[self.curr])
            self.reset_b.label.set_c('rgb'[self.curr])
            plt.draw()
    
        def play(self, event):
            sound = (norm(self.song_resample[self.p:self.p+FSSR*20])*(2**15-1)).astype(np.int16)
            try:
                play(AudioSegment(sound.tobytes(), frame_rate=self.sr, sample_width=sound.dtype.itemsize, channels=1))
            except KeyboardInterrupt:
                pass
    
        def resize(self, event):
            self.fax.get_figure().set_constrained_layout(True)
            plt.draw()
            plt.pause(0.2)
            self.fax.get_figure().set_constrained_layout(False)
    
        def _set_label(self, ind=None, dic=None):
            if ind is None:
                ind = self.curr
            if dic is None:
                dic = self.df[self.offset[self.curr_ind[ind], 0]]
            self.view_ax[ind][0].set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
            self.view_ax[ind][2].set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
            self.view_ax[ind][3].set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
    
        def _set_visible(self, ind=None, state=False):
            if ind is None:
                ind = self.curr
            self.view_data[ind][0][1][0].set_visible(state)
            self.view_data[ind][0][1][1].set_visible(state)
            self.view_data[ind][1][1][0].set_visible(state)
            self.view_data[ind][1][1][1].set_visible(state)
            self.view_data[ind][2][1].set_visible(state)
            self.view_data[ind][3][1].set_visible(state)
    
    
        def reset_curr(self, event):
            self.df[self.offset[self.curr_ind[self.curr], 0]] = EMLN.copy()
            self._set_label()
            self._set_visible()
            plt.draw()
    
        def reset(self, song, sr, song_resample):
            self.p = 0
            self.df = dict()
            self.line.set_ydata(song_resample[:FSSR * 20])
            self.song = song
            self.song_resample = song_resample
            self.sr = sr
            self.curr = 0  # current view selected
            self.offset = np.zeros((0, 2))
            self.scat.set_offsets(self.offset)
            self.curr_ind = 3 * [None]  # Ind of click for each plot
            self.curr_vert = 3 * [0]  # Current vertical line of sig/spec for each  plot
            for i in range(3):
                self.view_data[i][1][0].set_clim(2000,2100)
            for i in range(3):
                self._set_label(i, EMLN)
                self._set_visible(i)
            plt.draw()
    
    
    def init(in_path, channel, low=2e3, high=20e3):
        song, sr, song_resample = load_file(in_path, channel, low, high)
        fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
        gs = fig.add_gridspec(12, 20)
    
        full_sig = plt.subplot(gs[:2, 1:-1])
        callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
                            song, song_resample, sr, full_sig)
        callback.fax = full_sig
        full_sig.set_xlim(0, 20)
        lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2
        full_sig.set_ylim(-lim, lim)
        full_sig.set_yticks([])
        callback.f_cursor = Cursor(full_sig, horizOn=False, useblit=True, c='r')
        cid = fig.canvas.mpl_connect('button_release_event', callback.on_clicked)
    
        b_left_ax = plt.subplot(gs[:2, 0])
        b_right_ax = plt.subplot(gs[:2, -1])
        b_left = Button(b_left_ax, '<|')
        b_right = Button(b_right_ax, '|>')
        b_left.on_clicked(callback.shift_left)
        b_right.on_clicked(callback.shift_right)
        r_button_ax = plt.subplot(gs[10, 1:-1])
        r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(3)], orientation='horizontal',
                                  size=128, activecolor=list('rgb'))  # !Last char of labels is use as index in rest of code
        r_button_ax.axis('off')
        r_button.on_clicked(callback.change_curr)
        for i, c in enumerate('rgb'):
            r_button.labels[i].set_c(c)
        # c_button_ax = plt.subplot(gs[10,3:6])
        # c_button = CheckButtons(c_button_ax, [f'Save {i}' for i in range(3)], [False for i in range(3)])
    
        vfs = 4
        vs = 2
        hfs = 3
        hs = 1
        ax_view = [[plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]),
                    plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]),
                    plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]),
                    plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs])],
                   [plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]),
                    plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]),
                    plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]),
                    plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs])],
                   [plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]),
                    plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]),
                    plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]),
                    plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs])]]
        callback.view_ax = ax_view
    
        play_b_ax = plt.subplot(gs[-1, :2])
        play_b = Button(play_b_ax, 'Play\ncurrent segment')
        play_b.on_clicked(callback.play)
    
        resize_b_ax = plt.subplot(gs[-1, 2:4])
        resize_b = Button(resize_b_ax, 'Resize plot')
        resize_b.on_clicked(callback.resize)
    
        reset_b_ax = plt.subplot(gs[-1,4:6])
        reset_b = Button(reset_b_ax, 'Reset current')
        reset_b.label.set_c('r')
        reset_b.on_clicked(callback.reset_curr)
        callback.reset_b = reset_b
    
        # text_b_ax = plt.subplot(gs[-1,6:8])
        # text_b = TextBox(text_b_ax, 'Individue #\nof current')
        # # text_b.on_clicked(callback.resize)
    
        data_view = [[2 * [None] for _ in range(4)] for _ in range(3)]
        m_cursor = [None for _ in range(3)]
        # m_cursor2 = [[None for _ in range(4)] for _ in range(3)]
        callback.cursor = m_cursor
        for i in range(3):
            data_view[i][0][1] = (ax_view[i][0].axvline(10, c='k', linestyle='--'), ax_view[i][0].axvline(10, c='k'))
            data_view[i][0][0] = ax_view[i][0].plot(np.linspace(0, 20, int(20e-3 * sr), False), np.zeros(int(20e-3 * sr)))[
                0]
            ax_view[i][0].set_xlim(0, 20)
            ax_view[i][0].set_xlabel('IPI man:None')
            ax_view[i][0].set_ylim(-1, 1)
            data_view[i][0][1][0].set_visible(False)
            data_view[i][0][1][1].set_visible(False)
    
            data_view[i][1][0] = ax_view[i][1].specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)),
                                                        Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1]
            data_view[i][1][0].set_clim(2000,2100)
            data_view[i][1][1] = (ax_view[i][1].axvline(0.01, c='k', linestyle='--'), ax_view[i][1].axvline(0.01, c='k'))
            data_view[i][1][1][0].set_visible(False)
            data_view[i][1][1][1].set_visible(False)
    
            data_view[i][2][1] = ax_view[i][2].axvline(10, c='k')
            data_view[i][2][0] = ax_view[i][2].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[
                0]
            ax_view[i][2].set_xlim(0, 10)
            ax_view[i][2].set_xlabel('IPI man:None auto:None')
            ax_view[i][2].set_ylim(-1, 1)
            data_view[i][2][1].set_visible(False)
    
            data_view[i][3][1] = ax_view[i][3].axvline(10, c='k')
            data_view[i][3][0] = ax_view[i][3].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[
                0]
            ax_view[i][3].set_xlim(0, 10)
            ax_view[i][3].set_xlabel('IPI man:None auto:None')
            ax_view[i][3].set_ylim(0, 1)
            data_view[i][3][1].set_visible(False)
    
            for j in range(4):
                # m_cursor2[i][j] = Cursor(ax_view[i][j], horizOn=False, useblit=True, c='k')
                if j != 1:
                    ax_view[i][j].set_yticks([])
                else:
                    ax_view[i][j].set_yticklabels((ax_view[i][j].get_yticks() / 1e3).astype(int))
            # m_cursor2[i][0].linev.set_linestyle('--')
            # m_cursor2[i][1].linev.set_linestyle('--')
            m_cursor[i] = MyMultiCursor(ax_view[i], data_view[i][0][1][0], horizOn=False, useblit=True, c='k')
            m_cursor[i].linev[0].set_linestyle('--')
            m_cursor[i].linev[1].set_linestyle('--')
        callback.view_data = data_view
        return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons':
            {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
             'fs_click': cid, 'reset_b': reset_b}}  # Needed to keep the callbacks alive
    
    
    def reset(callback, in_path, channel, low=2e3, high=20e3):
        song, sr, song_resample = load_file(in_path, channel, low, high)
        callback.reset(song, sr, song_resample)
    
    
    def main(args):
        if args.out == '':
            outpath = args.input.rsplit('.', 1)[0] + '.pred.h5'
        else:
            outpath = args.out
        if os.path.isfile(outpath) and not args.erase:
            print(f'Out file {outpath} already exist and erase option isn\'t set.')
            return 1
        ref_dict = init(args.input, args.channel)
        plt.draw()
        plt.pause(0.2)
        ref_dict['fig'].set_constrained_layout(False)
        plt.show()
        df = pd.DataFrame.from_dict(ref_dict['callback'].df, orient='index')
        df.to_hdf(outpath, 'df')
        return 0
    
    
    if __name__ == '__main__':
    
        class ArgumentParser(argparse.ArgumentParser):
            def error(self, message):
                if message.startswith('the following arguments are required:'):
                    raise ValueError(message)
                super(ArgumentParser, self).error(message)
    
        parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument("input", type=str, help="Input file")
        parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
        parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
        parser.add_argument("--erase", action='store_true', help="If out file exist and this option is not given,"
                                                                 " the computation will be halted")
        try:
            args = parser.parse_args()
        except ValueError as e:
            print(f'Error while parsing the command line arguments: {e}')
    
            def ask(string):
                y = {'y', 'yes', 'o', 'oui'}
                a = {'y', 'yes', 'o', 'oui', 'n', 'no', 'non'}
                while True:
                    ans = input(string + ' [y/n] ').lower()
                    if ans in a:
                        return ans in y
    
            if not ask('Do you want to manually specify them?'):
                sys.exit(2)  # exit code of invalid argparse
    
            class VirtualArgParse(object):
                def __init__(self):
                    self.input = input("What is the input file path? ")
                    self.out = input("What is the out file path? (Leave empty for default)")
                    self.channel = int(input("Which channel do you want to use starting from 0? "))
                    self.erase = ask("Do you want to erase the out file if it already exist?")
    
            args = VirtualArgParse()
    
        sys.exit(main(args))