import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import scipy.signal as sg
import soundfile as sf
import os
import sys
from matplotlib.widgets import Button, Cursor, RadioButtons, AxesWidget
from fractions import Fraction
from pydub import AudioSegment
from pydub.playback import play
import pandas as pd
from typing import List, Union

FSSR = 48_000  # Sampling rate of full signal plot
FSPK = 0.1     # Max distance to detect a click in full sig in seconds
IPIPK = 0.15   # Max distance to detect a IPI in milliseconds
SPSC = 80      # Spectrogram scale
EMLN = {'p1_pos': np.nan, 'ipi_sig': np.nan,
        'ipi_corr_man': np.nan, 'ipi_corr_auto': np.nan,
        'ipi_ceps_man': np.nan, 'ipi_ceps_auto': np.nan,
        'ind_number': np.nan}  # Empty dataline


def read(file_path, always_2d=True):
    try:
        return sf.read(file_path, always_2d=always_2d)
    except Exception as e:
        return load_anysound(file_path)


def load_anysound(file_path):
    tmp = AudioSegment.from_file(file_path)
    return np.array(tmp.get_array_of_samples()).reshape(-1, tmp.channels), tmp.frame_rate


def load_file(in_path, channel, low, high):
    print(f'Loading and processing {in_path}')
    song, sr = read(in_path, always_2d=True)
    song = song[:, channel]
    sos = sg.butter(3, [low, high], 'bandpass', fs=sr, output='sos')
    song = sg.sosfiltfilt(sos, song)
    if len(song) < 20*sr:
        song = np.pad(song, (1, 20*sr), mode='constant')
    frac = Fraction(FSSR, sr)
    song_resample = sg.resample_poly(song, frac.numerator, frac.denominator)
    print('Done processing')
    return song, sr, song_resample


def norm(x):
    return x/(np.abs(x).max()+1e-10)


def norm_std(x, alpha=1.5):
    return x/(1.5*np.std(x)+1e-10)


class MyRadioButtons(RadioButtons):

    def __init__(self, ax, labels, active=0, activecolor='blue', size=49,
                 orientation="vertical", **kwargs):
        """
        Add radio buttons to an `~.axes.Axes`.
        Parameters
        ----------
        ax : `~matplotlib.axes.Axes`
            The axes to add the buttons to.
        labels : list of str
            The button labels.
        active : int
            The index of the initially selected button.
        activecolor : color
            The color of the selected button.
        size : float
            Size of the radio buttons
        orientation : str
            The orientation of the buttons: 'vertical' (default), or 'horizontal'.
        Further parameters are passed on to `Legend`.
        """
        AxesWidget.__init__(self, ax)
        self._activecolor = activecolor
        axcolor = ax.get_facecolor()
        self.value_selected = None

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_navigate(False)
        circles = []
        for i, label in enumerate(labels):
            if i == active:
                self.value_selected = label
                facecolor = self.activecolor
            else:
                facecolor = axcolor
            p = ax.scatter([],[], s=size, marker="o", edgecolor='black',
                           facecolor=facecolor)
            circles.append(p)
        if orientation == "horizontal":
            kwargs.update(ncol=len(labels), mode="expand")
        kwargs.setdefault("frameon", False)
        self.box = ax.legend(circles, labels, loc="center", **kwargs)
        self.labels = self.box.texts
        self.circles = self.box.legendHandles
        for c in self.circles:
            c.set_picker(5)
        self.cnt = 0
        self.observers = {}
        self.connect_event('pick_event', self._clicked)

    def _clicked(self, event):
        if self.ignore(event) or event.mouseevent.button != 1 or event.mouseevent.inaxes != self.ax:
            return
        if event.artist in self.circles:
            self.set_active(self.circles.index(event.artist))

    @property
    def activecolor(self):
        if hasattr(self._activecolor, '__getitem__'):
            return self._activecolor[int(self.value_selected[-1])]
        else:
            return self._activecolor


