diff --git a/ipi_extract.py b/ipi_extract.py index 933aac06c4b86e1d4cc127ca6ad71ffc480165f1..27d5cb41092949d0bce290810e85ec0ed2f03bc3 100644 --- a/ipi_extract.py +++ b/ipi_extract.py @@ -244,7 +244,7 @@ class MyMultiCursor(AxesWidget): class Callback(object): - def __init__(self, line, song, song_resample, sr, full_ax, num_view): + def __init__(self, line, song, song_resample, sr, full_ax, num_view, after_length): self.p = 0 self.df = dict() self.line = line @@ -253,6 +253,7 @@ class Callback(object): self.sr = sr self.fax = full_ax self.num_view = num_view + self.after_length = after_length self.view_data: [AxesGroup, None] = None self.curr = 0 # current view selected self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]]) @@ -312,10 +313,9 @@ class Callback(object): self.cursor[self.curr].linev[1].set_linestyle('--') for i in range(4): if i < 2: - ax_group[i][1][0].set_visible(False) - ax_group[i][1][1].set_visible(False) + ax_group[i].set_visible(False) else: - ax_group[i][1].set_visible(False) + ax_group[i].set_visible(False) else: self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0]) c = np.array([0, 0, 0, 0.]) @@ -360,10 +360,10 @@ class Callback(object): self.ind_b.color = '0.85' self.ind_b.howercolor = '0.95' self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos / FSSR]["ind_number"]}') - click = self.song[ - max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):int(mpos * self.sr / FSSR + 10e-3 * self.sr)] - if len(click) != 2 * int(10e-3 * self.sr): - np.pad(click, (0, 2 * int(10e-3 * self.sr) - len(click)), mode='constant') + click = self.song[max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0): + int(mpos * self.sr / FSSR + self.after_length * self.sr)] + if len(click) < 2 * int((10e-3 + self.after_length) * self.sr): + np.pad(click, (0, 2 * int((10e-3 + self.after_length) * self.sr) - len(click)), mode='constant') ax_group.signal.im.set_ydata(norm(click)) self._update_spectrogram(ax_group) ax_group.correlation.im.set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):])) @@ -374,25 +374,26 @@ class Callback(object): def _on_clicked_ax_group(self, event, ax_group, group_num, cell_num): if cell_num < 2: - pos = event.xdata if cell_num else event.xdata * 1000 + pos = event.xdata if cell_num == 0 else event.xdata * 1000 cur_type = self.curr_vert[group_num] - ax_group.signal.cursors[cur_type].set_xdata((pos, pos)) - ax_group.spectrogram.cursors[cur_type].set_xdata((pos*10**-3, pos*10**-3)) - ax_group.signal.cursor[cur_type].set_visible(True) + ax_group.signal.cursors[cur_type].set_xdata(2*(pos,)) + ax_group.spectrogram.cursors[cur_type].set_xdata(2*(pos*10**-3,)) + ax_group.signal.cursors[cur_type].set_visible(True) ax_group.spectrogram.cursors[cur_type].set_visible(True) if cur_type == 0: self.df[self.offset[self.curr_ind[group_num], 0]]['p1_pos'] = pos self.curr_vert[group_num] ^= 1 + cur_type ^= 1 self.cursor[group_num].linev[0].set_linestyle(['--', '-'][cur_type]) - self.cursor[group_num].linev[1].set_linestyle(['--', '-'][cur_type]) + self.cursor[group_num].linev[self.cursor[group_num].num_cur].set_linestyle(['--', '-'][cur_type]) if ax_group.signal.cursors[1].get_visible(): ipi_man = ax_group.signal_ipi self.df[self.offset[self.curr_ind[group_num], 0]]['ipi_sig'] = ipi_man - ax_group[0].set_xlabel(f'Sig man:{ipi_man:.5f}') + ax_group.signal.axes.set_xlabel(f'Sig man:{ipi_man:.5f}') else: cell = ax_group[cell_num] - cell.cursor.set_xdata((event.xdata, event.xdata)) - cell.cursor.set_visible(True) + cell.cursors.set_xdata((event.xdata, event.xdata)) + cell.cursors.set_visible(True) lim_min = max(int(self.sr/1e3*(event.xdata-IPIPK)), 0) ipi_auto = np.argmax(cell.axes.get_data()[1][lim_min:int(self.sr/1e3*(event.xdata+IPIPK))]) + lim_min col = 'ipi_' + ('corr' if cell_num == 2 else 'ceps') @@ -555,14 +556,15 @@ class Callback(object): self.curr_ind = self.num_view * [None] # Ind of click for each plot self.curr_vert = self.num_view * [0] # Current vertical line of sig/spec for each plot for ax_group in self.view_data: - ax_group[0][0].set_ydata(np.zeros(int(20e-3 * sr))) - ax_group[2][0].set_ydata(np.zeros(int(10e-3 * sr))) - ax_group[3][0].set_ydata(np.zeros(int(10e-3 * sr))) + ax_group.signal.im.set_ydata(np.zeros(int((10e-3 + self.after_length) * sr))) + ax_group.correlation.im.set_ydata(np.zeros(int(10e-3 * sr))) + ax_group.cepstrum.im.set_ydata(np.zeros(int(10e-3 * sr))) if sr_update: - ax_group[0][0].set_xdata(np.linspace(0, 20, int(20e-3*sr), False)) - ax_group[2][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) - ax_group[3][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) - ax_group[1][0].set_clim(2000, 2100) + ax_group.signal.im.set_xdata(np.linspace(0, (10e-3 + self.after_length)*1e3, + int((10e-3 + self.after_length)*sr), False)) + ax_group.correlation.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) + ax_group.cepstrum.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) + ax_group.spectrogram.im.set_clim(2000, 2100) for i in range(self.num_view): self._set_label(i, EMLN) self._set_visible(i) @@ -593,17 +595,18 @@ class AxesWithCursor: class AxesGroup: __slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum'] - def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2): + def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2, after_length=10e-3): self.signal = AxesWithCursor(ax_sig, - ax_sig.plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0], + ax_sig.plot(np.linspace(0, 10 + after_length*1e3, int((10e-3 + after_length)*sr), + False), np.zeros(int((10e-3 + after_length)*sr)))[0], (ax_sig.axvline(10, c='k', linestyle='--'),) + tuple(ax_sig.axvline(10, c='k') for _ in range(num_cur))) - ax_sig.set_xlim(0, 20) + ax_sig.set_xlim(0, 10 + after_length*1e3) ax_sig.set_xlabel('IPI man:None') ax_sig.set_ylim(-1, 1) ax_sig.set_yticks([]) self.spectrogram = AxesWithCursor(ax_spec, - ax_spec.specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)), + ax_spec.specgram(np.random.normal(0, 1e-6, int((10e-3 + after_length)*sr)), Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1], (ax_spec.axvline(0.01, c='k', linestyle='--'),) + tuple(ax_spec.axvline(0.01, c='k') for _ in range(num_cur))) @@ -652,14 +655,14 @@ class AxesGroup: return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0] -def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3): +def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3, after_length=10e-3): song, sr, song_resample = load_file(in_path, channel, low, high) fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True) gs = fig.add_gridspec(12, 20) full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1]) # Type annotation needed with plt 3.1 callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0], - song, song_resample, sr, full_sig, num_graph) + song, song_resample, sr, full_sig, num_graph, after_length) callback.fax = full_sig full_sig.set_xlim(0, 20) lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2 @@ -695,8 +698,8 @@ def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3): ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]), plt.subplot(gs[vs:vs + vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]), plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]), sr, num_cur) - for i in range(num_graph)] + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]), + sr, num_cur, after_length) for i in range(num_graph)] play_b_ax = plt.subplot(gs[-1, :2]) play_b = Button(play_b_ax, 'Play\ncurrent segment') @@ -768,13 +771,13 @@ def main(args): return 1 EMLN['onaxis'] = -1 - ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click) + ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click, args.after*1e-3) if args.resume: df = pd.read_hdf(outpath) if 'onaxis' not in df.columns: df['onaxis'] = -1 ref_dict['callback'].df = df.to_dict(orient='index') - ref_dict['callback'].offset = np.tile(np.array(list(ref_dict['callback'].df.keys()))[:, np.newaxis], (1,2)) + ref_dict['callback'].offset = np.tile(np.array(list(ref_dict['callback'].df.keys()))[:, np.newaxis], (1, 2)) ref_dict['callback'].offset[:, 1] = ref_dict['callback'].song_resample[(ref_dict['callback'].offset[:, 1] * FSSR).astype(int)] ref_dict['callback'].scat.set_offsets(ref_dict['callback'].offset) colors = np.array(len(ref_dict['callback'].offset)*[[0.7, 0.2, 0.5, 1]]) @@ -810,6 +813,7 @@ if __name__ == '__main__': parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.") parser.add_argument("--pulse", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.") parser.add_argument("--click", type=int, default=3, help="Number of click to plot simultaneously.") + parser.add_argument("--after", type=float, default=10., help="Duration to plot after P1 in ms.") parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter') parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter') group = parser.add_mutually_exclusive_group() @@ -840,6 +844,7 @@ if __name__ == '__main__': self.channel = int(input("Which channel do you want to use starting from 0? ")) self.pulse = int(input("How many pulses do you want to display (> 1) ")) self.click = int(input("How many clicks do you want to display (> 1) ")) + self.after = float(input("What duration do you want after P1?")) self.low = float(input("What is the lower frequency of the bandpass? ")) self.up = float(input("What is the upper frequency of the bandpass? ")) self.erase = ask("Do you want to erase the out file if it already exist?")