import argparse
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as sg
import soundfile as sf
import os
import sys
from matplotlib.widgets import Button, Cursor, RadioButtons, AxesWidget
from fractions import Fraction
from pydub import AudioSegment
from pydub.playback import play
import pandas as pd

FSSR = 48_000  # Sampling rate of full signal plot
FSPK = 0.1     # Max distance to detect a click in full sig in seconds
IPIPK= 0.15    # Max distance to detect a IPI in milliseconds
SPSC = 80      # Spectrogram scale
EMLN = {'p1_pos':np.nan, 'ipi_sig': np.nan,
        'ipi_corr_man': np.nan, 'ipi_corr_auto': np.nan,
        'ipi_ceps_man': np.nan, 'ipi_ceps_auto': np.nan,
        'ind_number': np.nan}  # Empty dataline


def read(file_path, always_2d=True):
    try:
        return sf.read(file_path, always_2d=always_2d)
    except Exception as e:
        return load_anysound(file_path)


def load_anysound(file_path):
    tmp = AudioSegment.from_file(file_path)
    return np.array(tmp.get_array_of_samples()).reshape(-1, tmp.channels), tmp.frame_rate


def load_file(in_path, channel, low, high):
    print(f'Loading and processing {in_path}')
    song, sr = read(in_path, always_2d=True)
    song = song[:, channel]
    sos = sg.butter(3, [low, high], 'bandpass', fs=sr, output='sos')
    song = sg.sosfiltfilt(sos, song)
    frac = Fraction(FSSR, sr)
    song_resample = sg.resample_poly(song, frac.numerator, frac.denominator)
    print('Done processing')
    return song, sr, song_resample


def norm(x):
    return x/(np.abs(x).max()+1e-10)


def norm_std(x, alpha=1.5):
    return x/(1.5*np.std(x)+1e-10)


class MyRadioButtons(RadioButtons):

    def __init__(self, ax, labels, active=0, activecolor='blue', size=49,
                 orientation="vertical", **kwargs):
        """
        Add radio buttons to an `~.axes.Axes`.
        Parameters
        ----------
        ax : `~matplotlib.axes.Axes`
            The axes to add the buttons to.
        labels : list of str
            The button labels.
        active : int
            The index of the initially selected button.
        activecolor : color
            The color of the selected button.
        size : float
            Size of the radio buttons
        orientation : str
            The orientation of the buttons: 'vertical' (default), or 'horizontal'.
        Further parameters are passed on to `Legend`.
        """
        AxesWidget.__init__(self, ax)
        self._activecolor = activecolor
        axcolor = ax.get_facecolor()
        self.value_selected = None

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_navigate(False)
        circles = []
        for i, label in enumerate(labels):
            if i == active:
                self.value_selected = label
                facecolor = self.activecolor
            else:
                facecolor = axcolor
            p = ax.scatter([],[], s=size, marker="o", edgecolor='black',
                           facecolor=facecolor)
            circles.append(p)
        if orientation == "horizontal":
            kwargs.update(ncol=len(labels), mode="expand")
        kwargs.setdefault("frameon", False)
        self.box = ax.legend(circles, labels, loc="center", **kwargs)
        self.labels = self.box.texts
        self.circles = self.box.legendHandles
        for c in self.circles:
            c.set_picker(5)
        self.cnt = 0
        self.observers = {}
        self.connect_event('pick_event', self._clicked)

    def _clicked(self, event):
        if self.ignore(event) or event.mouseevent.button != 1 or event.mouseevent.inaxes != self.ax:
            return
        if event.artist in self.circles:
            self.set_active(self.circles.index(event.artist))

    @property
    def activecolor(self):
        if hasattr(self._activecolor, '__getitem__'):
            return self._activecolor[int(self.value_selected[-1])]
        else:
            return self._activecolor