class MyMultiCursor(AxesWidget):
    def __init__(self, axes, p1, num_cur, vertOn=True, useblit=False,
                 **lineprops):
        AxesWidget.__init__(self, axes[0])
        self.axes = axes
        self.p1 = p1
        self.num_cur = num_cur
        self.connect_event('motion_notify_event', self.onmove)
        self.connect_event('draw_event', self.clear)

        self.visible = True
        self.vertOn = vertOn
        self.useblit = useblit and self.canvas.supports_blit

        if self.useblit:
            lineprops['animated'] = True
        self.linev = []
        for ax in self.axes:
            for line in ax.cursors:
                if line.get_linestyle() == '--':
                    continue
                self.linev.append(ax.axes.axvline(ax.axes.get_xbound()[0], visible=False, **lineprops))
        self.linev[0].set_linestyle('--')
        self.linev[self.num_cur].set_linestyle('--')
        self.n_axes = len(self.linev)

        self.background = len(self.axes)*[None]
        self.needclear = False

    def clear(self, event):
        """Internal event handler to clear the cursor."""
        if self.ignore(event):
            return
        for i, ax in enumerate(self.axes):
            if self.useblit:
                self.background[i] = self.canvas.copy_from_bbox(ax.axes.bbox)
            self.linev[i].set_visible(False)

    def onmove(self, event):
        """Internal event handler to draw the cursor when the mouse moves."""
        if self.ignore(event):
            return
        if not self.canvas.widgetlock.available(self):
            return
        if event.inaxes not in self.axes:
            for i in range(self.n_axes):
                self.linev[i].set_visible(False)
            if self.needclear:
                self.canvas.draw()
                self.needclear = False
            return
        self.needclear = True
        if not self.visible:
            return
        for ax_idx, ax in enumerate(self.axes):
            if ax.axes == event.inaxes:
                break
        if ax_idx < 2:
            if ax_idx == 1:
                pos = event.xdata*1e3
            else:
                pos = event.xdata
            if self.linev[0].get_linestyle() == '--':
                self.linev[0].set_xdata((pos, pos))
                self.linev[self.num_cur].set_xdata((pos / 1e3, pos / 1e3))
                self.linev[0].set_visible(self.visible and self.vertOn)
                self.linev[self.num_cur].set_visible(self.visible and self.vertOn)
                self._update()
                return
            else:
                ipi_pos = abs(pos - self.p1.get_xdata()[0])
                for i in range(self.n_axes):
                    if i < self.num_cur:
                        self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
                    elif i < 2*self.num_cur:
                        self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
                    else:
                        self.linev[i].set_xdata((ipi_pos, ipi_pos))
                    self.linev[i].set_visible(self.visible and self.vertOn)
                self._update()
                return
        else:
            ipi_pos = event.xdata
            pos = self.p1.get_xdata()[0] + ipi_pos
            restore = False
        if self.linev[0].get_linestyle() == '--':
            restore = True
            self.linev[0].set_linestyle('-')
            self.linev[self.num_cur].set_linestyle('-')
        for i in range(self.n_axes):
            if self.p1.get_visible():
                if i < self.num_cur:
                    self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
                elif i < 2 * self.num_cur:
                    self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
                self.linev[i].set_visible(self.visible and self.vertOn)
            if i >= 2 * self.num_cur:
                self.linev[i].set_xdata((ipi_pos, ipi_pos))
                self.linev[i].set_visible(self.visible and self.vertOn)
        self._update()

        if restore:
            self.linev[0].set_linestyle('--')
            self.linev[self.num_cur].set_linestyle('--')

    def _update(self):
        if self.useblit:
            k = 0
            for bk, ax in zip(self.background, self.axes):
                if bk is not None:
                    self.canvas.restore_region(bk)
                for k in range(k, k + self.num_cur if k < 2*self.num_cur else k+1):
                    ax.axes.draw_artist(self.linev[k])
                k += 1
                self.canvas.blit(ax.axes.bbox)
        else:
            self.canvas.draw_idle()
        return False


