From f95db74adc5b94a72c21ede81abb4e6d10eddcca Mon Sep 17 00:00:00 2001
From: maxence <maxence.ferrari@gmail.com>
Date: Mon, 28 Nov 2022 11:52:59 +0100
Subject: [PATCH] Reformat code

---
 ipi_extract.py | 496 +++++++++++++++++++++++++------------------------
 1 file changed, 257 insertions(+), 239 deletions(-)

diff --git a/ipi_extract.py b/ipi_extract.py
index 6e0d5a0..41aaa4b 100644
--- a/ipi_extract.py
+++ b/ipi_extract.py
@@ -10,12 +10,13 @@ from fractions import Fraction
 from pydub import AudioSegment
 from pydub.playback import play
 import pandas as pd
+from typing import List, Union
 
 FSSR = 48_000  # Sampling rate of full signal plot
 FSPK = 0.1     # Max distance to detect a click in full sig in seconds
-IPIPK= 0.15    # Max distance to detect a IPI in milliseconds
+IPIPK = 0.15   # Max distance to detect a IPI in milliseconds
 SPSC = 80      # Spectrogram scale
-EMLN = {'p1_pos':np.nan, 'ipi_sig': np.nan,
+EMLN = {'p1_pos': np.nan, 'ipi_sig': np.nan,
         'ipi_corr_man': np.nan, 'ipi_corr_auto': np.nan,
         'ipi_ceps_man': np.nan, 'ipi_ceps_auto': np.nan,
         'ind_number': np.nan}  # Empty dataline
@@ -138,10 +139,8 @@ class MyMultiCursor(AxesWidget):
 
         if self.useblit:
             lineprops['animated'] = True
-        #self.lineh = []
         self.linev = []
         for i in range(self.n_axes):
-            #self.lineh.append(axes[i].axhline(axes[i].get_ybound()[0], visible=False, **lineprops))
             self.linev.append(axes[i].axvline(axes[i].get_xbound()[0], visible=False, **lineprops))
 
         self.background = self.n_axes*[None]
@@ -155,7 +154,6 @@ class MyMultiCursor(AxesWidget):
             if self.useblit:
                 self.background[i] = self.canvas.copy_from_bbox(self.axes[i].bbox)
             self.linev[i].set_visible(False)
-            #self.lineh[i].set_visible(False)
 
     def onmove(self, event):
         """Internal event handler to draw the cursor when the mouse moves."""
@@ -166,8 +164,6 @@ class MyMultiCursor(AxesWidget):
         if event.inaxes not in self.axes:
             for i in range(self.n_axes):
                 self.linev[i].set_visible(False)
-                #self.lineh[i].set_visible(False)
-
             if self.needclear:
                 self.canvas.draw()
                 self.needclear = False
@@ -197,9 +193,7 @@ class MyMultiCursor(AxesWidget):
                         self.linev[i].set_xdata((pos / 1e3, pos / 1e3))
                     else:
                         self.linev[i].set_xdata((ipi_pos, ipi_pos))
-                    # self.lineh[i].set_ydata((event.ydata, event.ydata))
                     self.linev[i].set_visible(self.visible and self.vertOn)
-                    # self.lineh[i].set_visible(self.visible and self.horizOn)
                 self._update()
                 return
         else:
@@ -219,10 +213,7 @@ class MyMultiCursor(AxesWidget):
                 self.linev[i].set_visible(self.visible and self.vertOn)
             if i > 1:
                 self.linev[i].set_xdata((ipi_pos, ipi_pos))
-                #self.lineh[i].set_ydata((event.ydata, event.ydata))
                 self.linev[i].set_visible(self.visible and self.vertOn)
-                #self.lineh[i].set_visible(self.visible and self.horizOn)
-
         self._update()
 
         if restore:
@@ -235,7 +226,6 @@ class MyMultiCursor(AxesWidget):
                 if self.background[i] is not None:
                     self.canvas.restore_region(self.background[i])
                 self.axes[i].draw_artist(self.linev[i])
-                #self.axes[i].draw_artist(self.lineh[i])
                 self.canvas.blit(self.axes[i].bbox)
         else:
             self.canvas.draw_idle()
@@ -251,12 +241,12 @@ class Callback(object):
         self.song_resample = song_resample
         self.sr = sr
         self.fax = full_ax
-        self.view_ax = None
-        self.view_data = None
+        self.num_view = 3
+        self.view_data: [AxesGroup, None] = None
         self.curr = 0  # current view selected
-        self.scat = self.fax.scatter([],[], marker='x', c=[[0.7, 0.2, 0.5, 1]])
-        self.offset = np.zeros((0,2))
-        self.curr_ind = 3*[None]  # Ind of click for each plot
+        self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]])
+        self.offset = np.zeros((0, 2))
+        self.curr_ind: List[Union[int, None]] = 3*[None]  # Indices of click for each plot
         self.curr_vert = 3*[0]    # Current vertical line of sig/spec for each  plot
         self.cursor = None
         self.f_cursor = None
