diff --git a/ipi_extract.py b/ipi_extract.py
index 8d43664f6e5870c899d79bca3aa226476f693089..e55a8d80c75535d3d45f6241d65566323923768d 100644
--- a/ipi_extract.py
+++ b/ipi_extract.py
@@ -123,36 +123,41 @@ class MyRadioButtons(RadioButtons):
 
 
 class MyMultiCursor(AxesWidget):
-    def __init__(self, axes, p1, horizOn=True, vertOn=True, useblit=False,
+    def __init__(self, axes, p1, num_cur, vertOn=True, useblit=False,
                  **lineprops):
         AxesWidget.__init__(self, axes[0])
         self.axes = axes
-        self.n_axes = len(axes)
         self.p1 = p1
+        self.num_cur = num_cur
         self.connect_event('motion_notify_event', self.onmove)
         self.connect_event('draw_event', self.clear)
 
         self.visible = True
-        self.horizOn = horizOn
         self.vertOn = vertOn
         self.useblit = useblit and self.canvas.supports_blit
 
         if self.useblit:
             lineprops['animated'] = True
         self.linev = []
-        for i in range(self.n_axes):
-            self.linev.append(axes[i].axes.axvline(axes[i].axes.get_xbound()[0], visible=False, **lineprops))
-
-        self.background = self.n_axes*[None]
+        for ax in self.axes:
+            for line in ax.cursors:
+                if line.get_linestyle() == '--':
+                    continue
+                self.linev.append(ax.axes.axvline(ax.axes.get_xbound()[0], visible=False, **lineprops))
+        self.linev[0].set_linestyle('--')
+        self.linev[self.num_cur].set_linestyle('--')
+        self.n_axes = len(self.linev)
+
+        self.background = len(self.axes)*[None]
         self.needclear = False
 
     def clear(self, event):
         """Internal event handler to clear the cursor."""
         if self.ignore(event):
             return
-        for i in range(self.n_axes):
+        for i, ax in enumerate(self.axes):
             if self.useblit:
-                self.background[i] = self.canvas.copy_from_bbox(self.axes[i].axes.bbox)
+                self.background[i] = self.canvas.copy_from_bbox(ax.axes.bbox)
             self.linev[i].set_visible(False)
 
     def onmove(self, event):
@@ -181,18 +186,18 @@ class MyMultiCursor(AxesWidget):
                 pos = event.xdata
             if self.linev[0].get_linestyle() == '--':
                 self.linev[0].set_xdata((pos, pos))
-                self.linev[1].set_xdata((pos / 1e3, pos / 1e3))
+                self.linev[self.num_cur].set_xdata((pos / 1e3, pos / 1e3))
                 self.linev[0].set_visible(self.visible and self.vertOn)
-                self.linev[1].set_visible(self.visible and self.vertOn)
+                self.linev[self.num_cur].set_visible(self.visible and self.vertOn)
                 self._update()
                 return
             else:
                 ipi_pos = abs(pos - self.p1.get_xdata()[0])
                 for i in range(self.n_axes):
-                    if i == 0:
-                        self.linev[i].set_xdata((pos, pos))
-                    elif i == 1:
-                        self.linev[i].set_xdata((pos / 1e3, pos / 1e3))
+                    if i < self.num_cur:
+                        self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
+                    elif i < 2*self.num_cur:
+                        self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
                     else:
                         self.linev[i].set_xdata((ipi_pos, ipi_pos))
                     self.linev[i].set_visible(self.visible and self.vertOn)
@@ -202,33 +207,36 @@ class MyMultiCursor(AxesWidget):
             ipi_pos = event.xdata
             pos = self.p1.get_xdata()[0] + ipi_pos
             restore = False
+        if self.linev[0].get_linestyle() == '--':
+            restore = True
+            self.linev[0].set_linestyle('-')
+            self.linev[self.num_cur].set_linestyle('-')
         for i in range(self.n_axes):
             if self.p1.get_visible():
-                if i == 0:
-                    if self.linev[0].get_linestyle() == '--':
-                        restore = True
-                        self.linev[0].set_linestyle('-')
-                        self.linev[1].set_linestyle('-')
-                    self.linev[i].set_xdata((pos, pos))
-                elif i == 1:
-                    self.linev[i].set_xdata((pos/1e3, pos/1e3))
+                if i < self.num_cur:
+                    self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
+                elif i < 2 * self.num_cur:
+                    self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
                 self.linev[i].set_visible(self.visible and self.vertOn)
-            if i > 1:
+            if i >= 2 * self.num_cur:
                 self.linev[i].set_xdata((ipi_pos, ipi_pos))
                 self.linev[i].set_visible(self.visible and self.vertOn)
         self._update()
 
         if restore:
             self.linev[0].set_linestyle('--')
-            self.linev[1].set_linestyle('--')
+            self.linev[self.num_cur].set_linestyle('--')
 
     def _update(self):
         if self.useblit:
-            for i in range(self.n_axes):
-                if self.background[i] is not None:
-                    self.canvas.restore_region(self.background[i])
-                self.axes[i].axes.draw_artist(self.linev[i])
-                self.canvas.blit(self.axes[i].axes.bbox)
+            k = 0
+            for bk, ax in zip(self.background, self.axes):
+                if bk is not None:
+                    self.canvas.restore_region(bk)
+                for k in range(k, k + self.num_cur if k < 2*self.num_cur else k+1):
+                    ax.axes.draw_artist(self.linev[k])
+                k += 1
+                self.canvas.blit(ax.axes.bbox)
         else:
             self.canvas.draw_idle()
         return False
@@ -355,10 +363,10 @@ class Callback(object):
                 max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):int(mpos * self.sr / FSSR + 10e-3 * self.sr)]
         if len(click) != 2 * int(10e-3 * self.sr):
             np.pad(click, (0, 2 * int(10e-3 * self.sr) - len(click)), mode='constant')