class Callback(object):
    def __init__(self, line, song, song_resample, sr, full_ax, num_view, after_length):
        self.p = 0
        self.df = dict()
        self.line = line
        self.song = song
        self.song_resample = song_resample
        self.sr = sr
        self.fax = full_ax
        self.num_view = num_view
        self.after_length = after_length
        self.view_data: [AxesGroup, None] = None
        self.curr = 0  # current view selected
        self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]])
        self.offset = np.zeros((0, 2))
        self.curr_ind: List[Union[int, None]] = num_view*[None]  # Indices of click for each plot
        self.curr_vert = num_view*[0]    # Current vertical line of sig/spec for each  plot
        self.cursor = None
        self.f_cursor = None
        self.reset_b = None
        self.r_button = None
        self.ind_b = None
        self.onaxis_b = None
        self.spec_b = None
        self.nfft = 128
        self.ind_select = False

    def shift_left(self, event):
        self.p = max(0, self.p - FSSR*13)
        self._shit()

    def shift_right(self, event):
        self.p = min(len(self.song_resample) - FSSR*20, self.p + FSSR*13)
        self._shit()

    def _shit(self):
        self.line.set_ydata(self.song_resample[self.p:self.p+FSSR*20])
        lim = np.abs(self.song_resample[self.p:self.p+FSSR*20]).max()*1.2
        self.fax.set_ylim(-lim, lim)
        self.line.set_xdata(np.linspace(self.p/FSSR, self.p/FSSR+20, FSSR*20, False))
        self.fax.set_xlim(self.p/FSSR, self.p/FSSR+20)
        plt.draw()

    def _on_clicked_fax(self, event):
        # Click on full signal plot
        pos = int(FSSR * event.xdata)
        mpos = np.argmax(self.song_resample[max(pos - int(FSSR * FSPK), 0):
                                            pos + int(FSSR * FSPK)]) + max(pos - int(FSSR * FSPK), 0)
        ax_group = self.view_data[self.curr]
        if self.curr_ind[self.curr] is not None:
            if np.all(pd.isnull(list(self.df[self.offset[self.curr_ind[self.curr], 0]].values()))):
                self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1]
            else:
                self.scat._facecolors[self.curr_ind[self.curr]] = [0, 0, 0, 1]
        if mpos / FSSR not in self.offset[:, 0]:
            self.offset = np.concatenate([self.offset, [[mpos / FSSR, self.song_resample[mpos]]]], axis=0)
            self.scat.set_offsets(self.offset)
            self.df[mpos / FSSR] = EMLN.copy()
            c = [[0, 0, 0, 1]]
            c[0][self.curr] = 1
            if len(self.offset) == 1:
                self.scat.set_color(c)
            else:
                self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0))
            self.curr_ind[self.curr] = len(self.offset) - 1
            self.curr_vert[self.curr] = 0
            self.cursor[self.curr].linev[0].set_linestyle('--')
            self.cursor[self.curr].linev[self.cursor[self.curr].num_cur].set_linestyle('--')
            ax_group.set_visible(False)
        else:
            self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0])
            c = np.array([0, 0, 0, 0.])
            k = 0
            for i, v in enumerate(self.curr_ind):
                if v == self.curr_ind[self.curr]:
                    c += to_rgba('rgbkcmyrgbkcmy'[i])
                    k += 1
            for i in range(self.num_view):
                c /= k
            self.scat._facecolors[self.curr_ind[self.curr]] = c
            self.scat.set_color(self.scat._facecolors)
            row = self.df[mpos / FSSR]
            if np.isnan(row['p1_pos']):
                ax_group.signal.set_visible(False)
                ax_group.spectrogram.set_visible(False)
            else:
                ax_group.signal.cursors[0].set_xdata((row['p1_pos'], row['p1_pos']))
                ax_group.spectrogram.cursors[0].set_xdata((row['p1_pos'] * 1e-3, row['p1_pos'] * 1e-3))
                ax_group.signal.cursors[0].set_visible(True)
                ax_group.spectrogram.cursors[0].set_visible(True)
                if np.isnan(row['ipi_sig']):
                    ax_group.signal.cursors[1].set_visible(False)
                    ax_group.spectrogram.cursors[1].set_visible(False)
                else:
                    p2 = row['p1_pos'] + row['ipi_sig']
                    ax_group.signal.cursors[1].set_xdata((p2, p2))
                    ax_group.spectrogram.cursors[1].set_xdata((p2 * 1e-3, p2 * 1e-3))
                    ax_group.signal.cursors[1].set_visible(True)
                    ax_group.spectrogram.cursors[1].set_visible(True)
            if np.isnan(row['ipi_corr_man']):
                ax_group.correlation.set_visible(False)
            else:
                ax_group.correlation.cursors[0].set_xdata((row['ipi_corr_man'], row['ipi_corr_man']))
                ax_group.correlation.set_visible(True)
            if np.isnan(row['ipi_ceps_man']):
                ax_group.cepstrum.set_visible(False)
            else:
                ax_group.cepstrum.cursors[0].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man']))
                ax_group.cepstrum.set_visible(True)
        self.ind_select = False
        self.ind_b.color = '0.85'
        self.ind_b.howercolor = '0.95'
        self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos / FSSR]["ind_number"]}')
        click = self.song[max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):
                          int(mpos * self.sr / FSSR + self.after_length * self.sr)]
        if len(click) < 2 * int((10e-3 + self.after_length) * self.sr):
            np.pad(click, (0, 2 * int((10e-3 + self.after_length) * self.sr) - len(click)), mode='constant')
        ax_group.signal.im.set_ydata(norm(click))
        self._update_spectrogram(ax_group)
        ax_group.correlation.im.set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):]))
        ax_group.cepstrum.im.set_ydata(
            norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3 * self.sr)])))
        self._set_label()
        plt.draw()

    def _on_clicked_ax_group(self, event, ax_group, group_num, cell_num):
        if cell_num < 2:
            pos = event.xdata if cell_num == 0 else event.xdata * 1000
            cur_type = self.curr_vert[group_num]
            ax_group.signal.cursors[cur_type].set_xdata(2*(pos,))
            ax_group.spectrogram.cursors[cur_type].set_xdata(2*(pos*10**-3,))
            ax_group.signal.cursors[cur_type].set_visible(True)
            ax_group.spectrogram.cursors[cur_type].set_visible(True)
            if cur_type == 0:
                self.df[self.offset[self.curr_ind[group_num], 0]]['p1_pos'] = pos
            self.curr_vert[group_num] ^= 1
            cur_type ^= 1
            self.cursor[group_num].linev[0].set_linestyle(['--', '-'][cur_type])
            self.cursor[group_num].linev[self.cursor[group_num].num_cur].set_linestyle(['--', '-'][cur_type])
            if ax_group.signal.cursors[1].get_visible():
                ipi_man = ax_group.signal_ipi
                self.df[self.offset[self.curr_ind[group_num], 0]]['ipi_sig'] = ipi_man
                ax_group.signal.axes.set_xlabel(f'Sig man:{ipi_man:.5f}')
        else:
            cell = ax_group[cell_num]
            cell.cursors[0].set_xdata((event.xdata, event.xdata))
            cell.cursors[0].set_visible(True)
            lim_min = max(int(self.sr/1e3*(event.xdata-IPIPK)), 0)
            ipi_auto = np.argmax(cell.im.get_data()[1][lim_min:int(self.sr/1e3*(event.xdata+IPIPK))]) + lim_min
            col = 'ipi_' + ('corr' if cell_num == 2 else 'ceps')
            self.df[self.offset[self.curr_ind[group_num], 0]][col + '_man'] = event.xdata
            self.df[self.offset[self.curr_ind[group_num], 0]][col + '_auto'] = ipi_auto*1e3/self.sr
            cell.axes.set_xlabel(f'{"Corr" if cell_num == 2 else "Ceps"} man:{event.xdata:.3f} '
                            f'auto:{ipi_auto*1e3/self.sr:.3f}')
        plt.draw()

    def on_clicked(self, event):
        if event.inaxes == self.fax:
            self._on_clicked_fax(event)
        for i, ax_group in enumerate(self.view_data):  # Look if a click plot was clicked and which one
            for j in range(len(ax_group)):
                if event.inaxes == ax_group[j].axes:
                    self._on_clicked_ax_group(event, ax_group, i, j)
                    return

    def change_curr(self, label):
        self.curr = int(label[-1])
        self.f_cursor.linev.set_color('rgbkcmyrgbkcmy'[self.curr])
        self.reset_b.label.set_c('rgbkcmyrgbkcmy'[self.curr])
        plt.draw()

    def toggle_ind(self, event):
        if not len(self.offset):
            return
        self.ind_select = not self.ind_select
        self.ind_b.color = 'limegreen' if self.ind_select else '0.85'
        self.ind_b.hovercolor = 'lime' if self.ind_select else '0.95'
        if self.ind_select:
            self.df[self.offset[self.curr_ind[self.curr], 0]]['ind_number'] = ''
        self.ind_b.label.set_text(f'Current individual:\n'
                                  f'{self.df[self.offset[self.curr_ind[self.curr], 0]]["ind_number"]}')
        plt.draw()

    def key_pressed(self, event):
        if self.ind_select:
            row = self.df[self.offset[self.curr_ind[self.curr], 0]]
            if event.key == 'backspace':
                row['ind_number'] = row['ind_number'][:-1]
            elif event.key in ['shift', 'control', 'alt']:
                pass
            else:
                row['ind_number'] = row['ind_number'] + event.key
            self.ind_b.label.set_text(f'Current individual:\n{row["ind_number"]}')
            plt.draw()
        else:
            if event.key in '012':
                self.change_curr(event.key)
                self.r_button.set_active(int(event.key))

    def play(self, event):
        sound = (norm(self.song_resample[self.p:self.p+FSSR*20])*(2**15-1)).astype(np.int16)
        try:
            play(AudioSegment(sound.tobytes(), frame_rate=FSSR, sample_width=sound.dtype.itemsize, channels=1))
        except KeyboardInterrupt:
            pass

    def resize(self, event):
        self.fax.get_figure().set_constrained_layout(True)
        plt.draw()
        plt.pause(0.2)
        self.fax.get_figure().set_constrained_layout(False)

    def onaxis(self, event):
        row = self.df[self.offset[self.curr_ind[self.curr], 0]]
        if row['onaxis'] == -1:
            row['onaxis'] = 1
        else:
            row['onaxis'] ^= 1
        self.onaxis_b.label.set_text('On-axis'
                                     if row['onaxis'] else 'Off-axis')
        plt.draw()

    def increase_freq(self, event):
        lim = self.view_data[0].spectrogram.axes.get_ylim()[1] + 1e3
        for ax_group in self.view_data:
            ax_group.spectrogram.axes.set_ylim(0, lim)
        plt.draw()

    def decrease_freq(self, event):
        lim = max(self.view_data[0].spectrogram.axes.get_ylim()[1] - 1e3, 1e3)
        for ax_group in self.view_data:
            ax_group.spectrogram.axes.set_ylim(0, lim)
        plt.draw()

    def _update_spectrogram(self, ax_group):
        click = ax_group.signal.im.get_ydata()
        spec, _, t = plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1, pad_to=2*self.nfft)
        spec = np.flipud(20*np.log10(spec))
        ax_group.spectrogram.im.set_data(spec)
        ax_group.spectrogram.im.set_clim(spec.max() - SPSC, spec.max())
        ax_group.spectrogram.im.set_extent([t[0] - 1/self.sr/2, t[-1] + 1/self.sr/2, 0., self.sr/2])

    def increase_res(self, event):
        if self.nfft > int(10e-3*self.sr):
            return
        self.nfft *= 2
        for ax_group in self.view_data:
            self._update_spectrogram(ax_group)
        if self.nfft > int(10e-3*self.sr):
            self.spec_b['plus_res'].label.set_text('Can\'t go\nhigher')
        else:
            self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins')
        self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins')
        plt.draw()

    def decrease_res(self, event):
        if self.nfft < 8:
            return
        self.nfft //= 2
        for ax_group in self.view_data:
            self._update_spectrogram(ax_group)
        self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins')
        if self.nfft < 8:
            self.spec_b['minus_res'].label.set_text('Can\'t go\nlower')
        else:
            self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins')
        plt.draw()

    def _set_label(self, ind=None, dic=None):
        if ind is None:
            ind = self.curr
        if dic is None:
            dic = self.df[self.offset[self.curr_ind[ind], 0]]
        ax_group = self.view_data[ind]
        ax_group.signal.axes.set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
        ax_group.correlation.axes.set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
        ax_group.cepstrum.axes.set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
        self.ind_b.label.set_text(f'Current individual:\n{dic["ind_number"]}')
        if self.onaxis_b is not None:
            self.onaxis_b.label.set_text(['Off', 'On', '?'][dic['onaxis']] + '-axis')

    def _set_visible(self, ind=None, state=False):
        if ind is None:
            ind = self.curr
        self.view_data[ind].set_visible(state)

    def reset_curr(self, event):
        self.df[self.offset[self.curr_ind[self.curr], 0]] = EMLN.copy()
        self._set_label()
        self._set_visible()
        self.curr_vert[self.curr] = 0
        self.cursor[self.curr].linev[0].set_linestyle('--')
        self.cursor[self.curr].linev[self.cursor[self.curr].num_cur].set_linestyle('--')
        self.view_data[self.curr].set_visible(False)
        plt.draw()

    def reset(self, song, sr, song_resample):
        self.p = 0
        self.df = dict()
        self.song = song
        self.song_resample = song_resample
        self._shit()
        sr_update = False
        if self.sr != sr:
            self.sr = sr
            sr_update = True
        self.change_curr('0')  # reset current view to 0
        self.r_button.set_active(0)
        self.offset = np.zeros((0, 2))
        self.scat.set_offsets(self.offset)
        self.scat.set_color([[0, 0, 0, 1]])
        self.curr_ind = self.num_view * [None]  # Ind of click for each plot
        self.curr_vert = self.num_view * [0]  # Current vertical line of sig/spec for each  plot
        for ax_group in self.view_data:
            ax_group.signal.im.set_ydata(np.zeros(int((10e-3 + self.after_length) * sr)))
            ax_group.correlation.im.set_ydata(np.zeros(int(10e-3 * sr)))
            ax_group.cepstrum.im.set_ydata(np.zeros(int(10e-3 * sr)))
            if sr_update:
                ax_group.signal.im.set_xdata(np.linspace(0, (10e-3 + self.after_length)*1e3,
                                                     int((10e-3 + self.after_length)*sr), False))
                ax_group.correlation.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
                ax_group.cepstrum.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
            ax_group.spectrogram.im.set_clim(2000, 2100)
        for i in range(self.num_view):
            self._set_label(i, EMLN)
            self._set_visible(i)
        self.ind_select = False
        self.ind_b.color = '0.85'
        self.ind_b.howercolor = '0.95'
        self.ind_b.label.set_text(f'Current individual:\nnan')
        plt.draw()