@@ -284,130 +274,132 @@ class Callback(object):
         self.fax.set_xlim(self.p/FSSR, self.p/FSSR+20)
         plt.draw()
 
-
-    def on_clicked(self, event):
-        if event.inaxes == self.fax:  # Click on full signal plot
-            pos = int(FSSR*event.xdata)
-            mpos = np.argmax(self.song_resample[max(pos-int(FSSR*FSPK),0):pos+int(FSSR*FSPK)]) + max(pos-int(FSSR*FSPK),0)
-            if self.curr_ind[self.curr] is not None:
-                if np.all(pd.isnull(list(self.df[self.offset[self.curr_ind[self.curr], 0]].values()))):
-                    self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1]
-                else:
-                    self.scat._facecolors[self.curr_ind[self.curr]] = [0, 0, 0, 1]
-            if mpos/FSSR not in self.offset[:,0]:
-                self.offset = np.concatenate([self.offset, [[mpos/FSSR, self.song_resample[mpos]]]], axis=0)
-                self.scat.set_offsets(self.offset)
-                self.df[mpos/FSSR] = EMLN.copy()
-                c = [[0, 0, 0, 1]]
-                c[0][self.curr] = 1
-                if len(self.offset) == 1:
-                    self.scat.set_color(c)
-                else:
-                    self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0))
-                self.curr_ind[self.curr] = len(self.offset) - 1
-                self.curr_vert[self.curr] = 0
-                self.cursor[self.curr].linev[0].set_linestyle('--')
-                self.cursor[self.curr].linev[1].set_linestyle('--')
-                for i in range(4):
-                    if i < 2:
-                        self.view_data[self.curr][i][1][0].set_visible(False)
-                        self.view_data[self.curr][i][1][1].set_visible(False)
-                    else:
-                        self.view_data[self.curr][i][1].set_visible(False)
+    def _on_clicked_fax(self, event):
+        # Click on full signal plot
+        pos = int(FSSR * event.xdata)
+        mpos = np.argmax(self.song_resample[max(pos - int(FSSR * FSPK), 0):
+                                            pos + int(FSSR * FSPK)]) + max(pos - int(FSSR * FSPK), 0)
+        ax_group = self.view_data[self.curr]
+        if self.curr_ind[self.curr] is not None:
+            if np.all(pd.isnull(list(self.df[self.offset[self.curr_ind[self.curr], 0]].values()))):
+                self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1]
             else:
-                self.curr_ind[self.curr] = np.argmax(mpos/FSSR == self.offset[:,0])
-                c = [0, 0, 0, 1]
-                k = 0
-                for i, v in enumerate(self.curr_ind):
-                    if v == self.curr_ind[self.curr]:
-                        c[i] = 1
-                        k += 1
-                for i in range(3):
-                    c[i] /=k
-                self.scat._facecolors[self.curr_ind[self.curr]] = c
-                self.scat.set_color(self.scat._facecolors)
-                row = self.df[mpos/FSSR]
-                if np.isnan(row['p1_pos']):
-                    self.view_data[self.curr][0][1][0].set_visible(False)
-                    self.view_data[self.curr][0][1][1].set_visible(False)
-                    self.view_data[self.curr][1][1][0].set_visible(False)
-                    self.view_data[self.curr][1][1][1].set_visible(False)
-                else:
-                    self.view_data[self.curr][0][1][0].set_xdata((row['p1_pos'], row['p1_pos']))
-                    self.view_data[self.curr][1][1][0].set_xdata((row['p1_pos']*1e-3, row['p1_pos']*1e-3))
-                    self.view_data[self.curr][0][1][0].set_visible(True)
-                    self.view_data[self.curr][1][1][0].set_visible(True)
-                    if np.isnan(row['ipi_sig']):
-                        self.view_data[self.curr][0][1][1].set_visible(False)
-                        self.view_data[self.curr][1][1][1].set_visible(False)
-                    else:
-                        p2 = row['p1_pos'] + row['ipi_sig']
-                        self.view_data[self.curr][0][1][1].set_xdata((p2, p2))
-                        self.view_data[self.curr][1][1][1].set_xdata((p2*1e-3, p2*1e-3))
-                        self.view_data[self.curr][0][1][1].set_visible(True)
-                        self.view_data[self.curr][1][1][1].set_visible(True)
-                if np.isnan(row['ipi_corr_man']):
-                    self.view_data[self.curr][2][1].set_visible(False)
+                self.scat._facecolors[self.curr_ind[self.curr]] = [0, 0, 0, 1]
+        if mpos / FSSR not in self.offset[:, 0]:
+            self.offset = np.concatenate([self.offset, [[mpos / FSSR, self.song_resample[mpos]]]], axis=0)
+            self.scat.set_offsets(self.offset)
+            self.df[mpos / FSSR] = EMLN.copy()
+            c = [[0, 0, 0, 1]]
+            c[0][self.curr] = 1
+            if len(self.offset) == 1:
+                self.scat.set_color(c)
+            else:
+                self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0))
+            self.curr_ind[self.curr] = len(self.offset) - 1
+            self.curr_vert[self.curr] = 0
+            self.cursor[self.curr].linev[0].set_linestyle('--')
+            self.cursor[self.curr].linev[1].set_linestyle('--')
+            for i in range(4):
+                if i < 2:
+                    ax_group[i][1][0].set_visible(False)
+                    ax_group[i][1][1].set_visible(False)
                 else:
