diff --git a/ipi_extract.py b/ipi_extract.py
index e55a8d80c75535d3d45f6241d65566323923768d..933aac06c4b86e1d4cc127ca6ad71ffc480165f1 100644
--- a/ipi_extract.py
+++ b/ipi_extract.py
@@ -1,6 +1,7 @@
 import argparse
 import numpy as np
 import matplotlib.pyplot as plt
+from matplotlib.colors import to_rgba
 import scipy.signal as sg
 import soundfile as sf
 import os
@@ -243,7 +244,7 @@ class MyMultiCursor(AxesWidget):
 
 
 class Callback(object):
-    def __init__(self, line, song, song_resample, sr, full_ax):
+    def __init__(self, line, song, song_resample, sr, full_ax, num_view):
         self.p = 0
         self.df = dict()
         self.line = line
@@ -251,13 +252,13 @@ class Callback(object):
         self.song_resample = song_resample
         self.sr = sr
         self.fax = full_ax
-        self.num_view = 3
+        self.num_view = num_view
         self.view_data: [AxesGroup, None] = None
         self.curr = 0  # current view selected
         self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]])
         self.offset = np.zeros((0, 2))
-        self.curr_ind: List[Union[int, None]] = 3*[None]  # Indices of click for each plot
-        self.curr_vert = 3*[0]    # Current vertical line of sig/spec for each  plot
+        self.curr_ind: List[Union[int, None]] = num_view*[None]  # Indices of click for each plot
+        self.curr_vert = num_view*[0]    # Current vertical line of sig/spec for each  plot
         self.cursor = None
         self.f_cursor = None
         self.reset_b = None
@@ -317,14 +318,14 @@ class Callback(object):
                     ax_group[i][1].set_visible(False)
         else:
             self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0])
-            c = [0, 0, 0, 1]
+            c = np.array([0, 0, 0, 0.])
             k = 0
             for i, v in enumerate(self.curr_ind):
                 if v == self.curr_ind[self.curr]:
-                    c[i] = 1
+                    c += to_rgba('rgbkcmyrgbkcmy'[i])
                     k += 1
             for i in range(self.num_view):
-                c[i] /= k
+                c /= k
             self.scat._facecolors[self.curr_ind[self.curr]] = c
             self.scat.set_color(self.scat._facecolors)
             row = self.df[mpos / FSSR]
@@ -412,8 +413,8 @@ class Callback(object):
 
     def change_curr(self, label):
         self.curr = int(label[-1])
-        self.f_cursor.linev.set_color('rgb'[self.curr])
-        self.reset_b.label.set_c('rgb'[self.curr])
+        self.f_cursor.linev.set_color('rgbkcmyrgbkcmy'[self.curr])
+        self.reset_b.label.set_c('rgbkcmyrgbkcmy'[self.curr])
         plt.draw()
 
     def toggle_ind(self, event):
@@ -562,7 +563,7 @@ class Callback(object):
                 ax_group[2][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
                 ax_group[3][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
             ax_group[1][0].set_clim(2000, 2100)
-        for i in range(3):
+        for i in range(self.num_view):
             self._set_label(i, EMLN)
             self._set_visible(i)
         self.ind_select = False
@@ -651,14 +652,14 @@ class AxesGroup:
         return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0]
 
 
-def init(in_path, channel, low=2e3, high=20e3, num_cur=1):
+def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3):
     song, sr, song_resample = load_file(in_path, channel, low, high)
     fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
     gs = fig.add_gridspec(12, 20)
 
     full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1])  # Type annotation needed with plt 3.1
     callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
-                        song, song_resample, sr, full_sig)
+                        song, song_resample, sr, full_sig, num_graph)
     callback.fax = full_sig
     full_sig.set_xlim(0, 20)
     lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2
@@ -674,11 +675,11 @@ def init(in_path, channel, low=2e3, high=20e3, num_cur=1):
     b_left.on_clicked(callback.shift_left)
     b_right.on_clicked(callback.shift_right)
     r_button_ax = plt.subplot(gs[10, 1:-1])
-    r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(3)], orientation='horizontal',
-                              size=128, activecolor=list('rgb'))  # !Last char of labels is use as index in rest of code
+    r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(num_graph)], orientation='horizontal',
+                              size=128, activecolor=list('rgbkcmyrgbkcmy'[:num_graph]))  # !Last char of labels is use as index in rest of code
     r_button_ax.axis('off')
     r_button.on_clicked(callback.change_curr)
-    for i, c in enumerate('rgb'):
+    for i, c in enumerate('rgbkcmyrgbkcmy'[:num_graph]):
         r_button.labels[i].set_c(c)
     callback.r_button = r_button
     for v in "fullscreen,home,back,forward,pan,zoom,save,quit,grid,yscale,xscale,all_axes".split(','):
@@ -689,20 +690,13 @@ def init(in_path, channel, low=2e3, high=20e3, num_cur=1):
 
     vfs = 4
     vs = 2
-    hfs = 3
     hs = 1
-    ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]),
-                          plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr, num_cur),
-                AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]),
-                          plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr, num_cur),
-                AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]),
-                          plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]),
-                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr, num_cur)]
+    hfs = (20-hs)//(2*num_graph)
+    ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
+                          plt.subplot(gs[vs:vs + vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
+                          plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]), sr, num_cur)
+                for i in range(num_graph)]
 
     play_b_ax = plt.subplot(gs[-1, :2])
     play_b = Button(play_b_ax, 'Play\ncurrent segment')
@@ -766,12 +760,15 @@ def main(args):
     elif args.resume:
         print(f'Out file {outpath} does not already exist and resume option is set.')
         return 1
-    if args.num < 1:
-        print(f'{args.num} is an invalid number of pulses.')
+    if args.pulse < 1:
+        print(f'{args.pulse} is an invalid number of pulses.')
+        return 1
+    elif args.click < 1:
+        print(f'{args.click} is an invalid number of clicks.')
         return 1
 
     EMLN['onaxis'] = -1
-    ref_dict = init(args.input, args.channel, args.low, args.up, args.num)
+    ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click)
     if args.resume:
         df = pd.read_hdf(outpath)
         if 'onaxis' not in df.columns:
@@ -811,7 +808,8 @@ if __name__ == '__main__':
     parser.add_argument("input", type=str, help="Input file")
     parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
     parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
-    parser.add_argument("--num", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
+    parser.add_argument("--pulse", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
+    parser.add_argument("--click", type=int, default=3, help="Number of click to plot simultaneously.")
     parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter')
     parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter')
     group = parser.add_mutually_exclusive_group()
@@ -840,7 +838,8 @@ if __name__ == '__main__':
                 self.input = input("What is the input file path? ")
                 self.out = input("What is the out file path? (Leave empty for default)")
                 self.channel = int(input("Which channel do you want to use starting from 0? "))
-                self.num = int(input("How many pulse do you want to display (> 1) "))
+                self.pulse = int(input("How many pulses do you want to display (> 1) "))
+                self.click = int(input("How many clicks do you want to display (> 1) "))
                 self.low = float(input("What is the lower frequency of the bandpass? "))
                 self.up = float(input("What is the upper frequency of the bandpass? "))
                 self.erase = ask("Do you want to erase the out file if it already exist?")