diff --git a/ipi_extract.py b/ipi_extract.py
index 933aac06c4b86e1d4cc127ca6ad71ffc480165f1..27d5cb41092949d0bce290810e85ec0ed2f03bc3 100644
--- a/ipi_extract.py
+++ b/ipi_extract.py
@@ -244,7 +244,7 @@ class MyMultiCursor(AxesWidget):
 
 
 class Callback(object):
-    def __init__(self, line, song, song_resample, sr, full_ax, num_view):
+    def __init__(self, line, song, song_resample, sr, full_ax, num_view, after_length):
         self.p = 0
         self.df = dict()
         self.line = line
@@ -253,6 +253,7 @@ class Callback(object):
         self.sr = sr
         self.fax = full_ax
         self.num_view = num_view
+        self.after_length = after_length
         self.view_data: [AxesGroup, None] = None
         self.curr = 0  # current view selected
         self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]])
@@ -312,10 +313,9 @@ class Callback(object):
             self.cursor[self.curr].linev[1].set_linestyle('--')
             for i in range(4):
                 if i < 2:
-                    ax_group[i][1][0].set_visible(False)
-                    ax_group[i][1][1].set_visible(False)
+                    ax_group[i].set_visible(False)
                 else:
-                    ax_group[i][1].set_visible(False)
+                    ax_group[i].set_visible(False)
         else:
             self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0])
             c = np.array([0, 0, 0, 0.])
@@ -360,10 +360,10 @@ class Callback(object):
         self.ind_b.color = '0.85'
         self.ind_b.howercolor = '0.95'
         self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos / FSSR]["ind_number"]}')
-        click = self.song[
-                max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):int(mpos * self.sr / FSSR + 10e-3 * self.sr)]
-        if len(click) != 2 * int(10e-3 * self.sr):
-            np.pad(click, (0, 2 * int(10e-3 * self.sr) - len(click)), mode='constant')
+        click = self.song[max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):
+                          int(mpos * self.sr / FSSR + self.after_length * self.sr)]
+        if len(click) < 2 * int((10e-3 + self.after_length) * self.sr):
+            np.pad(click, (0, 2 * int((10e-3 + self.after_length) * self.sr) - len(click)), mode='constant')
         ax_group.signal.im.set_ydata(norm(click))
         self._update_spectrogram(ax_group)
         ax_group.correlation.im.set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):]))
@@ -374,25 +374,26 @@ class Callback(object):
 
     def _on_clicked_ax_group(self, event, ax_group, group_num, cell_num):
         if cell_num < 2:
-            pos = event.xdata if cell_num else event.xdata * 1000
+            pos = event.xdata if cell_num == 0 else event.xdata * 1000
             cur_type = self.curr_vert[group_num]
-            ax_group.signal.cursors[cur_type].set_xdata((pos, pos))
-            ax_group.spectrogram.cursors[cur_type].set_xdata((pos*10**-3, pos*10**-3))
-            ax_group.signal.cursor[cur_type].set_visible(True)
+            ax_group.signal.cursors[cur_type].set_xdata(2*(pos,))
+            ax_group.spectrogram.cursors[cur_type].set_xdata(2*(pos*10**-3,))
+            ax_group.signal.cursors[cur_type].set_visible(True)
             ax_group.spectrogram.cursors[cur_type].set_visible(True)
             if cur_type == 0:
                 self.df[self.offset[self.curr_ind[group_num], 0]]['p1_pos'] = pos
             self.curr_vert[group_num] ^= 1
+            cur_type ^= 1
             self.cursor[group_num].linev[0].set_linestyle(['--', '-'][cur_type])
-            self.cursor[group_num].linev[1].set_linestyle(['--', '-'][cur_type])
+            self.cursor[group_num].linev[self.cursor[group_num].num_cur].set_linestyle(['--', '-'][cur_type])
             if ax_group.signal.cursors[1].get_visible():
                 ipi_man = ax_group.signal_ipi
                 self.df[self.offset[self.curr_ind[group_num], 0]]['ipi_sig'] = ipi_man
-                ax_group[0].set_xlabel(f'Sig man:{ipi_man:.5f}')
+                ax_group.signal.axes.set_xlabel(f'Sig man:{ipi_man:.5f}')
         else:
             cell = ax_group[cell_num]
-            cell.cursor.set_xdata((event.xdata, event.xdata))
-            cell.cursor.set_visible(True)
+            cell.cursors.set_xdata((event.xdata, event.xdata))
+            cell.cursors.set_visible(True)
             lim_min = max(int(self.sr/1e3*(event.xdata-IPIPK)), 0)
             ipi_auto = np.argmax(cell.axes.get_data()[1][lim_min:int(self.sr/1e3*(event.xdata+IPIPK))]) + lim_min
             col = 'ipi_' + ('corr' if cell_num == 2 else 'ceps')
@@ -555,14 +556,15 @@ class Callback(object):
         self.curr_ind = self.num_view * [None]  # Ind of click for each plot
         self.curr_vert = self.num_view * [0]  # Current vertical line of sig/spec for each  plot
         for ax_group in self.view_data:
-            ax_group[0][0].set_ydata(np.zeros(int(20e-3 * sr)))
-            ax_group[2][0].set_ydata(np.zeros(int(10e-3 * sr)))
-            ax_group[3][0].set_ydata(np.zeros(int(10e-3 * sr)))
+            ax_group.signal.im.set_ydata(np.zeros(int((10e-3 + self.after_length) * sr)))
+            ax_group.correlation.im.set_ydata(np.zeros(int(10e-3 * sr)))
+            ax_group.cepstrum.im.set_ydata(np.zeros(int(10e-3 * sr)))
             if sr_update:
-                ax_group[0][0].set_xdata(np.linspace(0, 20, int(20e-3*sr), False))
-                ax_group[2][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
-                ax_group[3][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
-            ax_group[1][0].set_clim(2000, 2100)
+                ax_group.signal.im.set_xdata(np.linspace(0, (10e-3 + self.after_length)*1e3,
+                                                     int((10e-3 + self.after_length)*sr), False))
+                ax_group.correlation.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
+                ax_group.cepstrum.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
+            ax_group.spectrogram.im.set_clim(2000, 2100)
         for i in range(self.num_view):
             self._set_label(i, EMLN)
             self._set_visible(i)
@@ -593,17 +595,18 @@ class AxesWithCursor:
 class AxesGroup:
     __slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum']
 
-    def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2):
+    def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2, after_length=10e-3):
         self.signal = AxesWithCursor(ax_sig,
-                                     ax_sig.plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0],
+                                     ax_sig.plot(np.linspace(0, 10 + after_length*1e3, int((10e-3 + after_length)*sr),
+                                                             False), np.zeros(int((10e-3 + after_length)*sr)))[0],
                                      (ax_sig.axvline(10, c='k', linestyle='--'),) +
                                      tuple(ax_sig.axvline(10, c='k') for _ in range(num_cur)))
-        ax_sig.set_xlim(0, 20)
+        ax_sig.set_xlim(0, 10 + after_length*1e3)
         ax_sig.set_xlabel('IPI man:None')
         ax_sig.set_ylim(-1, 1)
         ax_sig.set_yticks([])
         self.spectrogram = AxesWithCursor(ax_spec,
-                                          ax_spec.specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)),
+                                          ax_spec.specgram(np.random.normal(0, 1e-6, int((10e-3 + after_length)*sr)),
                                                            Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1],
                                           (ax_spec.axvline(0.01, c='k', linestyle='--'),) +
                                           tuple(ax_spec.axvline(0.01, c='k') for _ in range(num_cur)))
@@ -652,14 +655,14 @@ class AxesGroup:
         return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0]
 
 
-def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3):
+def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3, after_length=10e-3):
     song, sr, song_resample = load_file(in_path, channel, low, high)
     fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
     gs = fig.add_gridspec(12, 20)
 
     full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1])  # Type annotation needed with plt 3.1
     callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
-                        song, song_resample, sr, full_sig, num_graph)
+                        song, song_resample, sr, full_sig, num_graph, after_length)
     callback.fax = full_sig
     full_sig.set_xlim(0, 20)
     lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2
@@ -695,8 +698,8 @@ def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3):
     ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
                           plt.subplot(gs[vs:vs + vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
                           plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]), sr, num_cur)
-                for i in range(num_graph)]
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
+                          sr, num_cur, after_length) for i in range(num_graph)]
 
     play_b_ax = plt.subplot(gs[-1, :2])
     play_b = Button(play_b_ax, 'Play\ncurrent segment')
@@ -768,13 +771,13 @@ def main(args):
         return 1
 
     EMLN['onaxis'] = -1
-    ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click)
+    ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click, args.after*1e-3)
     if args.resume:
         df = pd.read_hdf(outpath)
         if 'onaxis' not in df.columns:
             df['onaxis'] = -1
         ref_dict['callback'].df = df.to_dict(orient='index')
-        ref_dict['callback'].offset = np.tile(np.array(list(ref_dict['callback'].df.keys()))[:, np.newaxis], (1,2))
+        ref_dict['callback'].offset = np.tile(np.array(list(ref_dict['callback'].df.keys()))[:, np.newaxis], (1, 2))
         ref_dict['callback'].offset[:, 1] = ref_dict['callback'].song_resample[(ref_dict['callback'].offset[:, 1] * FSSR).astype(int)]
         ref_dict['callback'].scat.set_offsets(ref_dict['callback'].offset)
         colors = np.array(len(ref_dict['callback'].offset)*[[0.7, 0.2, 0.5, 1]])
@@ -810,6 +813,7 @@ if __name__ == '__main__':
     parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
     parser.add_argument("--pulse", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
     parser.add_argument("--click", type=int, default=3, help="Number of click to plot simultaneously.")
+    parser.add_argument("--after", type=float, default=10., help="Duration to plot after P1 in ms.")
     parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter')
     parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter')
     group = parser.add_mutually_exclusive_group()
@@ -840,6 +844,7 @@ if __name__ == '__main__':
                 self.channel = int(input("Which channel do you want to use starting from 0? "))
                 self.pulse = int(input("How many pulses do you want to display (> 1) "))
                 self.click = int(input("How many clicks do you want to display (> 1) "))
+                self.after = float(input("What duration do you want after P1?"))
                 self.low = float(input("What is the lower frequency of the bandpass? "))
                 self.up = float(input("What is the upper frequency of the bandpass? "))
                 self.erase = ask("Do you want to erase the out file if it already exist?")