-        ax_group[0][0].set_ydata(norm(click))
+        ax_group.signal.im.set_ydata(norm(click))
         self._update_spectrogram(ax_group)
-        ax_group[2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):]))
-        ax_group[3][0].set_ydata(
+        ax_group.correlation.im.set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):]))
+        ax_group.cepstrum.im.set_ydata(
             norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3 * self.sr)])))
         self._set_label()
         plt.draw()
@@ -472,7 +480,7 @@ class Callback(object):
         plt.draw()
 
     def _update_spectrogram(self, ax_group):
-        click = ax_group.signal.axes.get_ydata()
+        click = ax_group.signal.im.get_ydata()
         spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1,
                                                        pad_to=2*self.nfft)[0]))
         ax_group.spectrogram.im.set_data(spec)
@@ -510,9 +518,9 @@ class Callback(object):
         if dic is None:
             dic = self.df[self.offset[self.curr_ind[ind], 0]]
         ax_group = self.view_data[ind]
-        ax_group[0].set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
-        ax_group[2].set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
-        ax_group[3].set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
+        ax_group.signal.axes.set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
+        ax_group.correlation.axes.set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
+        ax_group.cepstrum.axes.set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
         self.ind_b.label.set_text(f'Current individual:\n{dic["ind_number"]}')
         if self.onaxis_b is not None:
             self.onaxis_b.label.set_text(['Off', 'On', '?'][dic['onaxis']] + '-axis')
@@ -584,10 +592,11 @@ class AxesWithCursor:
 class AxesGroup:
     __slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum']
 
-    def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr):
+    def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2):
         self.signal = AxesWithCursor(ax_sig,
                                      ax_sig.plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0],
-                                     (ax_sig.axvline(10, c='k', linestyle='--'), ax_sig.axvline(10, c='k')))
+                                     (ax_sig.axvline(10, c='k', linestyle='--'),) +
+                                     tuple(ax_sig.axvline(10, c='k') for _ in range(num_cur)))
         ax_sig.set_xlim(0, 20)
         ax_sig.set_xlabel('IPI man:None')
         ax_sig.set_ylim(-1, 1)
@@ -595,7 +604,8 @@ class AxesGroup:
         self.spectrogram = AxesWithCursor(ax_spec,
                                           ax_spec.specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)),
                                                            Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1],