-                    self.view_data[self.curr][2][1].set_xdata((row['ipi_corr_man'], row['ipi_corr_man']))
-                    self.view_data[self.curr][2][1].set_visible(True)
-                if np.isnan(row['ipi_ceps_man']):
-                    self.view_data[self.curr][3][1].set_visible(False)
+                    ax_group[i][1].set_visible(False)
+        else:
+            self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0])
+            c = [0, 0, 0, 1]
+            k = 0
+            for i, v in enumerate(self.curr_ind):
+                if v == self.curr_ind[self.curr]:
+                    c[i] = 1
+                    k += 1
+            for i in range(self.num_view):
+                c[i] /= k
+            self.scat._facecolors[self.curr_ind[self.curr]] = c
+            self.scat.set_color(self.scat._facecolors)
+            row = self.df[mpos / FSSR]
+            if np.isnan(row['p1_pos']):
+                ax_group.signal.set_visible(False)
+                ax_group.spectrogram.set_visible(False)
+            else:
+                ax_group.signal.cursors[0].set_xdata((row['p1_pos'], row['p1_pos']))
+                ax_group.spectrogram.cursors[0].set_xdata((row['p1_pos'] * 1e-3, row['p1_pos'] * 1e-3))
+                ax_group.signal.cursors[0].set_visible(True)
+                ax_group.spectrogram.cursors[0].set_visible(True)
+                if np.isnan(row['ipi_sig']):
+                    ax_group.signal.cursors[1].set_visible(False)
+                    ax_group.spectrogram.cursors[1].set_visible(False)
                 else:
-                    self.view_data[self.curr][3][1].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man']))
-                    self.view_data[self.curr][3][1].set_visible(True)
-            self.ind_select = False
-            self.ind_b.color = '0.85'
-            self.ind_b.howercolor = '0.95'
-            self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos/FSSR]["ind_number"]}')
-            click = self.song[max(int(mpos*self.sr/FSSR-10e-3*self.sr),0):int(mpos*self.sr/FSSR+10e-3*self.sr)]
-            if len(click) != 2*int(10e-3*self.sr):
-                np.pad(click, (0, 2*int(10e-3*self.sr) - len(click)), mode='constant')
-            self.view_data[self.curr][0][0].set_ydata(norm(click))
-            spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1)[0]))
-            self.view_data[self.curr][1][0].set_data(spec)
-            self.view_data[self.curr][1][0].set_clim(spec.max()-SPSC, spec.max())
-            self.view_data[self.curr][2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3*self.sr):]))
-            self.view_data[self.curr][3][0].set_ydata(norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3*self.sr)])))
-            self._set_label()
-            plt.draw()
-            return
-        for i in range(3):  # Look if a click plot was clicked and which one
-            for j in range(4):
-                if event.inaxes == self.view_ax[i][j]:
-                    break
+                    p2 = row['p1_pos'] + row['ipi_sig']
+                    ax_group.signal.cursors[1].set_xdata((p2, p2))
+                    ax_group.spectrogram.cursors[1].set_xdata((p2 * 1e-3, p2 * 1e-3))
+                    ax_group.signal.cursors[1].set_visible(True)
+                    ax_group.spectrogram.cursors[1].set_visible(True)
+            if np.isnan(row['ipi_corr_man']):
+                ax_group.correlation.set_visible(False)
             else:
-                continue
-            break
-        else:
-            return
-        if j < 2:
-            pos = event.xdata * 10**(3*j)
-            self.view_data[i][0][1][self.curr_vert[i]].set_xdata((pos, pos))
-            self.view_data[i][1][1][self.curr_vert[i]].set_xdata((pos*10**-3, pos*10**-3))
-            self.view_data[i][0][1][self.curr_vert[i]].set_visible(True)
-            self.view_data[i][1][1][self.curr_vert[i]].set_visible(True)
-            if self.curr_vert[i] == 0:
-                self.df[self.offset[self.curr_ind[i],0]]['p1_pos'] = pos
-            self.curr_vert[i] ^= 1
-            self.cursor[i].linev[0].set_linestyle(['--','-'][self.curr_vert[i]])
-            self.cursor[i].linev[1].set_linestyle(['--','-'][self.curr_vert[i]])
-            if self.view_data[i][0][1][1].get_visible():
-                ipi_man = self.view_data[i][0][1][1].get_xdata()[0] - self.view_data[i][0][1][0].get_xdata()[0]
-                self.df[self.offset[self.curr_ind[i], 0]]['ipi_sig'] = ipi_man
-                self.view_ax[i][0].set_xlabel(f'Sig man:{ipi_man:.5f}')
+                ax_group.correlation.cursors[0].set_xdata((row['ipi_corr_man'], row['ipi_corr_man']))
+                ax_group.correlation.set_visible(True)
+            if np.isnan(row['ipi_ceps_man']):
+                ax_group.cepstrum.set_visible(False)
+            else:
+                ax_group.cepstrum.cursors[0].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man']))
+                ax_group.cepstrum.set_visible(True)
+        self.ind_select = False
+        self.ind_b.color = '0.85'
+        self.ind_b.howercolor = '0.95'
+        self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos / FSSR]["ind_number"]}')
+        click = self.song[
+                max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):int(mpos * self.sr / FSSR + 10e-3 * self.sr)]
+        if len(click) != 2 * int(10e-3 * self.sr):
+            np.pad(click, (0, 2 * int(10e-3 * self.sr) - len(click)), mode='constant')
+        ax_group[0][0].set_ydata(norm(click))
+        self._update_spectrogram(ax_group)
+        ax_group[2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):]))
+        ax_group[3][0].set_ydata(
+            norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3 * self.sr)])))
+        self._set_label()
+        plt.draw()
+
+    def _on_clicked_ax_group(self, event, ax_group, group_num, cell_num):
+        if cell_num < 2:
+            pos = event.xdata if cell_num else event.xdata * 1000
+            cur_type = self.curr_vert[group_num]
+            ax_group.signal.cursors[cur_type].set_xdata((pos, pos))
+            ax_group.spectrogram.cursors[cur_type].set_xdata((pos*10**-3, pos*10**-3))
+            ax_group.signal.cursor[cur_type].set_visible(True)
+            ax_group.spectrogram.cursors[cur_type].set_visible(True)
+            if cur_type == 0:
+                self.df[self.offset[self.curr_ind[group_num], 0]]['p1_pos'] = pos
+            self.curr_vert[group_num] ^= 1
+            self.cursor[group_num].linev[0].set_linestyle(['--', '-'][cur_type])
+            self.cursor[group_num].linev[1].set_linestyle(['--', '-'][cur_type])
+            if ax_group.signal.cursors[1].get_visible():
+                ipi_man = ax_group.signal_ipi
+                self.df[self.offset[self.curr_ind[group_num], 0]]['ipi_sig'] = ipi_man
+                ax_group[0].set_xlabel(f'Sig man:{ipi_man:.5f}')
         else:
-            self.view_data[i][j][1].set_xdata((event.xdata, event.xdata))
-            self.view_data[i][j][1].set_visible(True)
-            ipi_auto = np.argmax(self.view_data[i][j][0].get_data()[1][
-                       max(int(self.sr/1e3*(event.xdata-IPIPK)),0):int(self.sr/1e3*(event.xdata+IPIPK))])\
-                       + max(int(self.sr/1e3*(event.xdata-IPIPK)),0)
-            col = 'ipi_' + ('corr' if j == 2 else 'ceps')
-            self.df[self.offset[self.curr_ind[i],0]][col + '_man'] = event.xdata
-            self.df[self.offset[self.curr_ind[i],0]][col + '_auto'] = ipi_auto*1e3/self.sr
-            self.view_ax[i][j].set_xlabel(f'{"Corr" if j == 2 else "Ceps"} man:{event.xdata:.3f} auto:{ipi_auto*1e3/self.sr:.3f}')
+            cell = ax_group[cell_num]
+            cell.cursor.set_xdata((event.xdata, event.xdata))
+            cell.cursor.set_visible(True)
+            lim_min = max(int(self.sr/1e3*(event.xdata-IPIPK)), 0)
+            ipi_auto = np.argmax(cell.axes.get_data()[1][lim_min:int(self.sr/1e3*(event.xdata+IPIPK))]) + lim_min
+            col = 'ipi_' + ('corr' if cell_num == 2 else 'ceps')
+            self.df[self.offset[self.curr_ind[group_num], 0]][col + '_man'] = event.xdata
+            self.df[self.offset[self.curr_ind[group_num], 0]][col + '_auto'] = ipi_auto*1e3/self.sr
+            cell.set_xlabel(f'{"Corr" if cell_num == 2 else "Ceps"} man:{event.xdata:.3f} '
+                            f'auto:{ipi_auto*1e3/self.sr:.3f}')
         plt.draw()
 
+    def on_clicked(self, event):
+        if event.inaxes == self.fax:
+            self._on_clicked_fax(event)
+        for i, ax_group in enumerate(self.view_data):  # Look if a click plot was clicked and which one
+            for j in range(len(ax_group)):
+                if event.inaxes == ax_group[j].axes:
+                    self._on_clicked_ax_group(event, ax_group, i, j)
+                    return
+
     def change_curr(self, label):
         self.curr = int(label[-1])
         self.f_cursor.linev.set_color('rgb'[self.curr])
@@ -428,7 +420,7 @@ class Callback(object):
 
     def key_pressed(self, event):
         if self.ind_select:
-            row = self.df[self.offset[self.curr_ind[self.curr],0]]
+            row = self.df[self.offset[self.curr_ind[self.curr], 0]]
             if event.key == 'backspace':
                 row['ind_number'] = row['ind_number'][:-1]
             elif event.key in ['shift', 'control', 'alt']:
@@ -441,8 +433,6 @@ class Callback(object):
             if event.key in '012':
                 self.change_curr(event.key)
                 self.r_button.set_active(int(event.key))
-                pass
-
 
     def play(self, event):
         sound = (norm(self.song_resample[self.p:self.p+FSSR*20])*(2**15-1)).astype(np.int16)
@@ -458,34 +448,40 @@ class Callback(object):
         self.fax.get_figure().set_constrained_layout(False)
 
     def onaxis(self, event):
-        if self.df[self.offset[self.curr_ind[self.curr], 0]]['onaxis'] == -1:
-            self.df[self.offset[self.curr_ind[self.curr], 0]]['onaxis'] = 1
+        row = self.df[self.offset[self.curr_ind[self.curr], 0]]
+        if row['onaxis'] == -1:
+            row['onaxis'] = 1
         else:
-            self.df[self.offset[self.curr_ind[self.curr], 0]]['onaxis'] ^= 1
+            row['onaxis'] ^= 1
         self.onaxis_b.label.set_text('On-axis'
-                                     if self.df[self.offset[self.curr_ind[self.curr], 0]]['onaxis'] else 'Off-axis')
+                                     if row['onaxis'] else 'Off-axis')
         plt.draw()
 
     def increase_freq(self, event):
-        self.view_ax[self.curr][1].set_ylim(0, self.view_ax[self.curr][1].get_ylim()[1]+1e3)
+        lim = self.view_data[0].specgram.get_ylim()[1] + 1e3
+        for ax_group in self.view_data:
+            ax_group.specgram.set_ylim(0, lim)
         plt.draw()
 
     def decrease_freq(self, event):
-        self.view_ax[self.curr][1].set_ylim(0, max(self.view_ax[self.curr][1].get_ylim()[1]-1e3,1e3))
+        lim = max(self.view_data[0].specgram.get_ylim()[1] + 1e3, 1e3)
+        for ax_group in self.view_data:
+            ax_group.specgram.set_ylim(0, lim)
         plt.draw()
 
-    def _update_spectrogram(self, num):
-        click = self.view_data[num][0][0].get_ydata()
-        spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1)[0]))
-        self.view_data[num][1][0].set_data(spec)
-        self.view_data[num][1][0].set_clim(spec.max()-SPSC, spec.max())
+    def _update_spectrogram(self, ax_group):
+        click = ax_group.signal.axes.get_ydata()
+        spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1,
+                                                       pad_to=2*self.nfft)[0]))
+        ax_group.spectrogram.im.set_data(spec)
+        ax_group.spectrogram.im.set_clim(spec.max() - SPSC, spec.max())
 
     def increase_res(self, event):
         if self.nfft > int(10e-3*self.sr):
             return
         self.nfft *= 2