class AxesWithCursor:
    __slots__ = ['axes', 'im', 'cursors']

    def __init__(self, axes: plt.Axes, im=None, cursors=None):
        self.axes = axes
        self.im = im
        self.cursors = list() if cursors is None else cursors

    def set_visible(self, state):
        for cursor in self.cursors:
            cursor.set_visible(state)

    @property
    def figure(self):
        return self.axes.figure


class AxesGroup:
    __slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum']

    def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2, after_length=10e-3):
        self.signal = AxesWithCursor(ax_sig,
                                     ax_sig.plot(np.linspace(0, 10 + after_length*1e3, int((10e-3 + after_length)*sr),
                                                             False), np.zeros(int((10e-3 + after_length)*sr)))[0],
                                     (ax_sig.axvline(10, c='k', linestyle='--'),) +
                                     tuple(ax_sig.axvline(10, c='k') for _ in range(num_cur)))
        ax_sig.set_xlim(0, 10 + after_length*1e3)
        ax_sig.set_xlabel('IPI man:None')
        ax_sig.set_ylim(-1, 1)
        ax_sig.set_yticks([])
        self.spectrogram = AxesWithCursor(ax_spec,
                                          ax_spec.specgram(np.random.normal(0, 1e-6, int((10e-3 + after_length)*sr)),
                                                           Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1],
                                          (ax_spec.axvline(0.01, c='k', linestyle='--'),) +
                                          tuple(ax_spec.axvline(0.01, c='k') for _ in range(num_cur)))
        self.spectrogram.im.set_clim(2000, 2100)
        ax_spec.set_ylim(0, min(20e3, sr/2))
        ax_spec.set_yticks(ax_spec.get_yticks())  # Needed, otherwise updating ylim doesn't update ticks properly
        ax_spec.set_yticklabels((ax_spec.get_yticks() / 1e3).astype(int))
        self.correlation = AxesWithCursor(ax_corr,
                                          ax_corr.plot(np.linspace(0, 10, int(10e-3 * sr), False),
                                                       np.zeros(int(10e-3 * sr)))[0],
                                          (ax_corr.axvline(10, c='k'),))
        ax_corr.set_xlim(0, 10)
        ax_corr.set_xlabel('IPI man:None auto:None')
        ax_corr.set_ylim(-1, 1)
        ax_corr.set_yticks([])

        self.cepstrum = AxesWithCursor(ax_cep,
                                       ax_cep.plot(np.linspace(0, 10, int(10e-3 * sr), False),
                                                   np.zeros(int(10e-3 * sr)))[0],
                                       (ax_cep.axvline(10, c='k'),))
        ax_cep.set_xlim(0, 10)
        ax_cep.set_xlabel('IPI man:None auto:None')
        ax_cep.set_ylim(0, 0.33)
        ax_cep.set_yticks([])
        self.set_visible(False)

    def __getitem__(self, item):
        return (self.signal, self.spectrogram, self.correlation, self.cepstrum)[item]

    def __len__(self):
        return 4

    def __contains__(self, item):
        if isinstance(item, AxesWithCursor):
            return any(item == ax for ax in self)
        else:
            return any(item == ax.axes for ax in self)

    def set_visible(self, state):
        self.signal.set_visible(state)
        self.spectrogram.set_visible(state)
        self.correlation.set_visible(state)
        self.cepstrum.set_visible(state)

    @property
    def signal_ipi(self):
        return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0]