-                                          (ax_spec.axvline(0.01, c='k', linestyle='--'), ax_spec.axvline(0.01, c='k')))
+                                          (ax_spec.axvline(0.01, c='k', linestyle='--'),) +
+                                          tuple(ax_spec.axvline(0.01, c='k') for _ in range(num_cur)))
         self.spectrogram.im.set_clim(2000, 2100)
         ax_spec.set_ylim(0, min(20e3, sr/2))
         ax_spec.set_yticklabels((ax_spec.get_yticks() / 1e3).astype(int))
@@ -610,11 +620,11 @@ class AxesGroup:
 
         self.cepstrum = AxesWithCursor(ax_cep,
                                        ax_cep.plot(np.linspace(0, 10, int(10e-3 * sr), False),
-                                                   np.zeros(int(10e-3 * sr))),
+                                                   np.zeros(int(10e-3 * sr)))[0],
                                        (ax_cep.axvline(10, c='k'),))
         ax_cep.set_xlim(0, 10)
         ax_cep.set_xlabel('IPI man:None auto:None')
-        ax_cep.set_ylim(-1, 1)
+        ax_cep.set_ylim(0, 0.33)
         ax_cep.set_yticks([])
         self.set_visible(False)
 
@@ -641,7 +651,7 @@ class AxesGroup:
         return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0]
 
 
-def init(in_path, channel, low=2e3, high=20e3):
+def init(in_path, channel, low=2e3, high=20e3, num_cur=1):
     song, sr, song_resample = load_file(in_path, channel, low, high)
     fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
     gs = fig.add_gridspec(12, 20)
@@ -684,15 +694,15 @@ def init(in_path, channel, low=2e3, high=20e3):
     ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]),
                           plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]),
                           plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr, num_cur),
                 AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]),
                           plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]),
                           plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr, num_cur),
                 AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]),
                           plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]),
                           plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr)]
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr, num_cur)]
 
     play_b_ax = plt.subplot(gs[-1, :2])
     play_b = Button(play_b_ax, 'Play\ncurrent segment')
@@ -732,9 +742,7 @@ def init(in_path, channel, low=2e3, high=20e3):
     m_cursor: List[Union[MyMultiCursor, None]] = len(ax_group) * [None]
     callback.cursor = m_cursor
     for i in range(len(ax_group)):
-        m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], horizOn=False, useblit=True, c='k')
-        m_cursor[i].linev[0].set_linestyle('--')
-        m_cursor[i].linev[1].set_linestyle('--')
+        m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], num_cur, useblit=True, c='k')
     callback.view_data = ax_group
     return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons':
             {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
@@ -758,9 +766,12 @@ def main(args):
     elif args.resume:
         print(f'Out file {outpath} does not already exist and resume option is set.')
         return 1
+    if args.num < 1:
+        print(f'{args.num} is an invalid number of pulses.')
+        return 1
 
     EMLN['onaxis'] = -1
-    ref_dict = init(args.input, args.channel, args.low, args.up)
+    ref_dict = init(args.input, args.channel, args.low, args.up, args.num)
     if args.resume:
         df = pd.read_hdf(outpath)
         if 'onaxis' not in df.columns:
@@ -800,6 +811,7 @@ if __name__ == '__main__':
     parser.add_argument("input", type=str, help="Input file")
     parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
     parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
+    parser.add_argument("--num", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
     parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter')
     parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter')
     group = parser.add_mutually_exclusive_group()
@@ -828,6 +840,7 @@ if __name__ == '__main__':
                 self.input = input("What is the input file path? ")
                 self.out = input("What is the out file path? (Leave empty for default)")
                 self.channel = int(input("Which channel do you want to use starting from 0? "))
+                self.num = int(input("How many pulse do you want to display (> 1) "))
                 self.low = float(input("What is the lower frequency of the bandpass? "))
                 self.up = float(input("What is the upper frequency of the bandpass? "))
                 self.erase = ask("Do you want to erase the out file if it already exist?")