##### IMPORTATIONS #####
import os 
import numpy as np
from tkinter import *
from tkinter import filedialog as fd
from tkinter import ttk
from matplotlib.backends.backend_tkagg import (NavigationToolbar2Tk, 
    FigureCanvasTkAgg)
from matplotlib.figure import Figure
import matplotlib.colors as mc
from line_clicker.line_clicker import clicker

# Import external functions
from functions import load_waveform, wave_to_spectrogram, save_dict


##### CLASSES #####
class FileExplorer(object):
    """
    A Class that opens a file explorer when it runs in an active Tkinter loop.
    
    ...

    Parameters
    ----------
    path : str
        Path to a folder in which the file explorer will be opened.

    Attributes
    ----------
    file : str
        Path to a file selected by the user in the file explorer window.

    Methods
    -------
    explorer_window():
        Calls the tkinter functions that opens a file explorer window.

    """

    def __init__(self, path):
        """
        Constructs all the necessary attributes for the FileExplorer object.

        Parameters
        ----------
        path : str
            Path to a folder in which the file explorer will be opened.
        file : str
            Path to a file selected by the user in the file explorer window.
        """
        self.path = path        # folder or file to be opened 
        self.explorer_window()  # start function auto

    def explorer_window(self):
        """
        Calls the tkinter function that opens a file explorer window.
        Affect the select pass to 'file' attribute.

        ...

        Parameters
        ----------
        path : str
            Path to the directory in which the file explorer will be opened.
        """ 
        if os.path.splitext(self.path)[1] == "":
            self.file = fd.askopenfilename(
                title='Open a file',
                initialdir=self.path,
                filetypes=(
                    ('Audio Files', '*.wav'),
                    ('All files', '*.*')
                    ))
        else:
            self.file = fd.askopenfilename(
                title='Open a file',
                initialdir=os.path.dirname(self.path),
                initialfile=os.path.basename(self.path),
                filetypes=(
                    ('Audio Files', '*.wav'),
                    ('All files', '*.*')
                    ))