def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3, after_length=10e-3):
    song, sr, song_resample = load_file(in_path, channel, low, high)
    fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
    gs = fig.add_gridspec(12, 20)

    full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1])  # Type annotation needed with plt 3.1
    callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
                        song, song_resample, sr, full_sig, num_graph, after_length)
    callback.fax = full_sig
    full_sig.set_xlim(0, 20)
    lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2
    full_sig.set_ylim(-lim, lim)
    full_sig.set_yticks([])
    callback.f_cursor = Cursor(full_sig, horizOn=False, useblit=True, c='r')
    cid1 = fig.canvas.mpl_connect('button_release_event', callback.on_clicked)

    b_left_ax = plt.subplot(gs[:2, 0])
    b_right_ax = plt.subplot(gs[:2, -1])
    b_left = Button(b_left_ax, '<|')
    b_right = Button(b_right_ax, '|>')
    b_left.on_clicked(callback.shift_left)
    b_right.on_clicked(callback.shift_right)
    r_button_ax = plt.subplot(gs[10, 1:-1])
    r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(num_graph)], orientation='horizontal',
                              size=128, activecolor=list('rgbkcmyrgbkcmy'[:num_graph]))  # !Last char of labels is use as index in rest of code
    r_button_ax.axis('off')
    r_button.on_clicked(callback.change_curr)
    for i, c in enumerate('rgbkcmyrgbkcmy'[:num_graph]):
        r_button.labels[i].set_c(c)
    callback.r_button = r_button
    for v in "fullscreen,home,back,forward,pan,zoom,save,quit,grid,yscale,xscale,all_axes".split(','):
        plt.rcParams[f'keymap.{v}'] = [] # disable default shortcut but fullsreen
    cid2 = fig.canvas.mpl_connect('key_press_event', callback.key_pressed)
    # c_button_ax = plt.subplot(gs[10,3:6])
    # c_button = CheckButtons(c_button_ax, [f'Save {i}' for i in range(3)], [False for i in range(3)])

    vfs = 4
    vs = 2
    hs = 1
    hfs = (20-hs)//(2*num_graph)
    ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
                          plt.subplot(gs[vs:vs + vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
                          sr, num_cur, after_length) for i in range(num_graph)]

    play_b_ax = plt.subplot(gs[-1, :2])
    play_b = Button(play_b_ax, 'Play\ncurrent segment')
    play_b.on_clicked(callback.play)

    resize_b_ax = plt.subplot(gs[-1, 2:4])
    resize_b = Button(resize_b_ax, 'Resize plot')
    resize_b.on_clicked(callback.resize)

    reset_b_ax = plt.subplot(gs[-1,4:6])
    reset_b = Button(reset_b_ax, 'Reset current')
    reset_b.label.set_c('r')
    reset_b.on_clicked(callback.reset_curr)
    callback.reset_b = reset_b

    ind_b_ax = plt.subplot(gs[-1,6:8])
    ind_b = Button(ind_b_ax, 'Current individual:\nnan')
    ind_b.on_clicked(callback.toggle_ind)
    callback.ind_b = ind_b

    freq_p_b_ax = plt.subplot(gs[2, 0])
    freq_m_b_ax = plt.subplot(gs[3, 0])
    freq_res_p_b_ax = plt.subplot(gs[4, 0])
    freq_res_m_b_ax = plt.subplot(gs[5, 0])
    freq_p_b = Button(freq_p_b_ax, '+\n1kHz')
    freq_p_b.on_clicked(callback.increase_freq)
    freq_m_b = Button(freq_m_b_ax, '-\n1kHz')
    freq_m_b.on_clicked(callback.decrease_freq)
    freq_res_p_b = Button(freq_res_p_b_ax, '256\nbins')
    freq_res_p_b.on_clicked(callback.increase_res)
    freq_res_m_b = Button(freq_res_m_b_ax, '64\nbins')
    freq_res_m_b.on_clicked(callback.decrease_res)

    spec_button = {'plus': freq_p_b, 'minus': freq_m_b, 'plus_res': freq_res_p_b, 'minus_res': freq_res_m_b}
    callback.spec_b = spec_button

    m_cursor: List[Union[MyMultiCursor, None]] = len(ax_group) * [None]
    callback.cursor = m_cursor
    for i in range(len(ax_group)):
        m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], num_cur, useblit=True, c='k')
    callback.view_data = ax_group
    return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons':
            {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
             'ind_b': ind_b, 'fs_click': cid1, 'key_pressed': cid2, 'reset_b': reset_b, 'spec_button': spec_button}}


