diff --git a/ipi_extract.py b/ipi_extract.py index 8d43664f6e5870c899d79bca3aa226476f693089..e55a8d80c75535d3d45f6241d65566323923768d 100644 --- a/ipi_extract.py +++ b/ipi_extract.py @@ -123,36 +123,41 @@ class MyRadioButtons(RadioButtons): class MyMultiCursor(AxesWidget): - def __init__(self, axes, p1, horizOn=True, vertOn=True, useblit=False, + def __init__(self, axes, p1, num_cur, vertOn=True, useblit=False, **lineprops): AxesWidget.__init__(self, axes[0]) self.axes = axes - self.n_axes = len(axes) self.p1 = p1 + self.num_cur = num_cur self.connect_event('motion_notify_event', self.onmove) self.connect_event('draw_event', self.clear) self.visible = True - self.horizOn = horizOn self.vertOn = vertOn self.useblit = useblit and self.canvas.supports_blit if self.useblit: lineprops['animated'] = True self.linev = [] - for i in range(self.n_axes): - self.linev.append(axes[i].axes.axvline(axes[i].axes.get_xbound()[0], visible=False, **lineprops)) - - self.background = self.n_axes*[None] + for ax in self.axes: + for line in ax.cursors: + if line.get_linestyle() == '--': + continue + self.linev.append(ax.axes.axvline(ax.axes.get_xbound()[0], visible=False, **lineprops)) + self.linev[0].set_linestyle('--') + self.linev[self.num_cur].set_linestyle('--') + self.n_axes = len(self.linev) + + self.background = len(self.axes)*[None] self.needclear = False def clear(self, event): """Internal event handler to clear the cursor.""" if self.ignore(event): return - for i in range(self.n_axes): + for i, ax in enumerate(self.axes): if self.useblit: - self.background[i] = self.canvas.copy_from_bbox(self.axes[i].axes.bbox) + self.background[i] = self.canvas.copy_from_bbox(ax.axes.bbox) self.linev[i].set_visible(False) def onmove(self, event): @@ -181,18 +186,18 @@ class MyMultiCursor(AxesWidget): pos = event.xdata if self.linev[0].get_linestyle() == '--': self.linev[0].set_xdata((pos, pos)) - self.linev[1].set_xdata((pos / 1e3, pos / 1e3)) + self.linev[self.num_cur].set_xdata((pos / 1e3, pos / 1e3)) self.linev[0].set_visible(self.visible and self.vertOn) - self.linev[1].set_visible(self.visible and self.vertOn) + self.linev[self.num_cur].set_visible(self.visible and self.vertOn) self._update() return else: ipi_pos = abs(pos - self.p1.get_xdata()[0]) for i in range(self.n_axes): - if i == 0: - self.linev[i].set_xdata((pos, pos)) - elif i == 1: - self.linev[i].set_xdata((pos / 1e3, pos / 1e3)) + if i < self.num_cur: + self.linev[i].set_xdata(2*(pos + i*ipi_pos,)) + elif i < 2*self.num_cur: + self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,)) else: self.linev[i].set_xdata((ipi_pos, ipi_pos)) self.linev[i].set_visible(self.visible and self.vertOn) @@ -202,33 +207,36 @@ class MyMultiCursor(AxesWidget): ipi_pos = event.xdata pos = self.p1.get_xdata()[0] + ipi_pos restore = False + if self.linev[0].get_linestyle() == '--': + restore = True + self.linev[0].set_linestyle('-') + self.linev[self.num_cur].set_linestyle('-') for i in range(self.n_axes): if self.p1.get_visible(): - if i == 0: - if self.linev[0].get_linestyle() == '--': - restore = True - self.linev[0].set_linestyle('-') - self.linev[1].set_linestyle('-') - self.linev[i].set_xdata((pos, pos)) - elif i == 1: - self.linev[i].set_xdata((pos/1e3, pos/1e3)) + if i < self.num_cur: + self.linev[i].set_xdata(2*(pos + i*ipi_pos,)) + elif i < 2 * self.num_cur: + self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,)) self.linev[i].set_visible(self.visible and self.vertOn) - if i > 1: + if i >= 2 * self.num_cur: self.linev[i].set_xdata((ipi_pos, ipi_pos)) self.linev[i].set_visible(self.visible and self.vertOn) self._update() if restore: self.linev[0].set_linestyle('--') - self.linev[1].set_linestyle('--') + self.linev[self.num_cur].set_linestyle('--') def _update(self): if self.useblit: - for i in range(self.n_axes): - if self.background[i] is not None: - self.canvas.restore_region(self.background[i]) - self.axes[i].axes.draw_artist(self.linev[i]) - self.canvas.blit(self.axes[i].axes.bbox) + k = 0 + for bk, ax in zip(self.background, self.axes): + if bk is not None: + self.canvas.restore_region(bk) + for k in range(k, k + self.num_cur if k < 2*self.num_cur else k+1): + ax.axes.draw_artist(self.linev[k]) + k += 1 + self.canvas.blit(ax.axes.bbox) else: self.canvas.draw_idle() return False @@ -355,10 +363,10 @@ class Callback(object): max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):int(mpos * self.sr / FSSR + 10e-3 * self.sr)] if len(click) != 2 * int(10e-3 * self.sr): np.pad(click, (0, 2 * int(10e-3 * self.sr) - len(click)), mode='constant') - ax_group[0][0].set_ydata(norm(click)) + ax_group.signal.im.set_ydata(norm(click)) self._update_spectrogram(ax_group) - ax_group[2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):])) - ax_group[3][0].set_ydata( + ax_group.correlation.im.set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):])) + ax_group.cepstrum.im.set_ydata( norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3 * self.sr)]))) self._set_label() plt.draw() @@ -472,7 +480,7 @@ class Callback(object): plt.draw() def _update_spectrogram(self, ax_group): - click = ax_group.signal.axes.get_ydata() + click = ax_group.signal.im.get_ydata() spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1, pad_to=2*self.nfft)[0])) ax_group.spectrogram.im.set_data(spec) @@ -510,9 +518,9 @@ class Callback(object): if dic is None: dic = self.df[self.offset[self.curr_ind[ind], 0]] ax_group = self.view_data[ind] - ax_group[0].set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}') - ax_group[2].set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}') - ax_group[3].set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}') + ax_group.signal.axes.set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}') + ax_group.correlation.axes.set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}') + ax_group.cepstrum.axes.set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}') self.ind_b.label.set_text(f'Current individual:\n{dic["ind_number"]}') if self.onaxis_b is not None: self.onaxis_b.label.set_text(['Off', 'On', '?'][dic['onaxis']] + '-axis') @@ -584,10 +592,11 @@ class AxesWithCursor: class AxesGroup: __slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum'] - def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr): + def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2): self.signal = AxesWithCursor(ax_sig, ax_sig.plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0], - (ax_sig.axvline(10, c='k', linestyle='--'), ax_sig.axvline(10, c='k'))) + (ax_sig.axvline(10, c='k', linestyle='--'),) + + tuple(ax_sig.axvline(10, c='k') for _ in range(num_cur))) ax_sig.set_xlim(0, 20) ax_sig.set_xlabel('IPI man:None') ax_sig.set_ylim(-1, 1) @@ -595,7 +604,8 @@ class AxesGroup: self.spectrogram = AxesWithCursor(ax_spec, ax_spec.specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)), Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1], - (ax_spec.axvline(0.01, c='k', linestyle='--'), ax_spec.axvline(0.01, c='k'))) + (ax_spec.axvline(0.01, c='k', linestyle='--'),) + + tuple(ax_spec.axvline(0.01, c='k') for _ in range(num_cur))) self.spectrogram.im.set_clim(2000, 2100) ax_spec.set_ylim(0, min(20e3, sr/2)) ax_spec.set_yticklabels((ax_spec.get_yticks() / 1e3).astype(int)) @@ -610,11 +620,11 @@ class AxesGroup: self.cepstrum = AxesWithCursor(ax_cep, ax_cep.plot(np.linspace(0, 10, int(10e-3 * sr), False), - np.zeros(int(10e-3 * sr))), + np.zeros(int(10e-3 * sr)))[0], (ax_cep.axvline(10, c='k'),)) ax_cep.set_xlim(0, 10) ax_cep.set_xlabel('IPI man:None auto:None') - ax_cep.set_ylim(-1, 1) + ax_cep.set_ylim(0, 0.33) ax_cep.set_yticks([]) self.set_visible(False) @@ -641,7 +651,7 @@ class AxesGroup: return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0] -def init(in_path, channel, low=2e3, high=20e3): +def init(in_path, channel, low=2e3, high=20e3, num_cur=1): song, sr, song_resample = load_file(in_path, channel, low, high) fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True) gs = fig.add_gridspec(12, 20) @@ -684,15 +694,15 @@ def init(in_path, channel, low=2e3, high=20e3): ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]), plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]), plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr), + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr, num_cur), AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]), plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]), plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr), + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr, num_cur), AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]), plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]), plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr)] + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr, num_cur)] play_b_ax = plt.subplot(gs[-1, :2]) play_b = Button(play_b_ax, 'Play\ncurrent segment') @@ -732,9 +742,7 @@ def init(in_path, channel, low=2e3, high=20e3): m_cursor: List[Union[MyMultiCursor, None]] = len(ax_group) * [None] callback.cursor = m_cursor for i in range(len(ax_group)): - m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], horizOn=False, useblit=True, c='k') - m_cursor[i].linev[0].set_linestyle('--') - m_cursor[i].linev[1].set_linestyle('--') + m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], num_cur, useblit=True, c='k') callback.view_data = ax_group return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons': {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button, @@ -758,9 +766,12 @@ def main(args): elif args.resume: print(f'Out file {outpath} does not already exist and resume option is set.') return 1 + if args.num < 1: + print(f'{args.num} is an invalid number of pulses.') + return 1 EMLN['onaxis'] = -1 - ref_dict = init(args.input, args.channel, args.low, args.up) + ref_dict = init(args.input, args.channel, args.low, args.up, args.num) if args.resume: df = pd.read_hdf(outpath) if 'onaxis' not in df.columns: @@ -800,6 +811,7 @@ if __name__ == '__main__': parser.add_argument("input", type=str, help="Input file") parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'") parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.") + parser.add_argument("--num", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.") parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter') parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter') group = parser.add_mutually_exclusive_group() @@ -828,6 +840,7 @@ if __name__ == '__main__': self.input = input("What is the input file path? ") self.out = input("What is the out file path? (Leave empty for default)") self.channel = int(input("Which channel do you want to use starting from 0? ")) + self.num = int(input("How many pulse do you want to display (> 1) ")) self.low = float(input("What is the lower frequency of the bandpass? ")) self.up = float(input("What is the upper frequency of the bandpass? ")) self.erase = ask("Do you want to erase the out file if it already exist?")