class App(object):
    """
    A Class to construct an contours annotation tool for audio data.
    
    ...

    Parameters
    ----------
    DIR : str
        Path to a folder in which the file explorer will be opened.
    DIR_OUT : str
        Path to a folder where the contours will be saved.
    MAX_C : int
        Maximum number of contours that can be drawn at once.
    NEW_SR : int
        Resampling rate of the audio recording.
    WAVEFILE : str
        Path to the audio recording to be opened. Should be a '.wav' file.
    coords_to_modify : dict
        Coordinates of points (from a previous annotation) that can be used
        as input to add modifications.

    Attributes
    ----------
    _default_bounds : list of float
        Boundaries for the matplotlib canvas.
        (Default value is loaded from 'PARAMETERS.py' file).
    _default_clipping : int
        Default clipping value for spectrogram, in dB.
        (Default value is loaded from 'PARAMETERS.py' file).
    _default_cmap : str
        Name of a matplotlib.pyplot color map.
        (Default value is loaded from 'PARAMETERS.py' file).
    _default_height : int
        Default height of the window in which the app will run, in pixels.
        (Default value is loaded from 'PARAMETERS.py' file).
    _default_hop_length : int
        Default hop length for spectrogram, in samples.
        (Default value is loaded from 'PARAMETERS.py' file).
    _default_left_panel_width : int
        Default width for the left panel of the window of the app.
        (Default value is loaded from 'PARAMETERS.py' file).
    _default_nfft : int
        Default fft size for spectrogram, in samples.
        (Default value is loaded from 'PARAMETERS.py' file).
    _default_width: int
        (Default) width of the window in which the app will run, in pixels.
        (Default value is loaded from 'PARAMETERS.py' file).

    canvas : matplotlib object
        Interface to include matplotlib plot in tkinter canvas.
    CHECK_bspline : tkinter int variable
        User checkbox about plotting curves or not.
    CLIP_IN : tkinter float variable (Double)
        User input for clipping value.
    FFT_IN : tkinter int variable
        User input for fft.
    figure, axis, data_showed : matplotlib objects
        Objects used to show matplotlib.pyplot plot.
    HOP_IN : tkinter int variable
        User input for hop length.
    klicker : mpl_point_clicker instance
        Adds widgets to matplotlib plot that allows to draw contours.
    NAME0 : int
        Number used to name the first contour available for annotation.
    NAME1 : int
        Number used to name the last contour available for annotation.
    OPTIONS : tkinter list variable
        User listbox to select category item.
    root : tkinter Tk instance
        Initialises tkinter interpreter and creates root window.
    spectrogram : numpy array
        Spectrogram of the waveform.
    waveform :  numpy array
        Waveform of the audio recording.

    Other attributes, buttons and labels have self explenatory names.

    Methods
    -------    
    bspline_activation():
        Activates/deactivates the visualisation of lines as curves.
    create_canvas():
        Creates matplotlib figure to show spectrogram in tkinter canvas.
    entry_setup():
        Creates variables that save inputs from the user in the entry fields.
    get_key_pressed(event):
        Updates plot and tkinter interface 
        when a key is pressed to add a new category.
    layout():
        Lays the main structure of the tkinter window.
    link_select(event):
        Changes the focus to be on a new category, corresponding to
        the selected item in listbox widget.
    load_audio():
        Loads audio data. Waveform and spectrogram.
    select_file():
        Opens a file explorer window to select a new wavefile. 
        Saves contours if a new file is selected.
        Updates the canvas to show the new spectrogram.
    setup():
        Loads default variables to local variables.
    submit():
        Loads user inputs to local variables.
    switch(self)
        Updates spectrogram displayed to PCEN (and conversely).
    _frame_listbox_scroll():
        Just a callable part of layout()
    _quit():
        Saves contours and closes the app.
    """

    from parameters import (_default_width, _default_height, _default_hop_length, 
        _default_nfft, _default_clipping, _default_cmap, _default_bounds,
        _default_left_panel_width)

    def __init__(
        self, 
        DIR, 
        MAX_C, 
        NEW_SR, 
        DIR_OUT, 
        WAVEFILE, 
        coords_to_modify={}):

        # init variables
        self.DIR = DIR
        self.MAX_C = MAX_C
        self.NEW_SR = NEW_SR
        self.DIR_OUT = DIR_OUT
        self.WAVEFILE = WAVEFILE
        self.NAME0 = 0
        self.NAME1 = MAX_C
        self.setup()

        # load audio data
        self.load_audio()

        # init interface
        self.root = Tk()
        self.root.style = ttk.Style()
        self.root.style.theme_use('clam')
        self.create_canvas()

        # addon
        self.klicker = clicker(
            axis=self.axis,
            names=["Line" + str(i+1) for i in range(self.NAME0, self.NAME1)],
            bspline='quadratic', maxlines=99, legend_bbox=(2,0.5),
            coords=coords_to_modify)

        # main loop
        self.entry_setup()
        self.layout()
        self.axis.set_position(self._default_bounds)

        # To avoid probles, disconnect matplotlib keypress
        # and replace it with tkinter keypress.
        self.figure.canvas.mpl_disconnect(self.klicker.key_press) 
        self.root.bind('<Key>', self.get_key_pressed)

        self.root.mainloop()

    def bspline_activation(self):
        """
        Activates/deactivates the visualisation of lines as curves.

        ...

        Returns
        -------
        None : Updates klicker. 
        (It uses the "wait" parameter to force straigth lines).
        """
        if self.CHECK_bspline.get():
            self.klicker.wait = 2
        else:
            self.klicker.wait = np.inf
        self.klicker.update_lines()
        self.klicker.figure.canvas.draw()

    def create_canvas(self):
        """
        Creates a figure based on imported spectrogram.

        ...

        Returns
        -------
        None : Creates figure, axis, data_showed and canvas variables.
        """
        self.figure = Figure(figsize=(16, 9))
        self.axis = self.figure.add_subplot()
        self.data_showed = self.axis.imshow(
            self.spectrogram[::-1], 
            cmap=self._default_cmap,
            interpolation='nearest', aspect='auto',
            extent=(0, self.audio_duration, 0, self.NEW_SR/2))
        self.data_showed.set_clim(
            vmin=np.nanmin(self.spectrogram), 
            vmax=np.nanmax(self.spectrogram))
        self.axis.set_xlabel("Time (in sec)")
        self.axis.set_ylabel("Frequencies (in Hz)")
        self.axis.set_title(f"Spectrogram of {os.path.basename(self.WAVEFILE)}")
        self.axis.set_position(self._default_bounds)
        self.figure.set_facecolor("gainsboro")
        self.canvas = FigureCanvasTkAgg(self.figure, master=self.root)
        self.canvas.draw()

    def entry_setup(self):
        """
        Creates tkinter variables that will be used in entry fields.
        (Objects that are specific to tkinter)

        ...

        Returns
        -------
        None : Creates FFT_IN, HOP_IN, CHECK_bspline, CLIP_IN and OPTIONS,
        variables that are tkinter variables (3 integers, 1 float, 1 list).
        """
        self.FFT_IN = IntVar(value=self._default_nfft)
        self.HOP_IN = IntVar(value=self._default_hop_length)
        self.CHECK_bspline = IntVar(value=1)
        self.CLIP_IN = DoubleVar(value=self._default_clipping)
        self.OPTIONS = Variable(value=self.klicker.legend_labels)

    def get_key_pressed(self, event):
        """
        Updates plot and tkinter interface 
        when a key is pressed to add a new category.

        ...

        Parameters
        ----------
        event : tkinter event
            tkinter object containing the name of the key pressed.

        Returns
        -------
        None : updates klicker, listbox, axis and figure.
        """
        class EmptyObject(object):
            """
            Empty class that is just a hacky of creating an object that
            can be used in matplotlib.
            """
            pass

        # create key attribute and use it    
        dummy_event = EmptyObject()
        dummy_event.key = event.char
        self.klicker.get_key_event(dummy_event, show=False)

        # if a category is added. Update listbox and canvas.
        if len(self.klicker.legend_labels) > self.listbox.size():
            self.listbox.insert(
                self.listbox.size(), 
                self.klicker.legend_labels[-1])
            self.listbox.itemconfig(
                self.listbox.size()-1, 
                {
                'bg': self.klicker.colors[
                    (self.listbox.size()-1)%len(self.klicker.colors)],
                'selectbackground': mc.to_hex(tuple([min(0.1+x,1) 
                    for x in mc.to_rgb(
                        self.klicker.colors[
                        (self.listbox.size()-1)%len(self.klicker.colors)])])),
                'selectforeground': 'white'})
            self.listbox.select_clear(0, END)
            self.listbox.select_set(self.listbox.size()-1)
            self.listbox.see(self.listbox.size()-1)
            self.listbox.activate(self.listbox.size()-1)
            self.listbox.selection_anchor(self.listbox.size()-1)
        self.axis.set_position(self._default_bounds)
        self.figure.canvas.draw()

    def layout(self):
        """
        This *long* function lays the structure of the tkinter interface
        
        ...

        Returns
        -------
        None : 
            Updates root
            Creates list_label, frame_list, listbox, scrollbar, empty_frame,
            activate_bspline, fft_label, fft_entry,win_label, win_entry, 
            clip_label, clip_entry, submit_button, quit_button, explore_button
            toolbarFrame, toolbar, loading_screen
        """
        # configure main window
        self.root.wm_title("PyAVA interface")         
        self.root.geometry(
            f"{str(self._default_width)}x{str(self._default_height)}")
        self.root.rowconfigure(1, weight=1)
        self.root.rowconfigure(14, weight=1)
        self.root.configure(bg='gainsboro')

        # Add Panel for line selection on Left side
        self.list_label = Label(
            self.root, 
            width=self._default_left_panel_width,
            text='Pick a line to draw.\n(Shift+a adds a new line)',
            font=('calibre',10,'bold'))
        self.list_label.grid(row=2, column=0)

        self._frame_listbox_scroll()

        self.activate_bspline = Checkbutton(
            self.root, 
            text='Activate interpolation',
            variable=self.CHECK_bspline,
            command=self.bspline_activation)
        self.activate_bspline.grid(row=4, column=0)

        # Add space between panels
        self.empty_frame = Label(
            self.root, 
            width=self._default_left_panel_width,
            height=30)
        self.empty_frame.grid(row=5, column=0)

        # Add panel for spectrogram personalisation on Left side.
        self.fft_label = Label(
            self.root, 
            width=self._default_left_panel_width,
            text='FFT window size:', 
            font=('calibre',10, 'bold'))
        self.fft_label.grid(row=6, column=0)

        self.fft_entry = Entry(
            self.root, 
            width=self._default_left_panel_width,
            textvariable=self.FFT_IN, 
            font=('calibre',10,'normal'))
        self.fft_entry.grid(row=7, column=0)

        self.win_label = Label(
            self.root, 
            width=self._default_left_panel_width,
            text='Hop length:', 
            font=('calibre',10,'bold'))
        self.win_label.grid(row=8, column=0)

        self.win_entry = Entry(
            self.root, 
            width=self._default_left_panel_width,
            textvariable=self.HOP_IN, 
            font=('calibre',10,'normal'))
        self.win_entry.grid(row=9, column=0)

        self.clip_label = Label(
            self.root, 
            width=self._default_left_panel_width,
            text='Clipping (dB):',
            font=('calibre',10,'bold'))
        self.clip_label.grid(row=10, column=0)

        self.clip_entry = Entry(
            self.root, 
            width=self._default_left_panel_width,
            textvariable=self.CLIP_IN,
            font=('calibre',10,'normal'))
        self.clip_entry.grid(row=11, column=0)

        self.submit_button = Button(
            self.root, 
            text='Update display',
            width=self._default_left_panel_width,
            command=self.submit)
        self.submit_button.grid(row=12, column=0)

        self.switch_view_button = Button(
            self.root,
            text="Switch to PCEN",
            width=self._default_left_panel_width,
            command=self.switch)
        self.switch_view_button.grid(row=13, column=0)

        # Add buttons at the bottom of the interface
        self.quit_button = Button(
            self.root, 
            text="Save & Quit", 
            command=self._quit)
        self.quit_button.grid(row=15, column=0)

        self.explore_button = Button(
            self.root,
            text="Open file explorer",
            command=self.select_file)
        self.explore_button.grid(row=15, column=1)

        # Add matplotlib tools at the top of the interface
        self.toolbarFrame = Frame(self.root)
        self.toolbar = NavigationToolbar2Tk(self.canvas, self.toolbarFrame)
        self.toolbar.update()
        self.toolbarFrame.grid(row=0, column=1, sticky='W')

        # Add main panel : canvas.
        self.canvas.get_tk_widget().grid(row=1, column=1, rowspan=14)
        self.loading_screen = Label(
            self.root,
            text="LOADING SPECTROGRAM... \nThis can take a few seconds.",
            font=("gothic", 30),
            justify=LEFT)

    def link_select(self, event):
        """
        Changes the focus to be on a new category, corresponding to
        the selected item in listbox widget.

        ...

        Parameters
        ----------
        event : tkinter object
            event containing the item clicked in listbox widget.

        Returns
        -------
        None : Updates klicker.
        """
        if len(event.widget.curselection()):
            self.klicker.current_line = event.widget.curselection()[0]

            # Manually update display
            for legend_line in self.klicker.legend.get_lines():
                legend_line.set_alpha(0.2)
            self.klicker.legend.get_lines()[self.klicker.current_line].set_alpha(1)
            self.klicker.figure.canvas.draw()  
    
    def load_audio(self):
        """
        A class to load the waveform and spectrogram of a wavefile.

        ...

        Returns
        -------
        None : Creates waveform and spectrogram arrays.
        """
        self.waveform = load_waveform(self.WAVEFILE, self.NEW_SR)
        self.spectrogram, self.audio_duration = wave_to_spectrogram(
            self.waveform,
            self.NEW_SR,
            self.NFFT,
            self.HOP_LENGTH,
            self.CLIPPING)

    def select_file(self):
        """
        A function that calls a new window to select a wavefile.
        Then replaces spectrogram in canvas using the newly selected file.

        ...

        Returns
        -------
        None : If a new file is selected, saves current coordinates to json file
        and generate a new window for annotation.
        """
        new_wavefile = FileExplorer(self.WAVEFILE).file

        if len(new_wavefile) > 0 :
            # save current coords
            save_dict(self.klicker.coords, self.DIR_OUT,
                os.path.basename(self.WAVEFILE)[:-4]+"-contours.json")
            self.WAVEFILE = new_wavefile

            # display loading scree
            self.loading_screen.grid(row=1, column=1, rowspan=14)
            self.canvas.get_tk_widget().destroy()

            # load new data
            self.setup()
            self.load_audio()
            self.create_canvas()         
            self.NAME0 = 0
            self.NAME1 = self.MAX_C

            # display new data
            self.klicker = clicker(
                axis=self.axis,
                names=["Line" + str(i+1) for i in range(self.NAME0, self.NAME1)],
                bspline='quadratic', maxlines=99, legend_bbox=(2,0.5)) 
            self.axis.set_position(self._default_bounds)
            self.figure.canvas.mpl_disconnect(self.klicker.key_press)

            # update interface
            self.toolbar.destroy()
            self.entry_setup()
            self.layout()
            self.loading_screen.grid_forget()

    def setup(self):
        """
        A function to create variables based on default values
        """
        self.NFFT = self._default_nfft
        self.HOP_LENGTH = self._default_hop_length
        self.CLIPPING = self._default_clipping

    def submit(self):
        """
        A function that fetches the new values in entry fields.
        Updates spectrogram accordingly.

        ---

        Returns
        -------
        None : Updates spectrogram and data_showed 
        according to new fft, hop_length and clipping values.
        """
        if ((self.FFT_IN.get() != self.NFFT) or
            (self.HOP_IN.get() != self.HOP_LENGTH) or
            (self.CLIP_IN.get() != self.CLIPPING)):

            self.NFFT = self.FFT_IN.get()
            self.HOP_LENGTH = self.HOP_IN.get()
            self.CLIPPING = self.CLIP_IN.get()

            self.spectrogram, self.audio_duration = wave_to_spectrogram(
                self.waveform, 
                self.NEW_SR, 
                self.NFFT, 
                self.HOP_LENGTH, 
                self.CLIPPING,
                as_pcen=(False if self.switch_view_button['text'] == "Switch to PCEN"
                    else True))
            self.data_showed.set_data(self.spectrogram[::-1])
            self.data_showed.set_clim(
                vmin=np.nanmin(self.spectrogram), 
                vmax=np.nanmax(self.spectrogram))
            self.canvas.draw()

    def switch(self):
        """
        Updates spectrogram displayed to PCEN (and conversely).

        ---

        Returns
        -------
        None : Updates switch_view_button, spectrogram and data_showed.
        """
        current_text = self.switch_view_button['text']

        if current_text == "Switch to PCEN":
            self.spectrogram, _ = wave_to_spectrogram(
                self.waveform,
                self.NEW_SR,
                self.NFFT,
                self.HOP_LENGTH,
                self.CLIPPING,
                as_pcen=True)
            self.switch_view_button['text'] = "Switch to Spectrogram"
        else:
            self.spectrogram, _ = wave_to_spectrogram(
                self.waveform,
                self.NEW_SR,
                self.NFFT,
                self.HOP_LENGTH,
                self.CLIPPING)
            self.switch_view_button['text'] = "Switch to PCEN"

        self.data_showed.set_data(self.spectrogram[::-1])
        self.data_showed.set_clim(
            vmin=np.nanmin(self.spectrogram), 
            vmax=np.nanmax(self.spectrogram))
        self.canvas.draw()

    def _frame_listbox_scroll(self):
        """
        Just a callable part of "layout"
        """
        self.frame_list = Frame(
            self.root, 
            width=self._default_left_panel_width, 
            height=50)
        self.frame_list.grid(row=3, column=0)

        self.listbox = Listbox(
            self.frame_list,
            height=9, 
            width=self._default_left_panel_width,
            selectmode=SINGLE, 
            listvariable=self.OPTIONS)
        for idx in range(len(self.klicker.legend_labels)):
            self.listbox.itemconfig(idx, 
                {
                'bg': self.klicker.colors[idx%len(self.klicker.colors)],
                'selectbackground': mc.to_hex(tuple([min(0.1+x,1) 
                    for x in mc.to_rgb(
                        self.klicker.colors[idx%len(self.klicker.colors)])])),
                'selectforeground': 'white'})
        self.listbox.pack(side="left", fill="y")
        self.listbox.bind("<<ListboxSelect>>", self.link_select)        
        self.listbox.select_set(0)

        self.scrollbar = Scrollbar(self.frame_list, orient="vertical")
        self.scrollbar.config(command=self.listbox.yview)
        self.scrollbar.pack(side="right", fill="y")
        self.listbox.config(yscrollcommand=self.scrollbar.set)

    def _quit(self):
        """
        A function that saves coordinates of lines before closing app.

        ...

        Returns
        -------
        None : Saves coords in a json file, quits and destroys root.
        """
        save_dict(
            self.klicker.coords, 
            self.DIR_OUT,
            os.path.basename(self.WAVEFILE)[:-4]+"-contours.json")

        self.root.quit()
        self.root.destroy()