class MyMultiCursor(AxesWidget):
    def __init__(self, axes, p1, horizOn=True, vertOn=True, useblit=False,
                 **lineprops):
        AxesWidget.__init__(self, axes[0])
        self.axes = axes
        self.n_axes = len(axes)
        self.p1 = p1
        self.connect_event('motion_notify_event', self.onmove)
        self.connect_event('draw_event', self.clear)

        self.visible = True
        self.horizOn = horizOn
        self.vertOn = vertOn
        self.useblit = useblit and self.canvas.supports_blit

        if self.useblit:
            lineprops['animated'] = True
        #self.lineh = []
        self.linev = []
        for i in range(self.n_axes):
            #self.lineh.append(axes[i].axhline(axes[i].get_ybound()[0], visible=False, **lineprops))
            self.linev.append(axes[i].axvline(axes[i].get_xbound()[0], visible=False, **lineprops))

        self.background = self.n_axes*[None]
        self.needclear = False

    def clear(self, event):
        """Internal event handler to clear the cursor."""
        if self.ignore(event):
            return
        for i in range(self.n_axes):
            if self.useblit:
                self.background[i] = self.canvas.copy_from_bbox(self.axes[i].bbox)
            self.linev[i].set_visible(False)
            #self.lineh[i].set_visible(False)

    def onmove(self, event):
        """Internal event handler to draw the cursor when the mouse moves."""
        if self.ignore(event):
            return
        if not self.canvas.widgetlock.available(self):
            return
        if event.inaxes not in self.axes:
            for i in range(self.n_axes):
                self.linev[i].set_visible(False)
                #self.lineh[i].set_visible(False)

            if self.needclear:
                self.canvas.draw()
                self.needclear = False
            return
        self.needclear = True
        if not self.visible:
            return
        ax_idx = self.axes.index(event.inaxes)
        if ax_idx < 2:
            if ax_idx == 1:
                pos = event.xdata*1e3
            else:
                pos = event.xdata
            if self.linev[0].get_linestyle() == '--':
                self.linev[0].set_xdata((pos, pos))
                self.linev[1].set_xdata((pos / 1e3, pos / 1e3))
                self.linev[0].set_visible(self.visible and self.vertOn)
                self.linev[1].set_visible(self.visible and self.vertOn)
                self._update()
                return
            else:
                ipi_pos = abs(pos - self.p1.get_xdata()[0])
                for i in range(self.n_axes):
                    if i == 0:
                        self.linev[i].set_xdata((pos, pos))
                    elif i == 1:
                        self.linev[i].set_xdata((pos / 1e3, pos / 1e3))
                    else:
                        self.linev[i].set_xdata((ipi_pos, ipi_pos))
                    # self.lineh[i].set_ydata((event.ydata, event.ydata))
                    self.linev[i].set_visible(self.visible and self.vertOn)
                    # self.lineh[i].set_visible(self.visible and self.horizOn)
                self._update()
                return
        else:
            ipi_pos = event.xdata
            pos = self.p1.get_xdata()[0] + ipi_pos
            restore = False
        for i in range(self.n_axes):
            if self.p1.get_visible():
                if i == 0:
                    if self.linev[0].get_linestyle() == '--':
                        restore = True
                        self.linev[0].set_linestyle('-')
                        self.linev[1].set_linestyle('-')
                    self.linev[i].set_xdata((pos, pos))
                elif i == 1:
                    self.linev[i].set_xdata((pos/1e3, pos/1e3))
                self.linev[i].set_visible(self.visible and self.vertOn)
            if i > 1:
                self.linev[i].set_xdata((ipi_pos, ipi_pos))
                #self.lineh[i].set_ydata((event.ydata, event.ydata))
                self.linev[i].set_visible(self.visible and self.vertOn)
                #self.lineh[i].set_visible(self.visible and self.horizOn)

        self._update()

        if restore:
            self.linev[0].set_linestyle('--')
            self.linev[1].set_linestyle('--')

    def _update(self):
        if self.useblit:
            for i in range(self.n_axes):
                if self.background[i] is not None:
                    self.canvas.restore_region(self.background[i])
                self.axes[i].draw_artist(self.linev[i])
                #self.axes[i].draw_artist(self.lineh[i])
                self.canvas.blit(self.axes[i].bbox)
        else:
            self.canvas.draw_idle()
        return False