def reset(callback, in_path, channel, low=2e3, high=20e3):
    song, sr, song_resample = load_file(in_path, channel, low, high)
    callback.reset(song, sr, song_resample)


def main(args):
    if args.out == '':
        outpath = args.input.rsplit('.', 1)[0] + '.pred.h5'
    else:
        outpath = args.out
    if os.path.isfile(outpath):
        if not (args.erase or args.resume):
            print(f'Out file {outpath} already exist and erase or resume option isn\'t set.')
            return 1
    elif args.resume:
        print(f'Out file {outpath} does not already exist and resume option is set.')
        return 1
    if args.pulse < 1:
        print(f'{args.pulse} is an invalid number of pulses.')
        return 1
    elif args.click < 1:
        print(f'{args.click} is an invalid number of clicks.')
        return 1

    EMLN['onaxis'] = -1
    ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click, args.after*1e-3)
    if args.resume:
        df = pd.read_hdf(outpath)
        if 'onaxis' not in df.columns:
            df['onaxis'] = -1
        ref_dict['callback'].df = df.to_dict(orient='index')
        ref_dict['callback'].offset = np.tile(np.array(list(ref_dict['callback'].df.keys()))[:, np.newaxis], (1, 2))
        ref_dict['callback'].offset[:, 1] = ref_dict['callback'].song_resample[(ref_dict['callback'].offset[:, 1] * FSSR).astype(int)]
        ref_dict['callback'].scat.set_offsets(ref_dict['callback'].offset)
        colors = np.array(len(ref_dict['callback'].offset)*[[0.7, 0.2, 0.5, 1]])
        for i, p in enumerate(ref_dict['callback'].offset[:, 0]):
            if not np.all(pd.isnull(list(ref_dict['callback'].df[p].values()))):
                colors[i] = [0, 0, 0, 1]
        ref_dict['callback'].scat.set_color(colors)

    onaxis_b_ax = plt.subplot(ref_dict['gridspec'][-1, 8:10])
    onaxis_b = Button(onaxis_b_ax, '?-axis')
    onaxis_b.on_clicked(ref_dict['callback'].onaxis)
    ref_dict['callback'].onaxis_b = onaxis_b
    plt.draw()
    plt.pause(0.2)
    ref_dict['fig'].set_constrained_layout(False)
    plt.show()
    df = pd.DataFrame.from_dict(ref_dict['callback'].df, orient='index')
    df.to_hdf(outpath, 'df', format='table')
    return 0