-        for i in range(3):
-            self._update_spectrogram(i)
+        for ax_group in self.view_data:
+            self._update_spectrogram(ax_group)
         if self.nfft > int(10e-3*self.sr):
             self.spec_b['plus_res'].label.set_text('Can\'t go\nhigher')
         else:
@@ -494,13 +490,13 @@ class Callback(object):
         plt.draw()
 
     def decrease_res(self, event):
-        if self.nfft <8:
+        if self.nfft < 8:
             return
         self.nfft //= 2
-        for i in range(3):
-            self._update_spectrogram(i)
+        for ax_group in self.view_data:
+            self._update_spectrogram(ax_group)
         self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins')
-        if self.nfft <8:
+        if self.nfft < 8:
             self.spec_b['minus_res'].label.set_text('Can\'t go\nlower')
         else:
             self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins')
@@ -511,9 +507,10 @@ class Callback(object):
             ind = self.curr
         if dic is None:
             dic = self.df[self.offset[self.curr_ind[ind], 0]]
-        self.view_ax[ind][0].set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
-        self.view_ax[ind][2].set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
-        self.view_ax[ind][3].set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
+        ax_group = self.view_data[ind]
+        ax_group[0].set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
+        ax_group[2].set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
+        ax_group[3].set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
         self.ind_b.label.set_text(f'Current individual:\n{dic["ind_number"]}')
         if self.onaxis_b is not None:
             self.onaxis_b.label.set_text(['Off', 'On', '?'][dic['onaxis']] + '-axis')
@@ -521,12 +518,7 @@ class Callback(object):
     def _set_visible(self, ind=None, state=False):
         if ind is None:
             ind = self.curr
-        self.view_data[ind][0][1][0].set_visible(state)
-        self.view_data[ind][0][1][1].set_visible(state)
-        self.view_data[ind][1][1][0].set_visible(state)
-        self.view_data[ind][1][1][1].set_visible(state)
-        self.view_data[ind][2][1].set_visible(state)
-        self.view_data[ind][3][1].set_visible(state)
+        self.view_data[ind].set_visible(state)
 
     def reset_curr(self, event):
         self.df[self.offset[self.curr_ind[self.curr], 0]] = EMLN.copy()
@@ -549,17 +541,17 @@ class Callback(object):
         self.offset = np.zeros((0, 2))
         self.scat.set_offsets(self.offset)
         self.scat.set_color([[0, 0, 0, 1]])
-        self.curr_ind = 3 * [None]  # Ind of click for each plot
-        self.curr_vert = 3 * [0]  # Current vertical line of sig/spec for each  plot
-        for i in range(3):
-            self.view_data[i][0][0].set_ydata(np.zeros(int(20e-3 * sr)))
-            self.view_data[i][2][0].set_ydata(np.zeros(int(10e-3 * sr)))
-            self.view_data[i][3][0].set_ydata(np.zeros(int(10e-3 * sr)))
+        self.curr_ind = self.num_view * [None]  # Ind of click for each plot
+        self.curr_vert = self.num_view * [0]  # Current vertical line of sig/spec for each  plot
+        for ax_group in self.view_data:
+            ax_group[0][0].set_ydata(np.zeros(int(20e-3 * sr)))
+            ax_group[2][0].set_ydata(np.zeros(int(10e-3 * sr)))
+            ax_group[3][0].set_ydata(np.zeros(int(10e-3 * sr)))
             if sr_update:
-                self.view_data[i][0][0].set_xdata(np.linspace(0, 20, int(20e-3*sr), False))
-                self.view_data[i][2][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
-                self.view_data[i][3][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
-            self.view_data[i][1][0].set_clim(2000,2100)
+                ax_group[0][0].set_xdata(np.linspace(0, 20, int(20e-3*sr), False))
+                ax_group[2][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
+                ax_group[3][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
+            ax_group[1][0].set_clim(2000, 2100)
         for i in range(3):
             self._set_label(i, EMLN)
             self._set_visible(i)
@@ -570,12 +562,79 @@ class Callback(object):
         plt.draw()
 
 
+class AxesWithCursor:
+    __slots__ = ['axes', 'im', 'cursors']
+
+    def __init__(self, axes: plt.Axes, im=None, cursors=None):
+        self.axes = axes
+        self.im = im
+        self.cursors = list() if cursors is None else cursors
+
+    def set_visible(self, state):
+        for cursor in self.cursors:
+            cursor.set_visible(state)
+
+
+class AxesGroup:
+    __slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum']
+
+    def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr):
+        self.signal = AxesWithCursor(ax_sig,
+                                     ax_sig.plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0],
+                                     (ax_sig.axvline(10, c='k', linestyle='--'), ax_sig.axvline(10, c='k')))
+        ax_sig.set_xlim(0, 20)
+        ax_sig.set_xlabel('IPI man:None')
+        ax_sig.set_ylim(-1, 1)
+        ax_sig.set_yticks([])
+        self.spectrogram = AxesWithCursor(ax_spec,
+                                          ax_spec.specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)),
+                                                           Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1],
+                                          (ax_spec.axvline(0.01, c='k', linestyle='--'), ax_spec.axvline(0.01, c='k')))
+        self.spectrogram.im.set_clim(2000, 2100)
+        ax_spec.set_ylim(0, min(20e3, sr/2))
+        ax_spec.set_ytickslabels((ax_spec.get_yticks() / 1e3).astype(int))
+        self.correlation = AxesWithCursor(ax_corr,
+                                          ax_corr.plot(np.linspace(0, 10, int(10e-3 * sr), False),
+                                                       np.zeros(int(10e-3 * sr)))[0],
+                                          (ax_corr.axvline(10, c='k'),))
+        ax_corr.set_xlim(0, 10)
+        ax_corr.set_xlabel('IPI man:None auto:None')
+        ax_corr.set_ylim(-1, 1)
+        ax_corr.set_yticks([])
+
+        self.cepstrum = AxesWithCursor(ax_cep,
+                                       ax_cep.plot(np.linspace(0, 10, int(10e-3 * sr), False),
+                                                   np.zeros(int(10e-3 * sr))),
+                                       (ax_cep.axvline(10, c='k'),))
+        ax_cep.set_xlim(0, 10)
+        ax_cep.set_xlabel('IPI man:None auto:None')
+        ax_cep.set_ylim(-1, 1)
+        ax_cep.set_yticks([])
+        self.set_visible(False)
+
+    def __getitem__(self, item):
+        return (self.signal, self.spectrogram, self.correlation, self.cepstrum)[item]
+
+    def __len__(self):
+        return 4
+
+    def set_visible(self, state):
+        self.signal.set_visible(state)
+        self.spectrogram.set_visible(state)
+        self.correlation.set_visible(state)
+        self.cepstrum.set_visible(state)
+
+    @property
+    def signal_ipi(self):
+        return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0]
+
+
 def init(in_path, channel, low=2e3, high=20e3):
     song, sr, song_resample = load_file(in_path, channel, low, high)
     fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
     gs = fig.add_gridspec(12, 20)
 
-    full_sig = plt.subplot(gs[:2, 1:-1])
+    full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1])  # Type annotation needed with plt 3.1
     callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
                         song, song_resample, sr, full_sig)
     callback.fax = full_sig
@@ -610,19 +669,18 @@ def init(in_path, channel, low=2e3, high=20e3):
     vs = 2
     hfs = 3
     hs = 1
-    ax_view = [[plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]),
-                plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]),
-                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]),
-                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs])],
-               [plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]),
-                plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]),
-                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]),
-                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs])],
-               [plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]),
-                plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]),
-                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]),
-                plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs])]]
-    callback.view_ax = ax_view
+    ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]),
+                          plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr),
+                AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]),
+                          plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr),
+                AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]),
+                          plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr)]
 
     play_b_ax = plt.subplot(gs[-1, :2])
     play_b = Button(play_b_ax, 'Play\ncurrent segment')
@@ -659,60 +717,16 @@ def init(in_path, channel, low=2e3, high=20e3):
     spec_button = {'plus': freq_p_b, 'minus': freq_m_b, 'plus_res': freq_res_p_b, 'minus_res': freq_res_m_b}
     callback.spec_b = spec_button
 
-    data_view = [[2 * [None] for _ in range(4)] for _ in range(3)]
-    m_cursor = [None for _ in range(3)]
-    # m_cursor2 = [[None for _ in range(4)] for _ in range(3)]
+    m_cursor: List[Union[MyMultiCursor, None]] = len(ax_group) * [None]
     callback.cursor = m_cursor