class Callback(object):
    def __init__(self, line, song, song_resample, sr, full_ax):
        self.p = 0
        self.df = dict()
        self.line = line
        self.song = song
        self.song_resample = song_resample
        self.sr = sr
        self.fax = full_ax
        self.view_ax = None
        self.view_data = None
        self.curr = 0  # current view selected
        self.scat = self.fax.scatter([],[], marker='x', c=[[0.7, 0.2, 0.5, 1]])
        self.offset = np.zeros((0,2))
        self.curr_ind = 3*[None]  # Ind of click for each plot
        self.curr_vert = 3*[0]    # Current vertical line of sig/spec for each  plot
        self.cursor = None
        self.f_cursor = None
        self.reset_b = None
        self.spec_b = None
        self.nfft = 128

    def shift_left(self, event):
        self.p = max(0, self.p - FSSR*13)
        self._shit()

    def shift_right(self, event):
        self.p = min(len(self.song_resample) - FSSR*20, self.p + FSSR*13)
        self._shit()

    def _shit(self):
        self.line.set_ydata(self.song_resample[self.p:self.p+FSSR*20])
        lim = np.abs(self.song_resample[self.p:self.p+FSSR*20]).max()*1.2
        self.fax.set_ylim(-lim, lim)
        self.line.set_xdata(np.linspace(self.p/FSSR, self.p/FSSR+20, FSSR*20, False))
        self.fax.set_xlim(self.p/FSSR, self.p/FSSR+20)
        plt.draw()


    def on_clicked(self, event):
        if event.inaxes == self.fax:  # Click on full signal plot
            pos = int(FSSR*event.xdata)
            mpos = np.argmax(self.song_resample[max(pos-int(FSSR*FSPK),0):pos+int(FSSR*FSPK)]) + max(pos-int(FSSR*FSPK),0)
            if self.curr_ind[self.curr] is not None:
                self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1]
            if mpos/FSSR not in self.offset[:,0]:
                self.offset = np.concatenate([self.offset, [[mpos/FSSR, self.song_resample[mpos]]]], axis=0)
                self.scat.set_offsets(self.offset)
                self.df[mpos/FSSR] = EMLN.copy()
                c = [[0, 0, 0, 1]]
                c[0][self.curr] = 1
                if len(self.offset) == 1:
                    self.scat.set_color(c)
                else:
                    self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0))
                self.curr_ind[self.curr] = len(self.offset) - 1
                self.curr_vert[self.curr] = 0
                self.cursor[self.curr].linev[0].set_linestyle('--')
                self.cursor[self.curr].linev[1].set_linestyle('--')
                for i in range(4):
                    if i < 2:
                        self.view_data[self.curr][i][1][0].set_visible(False)
                        self.view_data[self.curr][i][1][1].set_visible(False)
                    else:
                        self.view_data[self.curr][i][1].set_visible(False)
            else:
                self.curr_ind[self.curr] = np.argmax(mpos/FSSR == self.offset[:,0])
                c = [0, 0, 0, 1]
                k = 0
                for i, v in enumerate(self.curr_ind):
                    if v == self.curr_ind[self.curr]:
                        c[i] = 1
                        k += 1
                for i in range(3):
                    c[i] /=k
                self.scat._facecolors[self.curr_ind[self.curr]] = c
                self.scat.set_color(self.scat._facecolors)
                row = self.df[mpos/FSSR]
                if np.isnan(row['p1_pos']):
                    self.view_data[self.curr][0][1][0].set_visible(False)
                    self.view_data[self.curr][0][1][1].set_visible(False)
                    self.view_data[self.curr][1][1][0].set_visible(False)
                    self.view_data[self.curr][1][1][1].set_visible(False)
                else:
                    self.view_data[self.curr][0][1][0].set_xdata((row['p1_pos'], row['p1_pos']))
                    self.view_data[self.curr][1][1][0].set_xdata((row['p1_pos']*1e-3, row['p1_pos']*1e-3))
                    self.view_data[self.curr][0][1][0].set_visible(True)
                    self.view_data[self.curr][1][1][0].set_visible(True)
                    if np.isnan(row['ipi_sig']):
                        self.view_data[self.curr][0][1][1].set_visible(False)
                        self.view_data[self.curr][1][1][1].set_visible(False)
                    else:
                        p2 = row['p1_pos'] + row['ipi_sig']
                        self.view_data[self.curr][0][1][1].set_xdata((p2, p2))
                        self.view_data[self.curr][1][1][1].set_xdata((p2*1e-3, p2*1e-3))
                        self.view_data[self.curr][0][1][1].set_visible(True)
                        self.view_data[self.curr][1][1][1].set_visible(True)
                if np.isnan(row['ipi_corr_man']):
                    self.view_data[self.curr][2][1].set_visible(False)
                else:
                    self.view_data[self.curr][2][1].set_xdata((row['ipi_corr_man'], row['ipi_corr_man']))
                    self.view_data[self.curr][2][1].set_visible(True)
                if np.isnan(row['ipi_ceps_man']):
                    self.view_data[self.curr][3][1].set_visible(False)
                else:
                    self.view_data[self.curr][3][1].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man']))
                    self.view_data[self.curr][3][1].set_visible(True)
            click = self.song[max(int(mpos*self.sr/FSSR-10e-3*self.sr),0):int(mpos*self.sr/FSSR+10e-3*self.sr)]
            if len(click) != 2*int(10e-3*self.sr):
                np.pad(click, (0, 2*int(10e-3*self.sr) - len(click)), mode='constant')
            self.view_data[self.curr][0][0].set_ydata(norm(click))
            spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1)[0]))
            self.view_data[self.curr][1][0].set_data(spec)
            self.view_data[self.curr][1][0].set_clim(spec.max()-SPSC, spec.max())
            self.view_data[self.curr][2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3*self.sr):]))
            self.view_data[self.curr][3][0].set_ydata(norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3*self.sr)])))
            self._set_label()
            plt.draw()
            return
        for i in range(3):  # Look if a click plot was clicked and which one
            for j in range(4):
                if event.inaxes == self.view_ax[i][j]:
                    break
            else:
                continue
            break
        else:
            return
        if j < 2:
            pos = event.xdata * 10**(3*j)
            self.view_data[i][0][1][self.curr_vert[i]].set_xdata((pos, pos))
            self.view_data[i][1][1][self.curr_vert[i]].set_xdata((pos*10**-3, pos*10**-3))
            self.view_data[i][0][1][self.curr_vert[i]].set_visible(True)
            self.view_data[i][1][1][self.curr_vert[i]].set_visible(True)
            if self.curr_vert[i] == 0:
                self.df[self.offset[self.curr_ind[i],0]]['p1_pos'] = pos
            self.curr_vert[i] ^= 1
            self.cursor[i].linev[0].set_linestyle(['--','-'][self.curr_vert[i]])
            self.cursor[i].linev[1].set_linestyle(['--','-'][self.curr_vert[i]])
            if self.view_data[i][0][1][1].get_visible():
                ipi_man = self.view_data[i][0][1][1].get_xdata()[0] - self.view_data[i][0][1][0].get_xdata()[0]
                self.df[self.offset[self.curr_ind[i], 0]]['ipi_sig'] = ipi_man
                self.view_ax[i][0].set_xlabel(f'Sig man:{ipi_man:.5f}')
        else:
            self.view_data[i][j][1].set_xdata((event.xdata, event.xdata))
            self.view_data[i][j][1].set_visible(True)
            ipi_auto = np.argmax(self.view_data[i][j][0].get_data()[1][
                       max(int(self.sr/1e3*(event.xdata-IPIPK)),0):int(self.sr/1e3*(event.xdata+IPIPK))])\
                       + max(int(self.sr/1e3*(event.xdata-IPIPK)),0)
            col = 'ipi_' + ('corr' if j == 2 else 'ceps')
            self.df[self.offset[self.curr_ind[i],0]][col + '_man'] = event.xdata
            self.df[self.offset[self.curr_ind[i],0]][col + '_auto'] = ipi_auto*1e3/self.sr
            self.view_ax[i][j].set_xlabel(f'{"Corr" if j == 2 else "Ceps"} man:{event.xdata:.3f} auto:{ipi_auto*1e3/self.sr:.3f}')
        plt.draw()

    def change_curr(self, label):
        self.curr = int(label[-1])
        self.f_cursor.linev.set_color('rgb'[self.curr])
        self.reset_b.label.set_c('rgb'[self.curr])
        plt.draw()

    def play(self, event):
        sound = (norm(self.song_resample[self.p:self.p+FSSR*20])*(2**15-1)).astype(np.int16)
        try:
            play(AudioSegment(sound.tobytes(), frame_rate=FSSR, sample_width=sound.dtype.itemsize, channels=1))
        except KeyboardInterrupt:
            pass

    def resize(self, event):
        self.fax.get_figure().set_constrained_layout(True)
        plt.draw()
        plt.pause(0.2)
        self.fax.get_figure().set_constrained_layout(False)

    def increase_freq(self, event):
        self.view_ax[self.curr][1].set_ylim(0, self.view_ax[self.curr][1].get_ylim()[1]+1e3)
        plt.draw()

    def decrease_freq(self, event):
        self.view_ax[self.curr][1].set_ylim(0, max(self.view_ax[self.curr][1].get_ylim()[1]-1e3,1e3))
        plt.draw()

    def increase_res(self, event):
        if self.nfft > int(10e-3*self.sr):
            return
        click = self.view_data[self.curr][0][0].get_ydata()
        self.nfft *= 2
        spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1)[0]))
        self.view_data[self.curr][1][0].set_data(spec)
        self.view_data[self.curr][1][0].set_clim(spec.max()-SPSC, spec.max())
        if self.nfft > int(10e-3*self.sr):
            self.spec_b['plus_res'].label.set_text('Can\'t go\nhigher')
        else:
            self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins')
        self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins')
        plt.draw()

    def decrease_res(self, event):
        if self.nfft <8:
            return
        click = self.view_data[self.curr][0][0].get_ydata()
        self.nfft //= 2
        spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1)[0]))
        self.view_data[self.curr][1][0].set_data(spec)
        self.view_data[self.curr][1][0].set_clim(spec.max()-SPSC, spec.max())
        self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins')
        if self.nfft <8:
            self.spec_b['minus_res'].label.set_text('Can\'t go\nlower')
        else:
            self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins')
        plt.draw()


    def _set_label(self, ind=None, dic=None):
        if ind is None:
            ind = self.curr
        if dic is None:
            dic = self.df[self.offset[self.curr_ind[ind], 0]]
        self.view_ax[ind][0].set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
        self.view_ax[ind][2].set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
        self.view_ax[ind][3].set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')

    def _set_visible(self, ind=None, state=False):
        if ind is None:
            ind = self.curr
        self.view_data[ind][0][1][0].set_visible(state)
        self.view_data[ind][0][1][1].set_visible(state)
        self.view_data[ind][1][1][0].set_visible(state)
        self.view_data[ind][1][1][1].set_visible(state)
        self.view_data[ind][2][1].set_visible(state)
        self.view_data[ind][3][1].set_visible(state)

    def reset_curr(self, event):
        self.df[self.offset[self.curr_ind[self.curr], 0]] = EMLN.copy()
        self._set_label()
        self._set_visible()
        plt.draw()

    def reset(self, song, sr, song_resample):
        self.p = 0
        self.df = dict()
        self.song = song
        self.song_resample = song_resample
        self._shit()
        sr_update = False
        if self.sr != sr:
            self.sr = sr
            sr_update = True
        self.curr = 0  # current view selected
        self.offset = np.zeros((0, 2))
        self.scat.set_offsets(self.offset)
        self.scat.set_color([[0, 0, 0, 1]])
        self.curr_ind = 3 * [None]  # Ind of click for each plot
        self.curr_vert = 3 * [0]  # Current vertical line of sig/spec for each  plot
        for i in range(3):
            self.view_data[i][0][0].set_ydata(np.zeros(int(20e-3 * sr)))
            self.view_data[i][2][0].set_ydata(np.zeros(int(10e-3 * sr)))
            self.view_data[i][3][0].set_ydata(np.zeros(int(10e-3 * sr)))
            if sr_update:
                self.view_data[i][0][0].set_xdata(np.linspace(0, 20, int(20e-3*sr), False))
                self.view_data[i][2][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
                self.view_data[i][3][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
            self.view_data[i][1][0].set_clim(2000,2100)
        for i in range(3):
            self._set_label(i, EMLN)
            self._set_visible(i)
        plt.draw()


def init(in_path, channel, low=2e3, high=20e3):
    song, sr, song_resample = load_file(in_path, channel, low, high)
    fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
    gs = fig.add_gridspec(12, 20)

    full_sig = plt.subplot(gs[:2, 1:-1])
    callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
                        song, song_resample, sr, full_sig)
    callback.fax = full_sig
    full_sig.set_xlim(0, 20)
    lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2
    full_sig.set_ylim(-lim, lim)
    full_sig.set_yticks([])
    callback.f_cursor = Cursor(full_sig, horizOn=False, useblit=True, c='r')
    cid = fig.canvas.mpl_connect('button_release_event', callback.on_clicked)

    b_left_ax = plt.subplot(gs[:2, 0])
    b_right_ax = plt.subplot(gs[:2, -1])
    b_left = Button(b_left_ax, '<|')
    b_right = Button(b_right_ax, '|>')
    b_left.on_clicked(callback.shift_left)
    b_right.on_clicked(callback.shift_right)
    r_button_ax = plt.subplot(gs[10, 1:-1])
    r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(3)], orientation='horizontal',
                              size=128, activecolor=list('rgb'))  # !Last char of labels is use as index in rest of code
    r_button_ax.axis('off')
    r_button.on_clicked(callback.change_curr)
    for i, c in enumerate('rgb'):
        r_button.labels[i].set_c(c)
    # c_button_ax = plt.subplot(gs[10,3:6])
    # c_button = CheckButtons(c_button_ax, [f'Save {i}' for i in range(3)], [False for i in range(3)])

    vfs = 4
    vs = 2
    hfs = 3
    hs = 1
    ax_view = [[plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]),
                plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]),
                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]),
                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs])],
               [plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]),
                plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]),
                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]),
                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs])],
               [plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]),
                plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]),
                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]),
                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs])]]
    callback.view_ax = ax_view

    play_b_ax = plt.subplot(gs[-1, :2])
    play_b = Button(play_b_ax, 'Play\ncurrent segment')
    play_b.on_clicked(callback.play)

    resize_b_ax = plt.subplot(gs[-1, 2:4])
    resize_b = Button(resize_b_ax, 'Resize plot')
    resize_b.on_clicked(callback.resize)

    reset_b_ax = plt.subplot(gs[-1,4:6])
    reset_b = Button(reset_b_ax, 'Reset current')
    reset_b.label.set_c('r')
    reset_b.on_clicked(callback.reset_curr)
    callback.reset_b = reset_b

    freq_p_b_ax = plt.subplot(gs[2, 0])
    freq_m_b_ax = plt.subplot(gs[3, 0])
    freq_res_p_b_ax = plt.subplot(gs[4, 0])
    freq_res_m_b_ax = plt.subplot(gs[5, 0])
    freq_p_b = Button(freq_p_b_ax, '+\n1kHz')
    freq_p_b.on_clicked(callback.increase_freq)
    freq_m_b = Button(freq_m_b_ax, '-\n1kHz')
    freq_m_b.on_clicked(callback.decrease_freq)
    freq_res_p_b = Button(freq_res_p_b_ax, '256\nbins')
    freq_res_p_b.on_clicked(callback.increase_res)
    freq_res_m_b = Button(freq_res_m_b_ax, '64\nbins')
    freq_res_m_b.on_clicked(callback.decrease_res)

    spec_button = {'plus': freq_p_b, 'minus': freq_m_b, 'plus_res': freq_res_p_b, 'minus_res': freq_res_m_b}
    callback.spec_b = spec_button

    data_view = [[2 * [None] for _ in range(4)] for _ in range(3)]
    m_cursor = [None for _ in range(3)]
    # m_cursor2 = [[None for _ in range(4)] for _ in range(3)]
    callback.cursor = m_cursor
    for i in range(3):
        data_view[i][0][1] = (ax_view[i][0].axvline(10, c='k', linestyle='--'), ax_view[i][0].axvline(10, c='k'))
        data_view[i][0][0] = ax_view[i][0].plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0]
        ax_view[i][0].set_xlim(0, 20)
        ax_view[i][0].set_xlabel('IPI man:None')
        ax_view[i][0].set_ylim(-1, 1)
        data_view[i][0][1][0].set_visible(False)
        data_view[i][0][1][1].set_visible(False)

        data_view[i][1][0] = ax_view[i][1].specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)),
                                                    Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1]
        data_view[i][1][0].set_clim(2000,2100)
        data_view[i][1][1] = (ax_view[i][1].axvline(0.01, c='k', linestyle='--'), ax_view[i][1].axvline(0.01, c='k'))
        data_view[i][1][1][0].set_visible(False)
        data_view[i][1][1][1].set_visible(False)

        data_view[i][2][1] = ax_view[i][2].axvline(10, c='k')
        data_view[i][2][0] = ax_view[i][2].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[
            0]
        ax_view[i][2].set_xlim(0, 10)
        ax_view[i][2].set_xlabel('IPI man:None auto:None')
        ax_view[i][2].set_ylim(-1, 1)
        data_view[i][2][1].set_visible(False)

        data_view[i][3][1] = ax_view[i][3].axvline(10, c='k')
        data_view[i][3][0] = ax_view[i][3].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[
            0]
        ax_view[i][3].set_xlim(0, 10)
        ax_view[i][3].set_xlabel('IPI man:None auto:None')
        ax_view[i][3].set_ylim(0, 1)
        data_view[i][3][1].set_visible(False)

        for j in range(4):
            # m_cursor2[i][j] = Cursor(ax_view[i][j], horizOn=False, useblit=True, c='k')
            if j != 1:
                ax_view[i][j].set_yticks([])
            else:
                ax_view[i][j].set_ylim(0, min(20e3, sr/2))
                ax_view[i][j].set_yticks(ax_view[i][j].get_yticks())
                ax_view[i][j].set_yticklabels((ax_view[i][j].get_yticks() / 1e3).astype(int))
        # m_cursor2[i][0].linev.set_linestyle('--')
        # m_cursor2[i][1].linev.set_linestyle('--')
        m_cursor[i] = MyMultiCursor(ax_view[i], data_view[i][0][1][0], horizOn=False, useblit=True, c='k')
        m_cursor[i].linev[0].set_linestyle('--')
        m_cursor[i].linev[1].set_linestyle('--')
    callback.view_data = data_view
    return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons':
        {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
         'fs_click': cid, 'reset_b': reset_b ,
         'spec_button': spec_button}}  # Needed to keep the callbacks alive