if __name__ == '__main__':

    class ArgumentParser(argparse.ArgumentParser):
        def error(self, message):
            if message.startswith('the following arguments are required:'):
                raise ValueError(message)
            super(ArgumentParser, self).error(message)

    parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("input", type=str, help="Input file")
    parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
    parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
    parser.add_argument("--pulse", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
    parser.add_argument("--click", type=int, default=3, help="Number of click to plot simultaneously.")
    parser.add_argument("--after", type=float, default=10., help="Duration to plot after P1 in ms.")
    parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter')
    parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter')
    group = parser.add_mutually_exclusive_group()
    group.add_argument("--erase", action='store_true', help="If out file exist and this option is not given,"
                                                             " the computation will be halted")
    group.add_argument("--resume", action='store_true', help="If out file exist and this option is given,"
                                                             " the previous annotation file will be loaded")
    try:
        args = parser.parse_args()
    except ValueError as e:
        print(f'Error while parsing the command line arguments: {e}')

        def ask(string):
            y = {'y', 'yes', 'o', 'oui'}
            a = {'y', 'yes', 'o', 'oui', 'n', 'no', 'non'}
            while True:
                ans = input(string + ' [y/n] ').lower()
                if ans in a:
                    return ans in y

        if not ask('Do you want to manually specify them?'):
            sys.exit(2)  # exit code of invalid argparse

        class VirtualArgParse(object):
            def __init__(self):
                self.input = input("What is the input file path? ")
                self.out = input("What is the out file path? (Leave empty for default)")
                self.channel = int(input("Which channel do you want to use starting from 0? "))
                self.pulse = int(input("How many pulses do you want to display (> 1) "))
                self.click = int(input("How many clicks do you want to display (> 1) "))
                self.after = float(input("What duration do you want after P1?"))
                self.low = float(input("What is the lower frequency of the bandpass? "))
                self.up = float(input("What is the upper frequency of the bandpass? "))
                self.erase = ask("Do you want to erase the out file if it already exist?")
                if not self.erase:
                    self.resume = ask("Do you want to resume the out file if it already exist?")
                else:
                    self.resume = False

        args = VirtualArgParse()

    sys.exit(main(args))