-    for i in range(3):
-        data_view[i][0][1] = (ax_view[i][0].axvline(10, c='k', linestyle='--'), ax_view[i][0].axvline(10, c='k'))
-        data_view[i][0][0] = ax_view[i][0].plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0]
-        ax_view[i][0].set_xlim(0, 20)
-        ax_view[i][0].set_xlabel('IPI man:None')
-        ax_view[i][0].set_ylim(-1, 1)
-        data_view[i][0][1][0].set_visible(False)
-        data_view[i][0][1][1].set_visible(False)
-
-        data_view[i][1][0] = ax_view[i][1].specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)),
-                                                    Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1]
-        data_view[i][1][0].set_clim(2000,2100)
-        data_view[i][1][1] = (ax_view[i][1].axvline(0.01, c='k', linestyle='--'), ax_view[i][1].axvline(0.01, c='k'))
-        data_view[i][1][1][0].set_visible(False)
-        data_view[i][1][1][1].set_visible(False)
-
-        data_view[i][2][1] = ax_view[i][2].axvline(10, c='k')
-        data_view[i][2][0] = ax_view[i][2].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[
-            0]
-        ax_view[i][2].set_xlim(0, 10)
-        ax_view[i][2].set_xlabel('IPI man:None auto:None')
-        ax_view[i][2].set_ylim(-1, 1)
-        data_view[i][2][1].set_visible(False)
-
-        data_view[i][3][1] = ax_view[i][3].axvline(10, c='k')
-        data_view[i][3][0] = ax_view[i][3].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[
-            0]
-        ax_view[i][3].set_xlim(0, 10)
-        ax_view[i][3].set_xlabel('IPI man:None auto:None')
-        ax_view[i][3].set_ylim(0, 1)
-        data_view[i][3][1].set_visible(False)
-
-        for j in range(4):
-            # m_cursor2[i][j] = Cursor(ax_view[i][j], horizOn=False, useblit=True, c='k')
-            if j != 1:
-                ax_view[i][j].set_yticks([])
-            else:
-                ax_view[i][j].set_ylim(0, min(20e3, sr/2))
-                ax_view[i][j].set_yticks(ax_view[i][j].get_yticks())
-                ax_view[i][j].set_yticklabels((ax_view[i][j].get_yticks() / 1e3).astype(int))
-        # m_cursor2[i][0].linev.set_linestyle('--')
-        # m_cursor2[i][1].linev.set_linestyle('--')
-        m_cursor[i] = MyMultiCursor(ax_view[i], data_view[i][0][1][0], horizOn=False, useblit=True, c='k')
+    for i in range(len(ax_group)):
+        m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], horizOn=False, useblit=True, c='k')
         m_cursor[i].linev[0].set_linestyle('--')
         m_cursor[i].linev[1].set_linestyle('--')
-    callback.view_data = data_view
+    callback.view_data = ax_group
     return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons':
-        {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
-        'ind_b': ind_b, 'fs_click': cid1, 'key_pressed': cid2, 'reset_b': reset_b ,
-         'spec_button': spec_button}}  # Needed to keep the callbacks alive
+            {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
+             'ind_b': ind_b, 'fs_click': cid1, 'key_pressed': cid2, 'reset_b': reset_b, 'spec_button': spec_button}}
 
 
 def reset(callback, in_path, channel, low=2e3, high=20e3):
@@ -734,7 +748,7 @@ def main(args):
         return 1
 
     EMLN['onaxis'] = -1
-    ref_dict = init(args.input, args.channel)
+    ref_dict = init(args.input, args.channel, args.low, args.up)
     if args.resume:
         df = pd.read_hdf(outpath)
         if 'onaxis' not in df.columns:
@@ -774,6 +788,8 @@ if __name__ == '__main__':
     parser.add_argument("input", type=str, help="Input file")
     parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
     parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
+    parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter')
+    parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter')
     group = parser.add_mutually_exclusive_group()
     group.add_argument("--erase", action='store_true', help="If out file exist and this option is not given,"
                                                              " the computation will be halted")
@@ -800,6 +816,8 @@ if __name__ == '__main__':
                 self.input = input("What is the input file path? ")
                 self.out = input("What is the out file path? (Leave empty for default)")
                 self.channel = int(input("Which channel do you want to use starting from 0? "))
+                self.low = float(input("What is the lower frequency of the bandpass? "))
+                self.up = float(input("What is the upper frequency of the bandpass? "))
                 self.erase = ask("Do you want to erase the out file if it already exist?")
                 if not self.erase:
                     self.resume = ask("Do you want to resume the out file if it already exist?")
-- 
GitLab