def reset(callback, in_path, channel, low=2e3, high=20e3):
    song, sr, song_resample = load_file(in_path, channel, low, high)
    callback.reset(song, sr, song_resample)


def main(args):
    if args.out == '':
        outpath = args.input.rsplit('.', 1)[0] + '.pred.h5'
    else:
        outpath = args.out
    if os.path.isfile(outpath) and not args.erase:
        print(f'Out file {outpath} already exist and erase option isn\'t set.')
        return 1
    ref_dict = init(args.input, args.channel)
    plt.draw()
    plt.pause(0.2)
    ref_dict['fig'].set_constrained_layout(False)
    plt.show()
    df = pd.DataFrame.from_dict(ref_dict['callback'].df, orient='index')
    df.to_hdf(outpath, 'df')
    return 0


if __name__ == '__main__':

    class ArgumentParser(argparse.ArgumentParser):
        def error(self, message):
            if message.startswith('the following arguments are required:'):
                raise ValueError(message)
            super(ArgumentParser, self).error(message)

    parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("input", type=str, help="Input file")
    parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
    parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
    parser.add_argument("--erase", action='store_true', help="If out file exist and this option is not given,"
                                                             " the computation will be halted")
    try:
        args = parser.parse_args()
    except ValueError as e:
        print(f'Error while parsing the command line arguments: {e}')

        def ask(string):
            y = {'y', 'yes', 'o', 'oui'}
            a = {'y', 'yes', 'o', 'oui', 'n', 'no', 'non'}
            while True:
                ans = input(string + ' [y/n] ').lower()
                if ans in a:
                    return ans in y

        if not ask('Do you want to manually specify them?'):
            sys.exit(2)  # exit code of invalid argparse

        class VirtualArgParse(object):
            def __init__(self):
                self.input = input("What is the input file path? ")
                self.out = input("What is the out file path? (Leave empty for default)")
                self.channel = int(input("Which channel do you want to use starting from 0? "))
                self.erase = ask("Do you want to erase the out file if it already exist?")

        args = VirtualArgParse()

    sys.exit(main(args))