diff --git a/ipi_extract.py b/ipi_extract.py index 6e0d5a0ed17d1881076eaf916ba5ad5248831839..41aaa4b65d74ce9f1e803628c1ca2f32004afc5d 100644 --- a/ipi_extract.py +++ b/ipi_extract.py @@ -10,12 +10,13 @@ from fractions import Fraction from pydub import AudioSegment from pydub.playback import play import pandas as pd +from typing import List, Union FSSR = 48_000 # Sampling rate of full signal plot FSPK = 0.1 # Max distance to detect a click in full sig in seconds -IPIPK= 0.15 # Max distance to detect a IPI in milliseconds +IPIPK = 0.15 # Max distance to detect a IPI in milliseconds SPSC = 80 # Spectrogram scale -EMLN = {'p1_pos':np.nan, 'ipi_sig': np.nan, +EMLN = {'p1_pos': np.nan, 'ipi_sig': np.nan, 'ipi_corr_man': np.nan, 'ipi_corr_auto': np.nan, 'ipi_ceps_man': np.nan, 'ipi_ceps_auto': np.nan, 'ind_number': np.nan} # Empty dataline @@ -138,10 +139,8 @@ class MyMultiCursor(AxesWidget): if self.useblit: lineprops['animated'] = True - #self.lineh = [] self.linev = [] for i in range(self.n_axes): - #self.lineh.append(axes[i].axhline(axes[i].get_ybound()[0], visible=False, **lineprops)) self.linev.append(axes[i].axvline(axes[i].get_xbound()[0], visible=False, **lineprops)) self.background = self.n_axes*[None] @@ -155,7 +154,6 @@ class MyMultiCursor(AxesWidget): if self.useblit: self.background[i] = self.canvas.copy_from_bbox(self.axes[i].bbox) self.linev[i].set_visible(False) - #self.lineh[i].set_visible(False) def onmove(self, event): """Internal event handler to draw the cursor when the mouse moves.""" @@ -166,8 +164,6 @@ class MyMultiCursor(AxesWidget): if event.inaxes not in self.axes: for i in range(self.n_axes): self.linev[i].set_visible(False) - #self.lineh[i].set_visible(False) - if self.needclear: self.canvas.draw() self.needclear = False @@ -197,9 +193,7 @@ class MyMultiCursor(AxesWidget): self.linev[i].set_xdata((pos / 1e3, pos / 1e3)) else: self.linev[i].set_xdata((ipi_pos, ipi_pos)) - # self.lineh[i].set_ydata((event.ydata, event.ydata)) self.linev[i].set_visible(self.visible and self.vertOn) - # self.lineh[i].set_visible(self.visible and self.horizOn) self._update() return else: @@ -219,10 +213,7 @@ class MyMultiCursor(AxesWidget): self.linev[i].set_visible(self.visible and self.vertOn) if i > 1: self.linev[i].set_xdata((ipi_pos, ipi_pos)) - #self.lineh[i].set_ydata((event.ydata, event.ydata)) self.linev[i].set_visible(self.visible and self.vertOn) - #self.lineh[i].set_visible(self.visible and self.horizOn) - self._update() if restore: @@ -235,7 +226,6 @@ class MyMultiCursor(AxesWidget): if self.background[i] is not None: self.canvas.restore_region(self.background[i]) self.axes[i].draw_artist(self.linev[i]) - #self.axes[i].draw_artist(self.lineh[i]) self.canvas.blit(self.axes[i].bbox) else: self.canvas.draw_idle() @@ -251,12 +241,12 @@ class Callback(object): self.song_resample = song_resample self.sr = sr self.fax = full_ax - self.view_ax = None - self.view_data = None + self.num_view = 3 + self.view_data: [AxesGroup, None] = None self.curr = 0 # current view selected - self.scat = self.fax.scatter([],[], marker='x', c=[[0.7, 0.2, 0.5, 1]]) - self.offset = np.zeros((0,2)) - self.curr_ind = 3*[None] # Ind of click for each plot + self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]]) + self.offset = np.zeros((0, 2)) + self.curr_ind: List[Union[int, None]] = 3*[None] # Indices of click for each plot self.curr_vert = 3*[0] # Current vertical line of sig/spec for each plot self.cursor = None self.f_cursor = None @@ -284,130 +274,132 @@ class Callback(object): self.fax.set_xlim(self.p/FSSR, self.p/FSSR+20) plt.draw() - - def on_clicked(self, event): - if event.inaxes == self.fax: # Click on full signal plot - pos = int(FSSR*event.xdata) - mpos = np.argmax(self.song_resample[max(pos-int(FSSR*FSPK),0):pos+int(FSSR*FSPK)]) + max(pos-int(FSSR*FSPK),0) - if self.curr_ind[self.curr] is not None: - if np.all(pd.isnull(list(self.df[self.offset[self.curr_ind[self.curr], 0]].values()))): - self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1] - else: - self.scat._facecolors[self.curr_ind[self.curr]] = [0, 0, 0, 1] - if mpos/FSSR not in self.offset[:,0]: - self.offset = np.concatenate([self.offset, [[mpos/FSSR, self.song_resample[mpos]]]], axis=0) - self.scat.set_offsets(self.offset) - self.df[mpos/FSSR] = EMLN.copy() - c = [[0, 0, 0, 1]] - c[0][self.curr] = 1 - if len(self.offset) == 1: - self.scat.set_color(c) - else: - self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0)) - self.curr_ind[self.curr] = len(self.offset) - 1 - self.curr_vert[self.curr] = 0 - self.cursor[self.curr].linev[0].set_linestyle('--') - self.cursor[self.curr].linev[1].set_linestyle('--') - for i in range(4): - if i < 2: - self.view_data[self.curr][i][1][0].set_visible(False) - self.view_data[self.curr][i][1][1].set_visible(False) - else: - self.view_data[self.curr][i][1].set_visible(False) + def _on_clicked_fax(self, event): + # Click on full signal plot + pos = int(FSSR * event.xdata) + mpos = np.argmax(self.song_resample[max(pos - int(FSSR * FSPK), 0): + pos + int(FSSR * FSPK)]) + max(pos - int(FSSR * FSPK), 0) + ax_group = self.view_data[self.curr] + if self.curr_ind[self.curr] is not None: + if np.all(pd.isnull(list(self.df[self.offset[self.curr_ind[self.curr], 0]].values()))): + self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1] else: - self.curr_ind[self.curr] = np.argmax(mpos/FSSR == self.offset[:,0]) - c = [0, 0, 0, 1] - k = 0 - for i, v in enumerate(self.curr_ind): - if v == self.curr_ind[self.curr]: - c[i] = 1 - k += 1 - for i in range(3): - c[i] /=k - self.scat._facecolors[self.curr_ind[self.curr]] = c - self.scat.set_color(self.scat._facecolors) - row = self.df[mpos/FSSR] - if np.isnan(row['p1_pos']): - self.view_data[self.curr][0][1][0].set_visible(False) - self.view_data[self.curr][0][1][1].set_visible(False) - self.view_data[self.curr][1][1][0].set_visible(False) - self.view_data[self.curr][1][1][1].set_visible(False) - else: - self.view_data[self.curr][0][1][0].set_xdata((row['p1_pos'], row['p1_pos'])) - self.view_data[self.curr][1][1][0].set_xdata((row['p1_pos']*1e-3, row['p1_pos']*1e-3)) - self.view_data[self.curr][0][1][0].set_visible(True) - self.view_data[self.curr][1][1][0].set_visible(True) - if np.isnan(row['ipi_sig']): - self.view_data[self.curr][0][1][1].set_visible(False) - self.view_data[self.curr][1][1][1].set_visible(False) - else: - p2 = row['p1_pos'] + row['ipi_sig'] - self.view_data[self.curr][0][1][1].set_xdata((p2, p2)) - self.view_data[self.curr][1][1][1].set_xdata((p2*1e-3, p2*1e-3)) - self.view_data[self.curr][0][1][1].set_visible(True) - self.view_data[self.curr][1][1][1].set_visible(True) - if np.isnan(row['ipi_corr_man']): - self.view_data[self.curr][2][1].set_visible(False) + self.scat._facecolors[self.curr_ind[self.curr]] = [0, 0, 0, 1] + if mpos / FSSR not in self.offset[:, 0]: + self.offset = np.concatenate([self.offset, [[mpos / FSSR, self.song_resample[mpos]]]], axis=0) + self.scat.set_offsets(self.offset) + self.df[mpos / FSSR] = EMLN.copy() + c = [[0, 0, 0, 1]] + c[0][self.curr] = 1 + if len(self.offset) == 1: + self.scat.set_color(c) + else: + self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0)) + self.curr_ind[self.curr] = len(self.offset) - 1 + self.curr_vert[self.curr] = 0 + self.cursor[self.curr].linev[0].set_linestyle('--') + self.cursor[self.curr].linev[1].set_linestyle('--') + for i in range(4): + if i < 2: + ax_group[i][1][0].set_visible(False) + ax_group[i][1][1].set_visible(False) else: - self.view_data[self.curr][2][1].set_xdata((row['ipi_corr_man'], row['ipi_corr_man'])) - self.view_data[self.curr][2][1].set_visible(True) - if np.isnan(row['ipi_ceps_man']): - self.view_data[self.curr][3][1].set_visible(False) + ax_group[i][1].set_visible(False) + else: + self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0]) + c = [0, 0, 0, 1] + k = 0 + for i, v in enumerate(self.curr_ind): + if v == self.curr_ind[self.curr]: + c[i] = 1 + k += 1 + for i in range(self.num_view): + c[i] /= k + self.scat._facecolors[self.curr_ind[self.curr]] = c + self.scat.set_color(self.scat._facecolors) + row = self.df[mpos / FSSR] + if np.isnan(row['p1_pos']): + ax_group.signal.set_visible(False) + ax_group.spectrogram.set_visible(False) + else: + ax_group.signal.cursors[0].set_xdata((row['p1_pos'], row['p1_pos'])) + ax_group.spectrogram.cursors[0].set_xdata((row['p1_pos'] * 1e-3, row['p1_pos'] * 1e-3)) + ax_group.signal.cursors[0].set_visible(True) + ax_group.spectrogram.cursors[0].set_visible(True) + if np.isnan(row['ipi_sig']): + ax_group.signal.cursors[1].set_visible(False) + ax_group.spectrogram.cursors[1].set_visible(False) else: - self.view_data[self.curr][3][1].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man'])) - self.view_data[self.curr][3][1].set_visible(True) - self.ind_select = False - self.ind_b.color = '0.85' - self.ind_b.howercolor = '0.95' - self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos/FSSR]["ind_number"]}') - click = self.song[max(int(mpos*self.sr/FSSR-10e-3*self.sr),0):int(mpos*self.sr/FSSR+10e-3*self.sr)] - if len(click) != 2*int(10e-3*self.sr): - np.pad(click, (0, 2*int(10e-3*self.sr) - len(click)), mode='constant') - self.view_data[self.curr][0][0].set_ydata(norm(click)) - spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1)[0])) - self.view_data[self.curr][1][0].set_data(spec) - self.view_data[self.curr][1][0].set_clim(spec.max()-SPSC, spec.max()) - self.view_data[self.curr][2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3*self.sr):])) - self.view_data[self.curr][3][0].set_ydata(norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3*self.sr)]))) - self._set_label() - plt.draw() - return - for i in range(3): # Look if a click plot was clicked and which one - for j in range(4): - if event.inaxes == self.view_ax[i][j]: - break + p2 = row['p1_pos'] + row['ipi_sig'] + ax_group.signal.cursors[1].set_xdata((p2, p2)) + ax_group.spectrogram.cursors[1].set_xdata((p2 * 1e-3, p2 * 1e-3)) + ax_group.signal.cursors[1].set_visible(True) + ax_group.spectrogram.cursors[1].set_visible(True) + if np.isnan(row['ipi_corr_man']): + ax_group.correlation.set_visible(False) else: - continue - break - else: - return - if j < 2: - pos = event.xdata * 10**(3*j) - self.view_data[i][0][1][self.curr_vert[i]].set_xdata((pos, pos)) - self.view_data[i][1][1][self.curr_vert[i]].set_xdata((pos*10**-3, pos*10**-3)) - self.view_data[i][0][1][self.curr_vert[i]].set_visible(True) - self.view_data[i][1][1][self.curr_vert[i]].set_visible(True) - if self.curr_vert[i] == 0: - self.df[self.offset[self.curr_ind[i],0]]['p1_pos'] = pos - self.curr_vert[i] ^= 1 - self.cursor[i].linev[0].set_linestyle(['--','-'][self.curr_vert[i]]) - self.cursor[i].linev[1].set_linestyle(['--','-'][self.curr_vert[i]]) - if self.view_data[i][0][1][1].get_visible(): - ipi_man = self.view_data[i][0][1][1].get_xdata()[0] - self.view_data[i][0][1][0].get_xdata()[0] - self.df[self.offset[self.curr_ind[i], 0]]['ipi_sig'] = ipi_man - self.view_ax[i][0].set_xlabel(f'Sig man:{ipi_man:.5f}') + ax_group.correlation.cursors[0].set_xdata((row['ipi_corr_man'], row['ipi_corr_man'])) + ax_group.correlation.set_visible(True) + if np.isnan(row['ipi_ceps_man']): + ax_group.cepstrum.set_visible(False) + else: + ax_group.cepstrum.cursors[0].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man'])) + ax_group.cepstrum.set_visible(True) + self.ind_select = False + self.ind_b.color = '0.85' + self.ind_b.howercolor = '0.95' + self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos / FSSR]["ind_number"]}') + click = self.song[ + max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):int(mpos * self.sr / FSSR + 10e-3 * self.sr)] + if len(click) != 2 * int(10e-3 * self.sr): + np.pad(click, (0, 2 * int(10e-3 * self.sr) - len(click)), mode='constant') + ax_group[0][0].set_ydata(norm(click)) + self._update_spectrogram(ax_group) + ax_group[2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):])) + ax_group[3][0].set_ydata( + norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3 * self.sr)]))) + self._set_label() + plt.draw() + + def _on_clicked_ax_group(self, event, ax_group, group_num, cell_num): + if cell_num < 2: + pos = event.xdata if cell_num else event.xdata * 1000 + cur_type = self.curr_vert[group_num] + ax_group.signal.cursors[cur_type].set_xdata((pos, pos)) + ax_group.spectrogram.cursors[cur_type].set_xdata((pos*10**-3, pos*10**-3)) + ax_group.signal.cursor[cur_type].set_visible(True) + ax_group.spectrogram.cursors[cur_type].set_visible(True) + if cur_type == 0: + self.df[self.offset[self.curr_ind[group_num], 0]]['p1_pos'] = pos + self.curr_vert[group_num] ^= 1 + self.cursor[group_num].linev[0].set_linestyle(['--', '-'][cur_type]) + self.cursor[group_num].linev[1].set_linestyle(['--', '-'][cur_type]) + if ax_group.signal.cursors[1].get_visible(): + ipi_man = ax_group.signal_ipi + self.df[self.offset[self.curr_ind[group_num], 0]]['ipi_sig'] = ipi_man + ax_group[0].set_xlabel(f'Sig man:{ipi_man:.5f}') else: - self.view_data[i][j][1].set_xdata((event.xdata, event.xdata)) - self.view_data[i][j][1].set_visible(True) - ipi_auto = np.argmax(self.view_data[i][j][0].get_data()[1][ - max(int(self.sr/1e3*(event.xdata-IPIPK)),0):int(self.sr/1e3*(event.xdata+IPIPK))])\ - + max(int(self.sr/1e3*(event.xdata-IPIPK)),0) - col = 'ipi_' + ('corr' if j == 2 else 'ceps') - self.df[self.offset[self.curr_ind[i],0]][col + '_man'] = event.xdata - self.df[self.offset[self.curr_ind[i],0]][col + '_auto'] = ipi_auto*1e3/self.sr - self.view_ax[i][j].set_xlabel(f'{"Corr" if j == 2 else "Ceps"} man:{event.xdata:.3f} auto:{ipi_auto*1e3/self.sr:.3f}') + cell = ax_group[cell_num] + cell.cursor.set_xdata((event.xdata, event.xdata)) + cell.cursor.set_visible(True) + lim_min = max(int(self.sr/1e3*(event.xdata-IPIPK)), 0) + ipi_auto = np.argmax(cell.axes.get_data()[1][lim_min:int(self.sr/1e3*(event.xdata+IPIPK))]) + lim_min + col = 'ipi_' + ('corr' if cell_num == 2 else 'ceps') + self.df[self.offset[self.curr_ind[group_num], 0]][col + '_man'] = event.xdata + self.df[self.offset[self.curr_ind[group_num], 0]][col + '_auto'] = ipi_auto*1e3/self.sr + cell.set_xlabel(f'{"Corr" if cell_num == 2 else "Ceps"} man:{event.xdata:.3f} ' + f'auto:{ipi_auto*1e3/self.sr:.3f}') plt.draw() + def on_clicked(self, event): + if event.inaxes == self.fax: + self._on_clicked_fax(event) + for i, ax_group in enumerate(self.view_data): # Look if a click plot was clicked and which one + for j in range(len(ax_group)): + if event.inaxes == ax_group[j].axes: + self._on_clicked_ax_group(event, ax_group, i, j) + return + def change_curr(self, label): self.curr = int(label[-1]) self.f_cursor.linev.set_color('rgb'[self.curr]) @@ -428,7 +420,7 @@ class Callback(object): def key_pressed(self, event): if self.ind_select: - row = self.df[self.offset[self.curr_ind[self.curr],0]] + row = self.df[self.offset[self.curr_ind[self.curr], 0]] if event.key == 'backspace': row['ind_number'] = row['ind_number'][:-1] elif event.key in ['shift', 'control', 'alt']: @@ -441,8 +433,6 @@ class Callback(object): if event.key in '012': self.change_curr(event.key) self.r_button.set_active(int(event.key)) - pass - def play(self, event): sound = (norm(self.song_resample[self.p:self.p+FSSR*20])*(2**15-1)).astype(np.int16) @@ -458,34 +448,40 @@ class Callback(object): self.fax.get_figure().set_constrained_layout(False) def onaxis(self, event): - if self.df[self.offset[self.curr_ind[self.curr], 0]]['onaxis'] == -1: - self.df[self.offset[self.curr_ind[self.curr], 0]]['onaxis'] = 1 + row = self.df[self.offset[self.curr_ind[self.curr], 0]] + if row['onaxis'] == -1: + row['onaxis'] = 1 else: - self.df[self.offset[self.curr_ind[self.curr], 0]]['onaxis'] ^= 1 + row['onaxis'] ^= 1 self.onaxis_b.label.set_text('On-axis' - if self.df[self.offset[self.curr_ind[self.curr], 0]]['onaxis'] else 'Off-axis') + if row['onaxis'] else 'Off-axis') plt.draw() def increase_freq(self, event): - self.view_ax[self.curr][1].set_ylim(0, self.view_ax[self.curr][1].get_ylim()[1]+1e3) + lim = self.view_data[0].specgram.get_ylim()[1] + 1e3 + for ax_group in self.view_data: + ax_group.specgram.set_ylim(0, lim) plt.draw() def decrease_freq(self, event): - self.view_ax[self.curr][1].set_ylim(0, max(self.view_ax[self.curr][1].get_ylim()[1]-1e3,1e3)) + lim = max(self.view_data[0].specgram.get_ylim()[1] + 1e3, 1e3) + for ax_group in self.view_data: + ax_group.specgram.set_ylim(0, lim) plt.draw() - def _update_spectrogram(self, num): - click = self.view_data[num][0][0].get_ydata() - spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1)[0])) - self.view_data[num][1][0].set_data(spec) - self.view_data[num][1][0].set_clim(spec.max()-SPSC, spec.max()) + def _update_spectrogram(self, ax_group): + click = ax_group.signal.axes.get_ydata() + spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1, + pad_to=2*self.nfft)[0])) + ax_group.spectrogram.im.set_data(spec) + ax_group.spectrogram.im.set_clim(spec.max() - SPSC, spec.max()) def increase_res(self, event): if self.nfft > int(10e-3*self.sr): return self.nfft *= 2 - for i in range(3): - self._update_spectrogram(i) + for ax_group in self.view_data: + self._update_spectrogram(ax_group) if self.nfft > int(10e-3*self.sr): self.spec_b['plus_res'].label.set_text('Can\'t go\nhigher') else: @@ -494,13 +490,13 @@ class Callback(object): plt.draw() def decrease_res(self, event): - if self.nfft <8: + if self.nfft < 8: return self.nfft //= 2 - for i in range(3): - self._update_spectrogram(i) + for ax_group in self.view_data: + self._update_spectrogram(ax_group) self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins') - if self.nfft <8: + if self.nfft < 8: self.spec_b['minus_res'].label.set_text('Can\'t go\nlower') else: self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins') @@ -511,9 +507,10 @@ class Callback(object): ind = self.curr if dic is None: dic = self.df[self.offset[self.curr_ind[ind], 0]] - self.view_ax[ind][0].set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}') - self.view_ax[ind][2].set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}') - self.view_ax[ind][3].set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}') + ax_group = self.view_data[ind] + ax_group[0].set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}') + ax_group[2].set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}') + ax_group[3].set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}') self.ind_b.label.set_text(f'Current individual:\n{dic["ind_number"]}') if self.onaxis_b is not None: self.onaxis_b.label.set_text(['Off', 'On', '?'][dic['onaxis']] + '-axis') @@ -521,12 +518,7 @@ class Callback(object): def _set_visible(self, ind=None, state=False): if ind is None: ind = self.curr - self.view_data[ind][0][1][0].set_visible(state) - self.view_data[ind][0][1][1].set_visible(state) - self.view_data[ind][1][1][0].set_visible(state) - self.view_data[ind][1][1][1].set_visible(state) - self.view_data[ind][2][1].set_visible(state) - self.view_data[ind][3][1].set_visible(state) + self.view_data[ind].set_visible(state) def reset_curr(self, event): self.df[self.offset[self.curr_ind[self.curr], 0]] = EMLN.copy() @@ -549,17 +541,17 @@ class Callback(object): self.offset = np.zeros((0, 2)) self.scat.set_offsets(self.offset) self.scat.set_color([[0, 0, 0, 1]]) - self.curr_ind = 3 * [None] # Ind of click for each plot - self.curr_vert = 3 * [0] # Current vertical line of sig/spec for each plot - for i in range(3): - self.view_data[i][0][0].set_ydata(np.zeros(int(20e-3 * sr))) - self.view_data[i][2][0].set_ydata(np.zeros(int(10e-3 * sr))) - self.view_data[i][3][0].set_ydata(np.zeros(int(10e-3 * sr))) + self.curr_ind = self.num_view * [None] # Ind of click for each plot + self.curr_vert = self.num_view * [0] # Current vertical line of sig/spec for each plot + for ax_group in self.view_data: + ax_group[0][0].set_ydata(np.zeros(int(20e-3 * sr))) + ax_group[2][0].set_ydata(np.zeros(int(10e-3 * sr))) + ax_group[3][0].set_ydata(np.zeros(int(10e-3 * sr))) if sr_update: - self.view_data[i][0][0].set_xdata(np.linspace(0, 20, int(20e-3*sr), False)) - self.view_data[i][2][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) - self.view_data[i][3][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) - self.view_data[i][1][0].set_clim(2000,2100) + ax_group[0][0].set_xdata(np.linspace(0, 20, int(20e-3*sr), False)) + ax_group[2][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) + ax_group[3][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False)) + ax_group[1][0].set_clim(2000, 2100) for i in range(3): self._set_label(i, EMLN) self._set_visible(i) @@ -570,12 +562,79 @@ class Callback(object): plt.draw() +class AxesWithCursor: + __slots__ = ['axes', 'im', 'cursors'] + + def __init__(self, axes: plt.Axes, im=None, cursors=None): + self.axes = axes + self.im = im + self.cursors = list() if cursors is None else cursors + + def set_visible(self, state): + for cursor in self.cursors: + cursor.set_visible(state) + + +class AxesGroup: + __slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum'] + + def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr): + self.signal = AxesWithCursor(ax_sig, + ax_sig.plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0], + (ax_sig.axvline(10, c='k', linestyle='--'), ax_sig.axvline(10, c='k'))) + ax_sig.set_xlim(0, 20) + ax_sig.set_xlabel('IPI man:None') + ax_sig.set_ylim(-1, 1) + ax_sig.set_yticks([]) + self.spectrogram = AxesWithCursor(ax_spec, + ax_spec.specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)), + Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1], + (ax_spec.axvline(0.01, c='k', linestyle='--'), ax_spec.axvline(0.01, c='k'))) + self.spectrogram.im.set_clim(2000, 2100) + ax_spec.set_ylim(0, min(20e3, sr/2)) + ax_spec.set_ytickslabels((ax_spec.get_yticks() / 1e3).astype(int)) + self.correlation = AxesWithCursor(ax_corr, + ax_corr.plot(np.linspace(0, 10, int(10e-3 * sr), False), + np.zeros(int(10e-3 * sr)))[0], + (ax_corr.axvline(10, c='k'),)) + ax_corr.set_xlim(0, 10) + ax_corr.set_xlabel('IPI man:None auto:None') + ax_corr.set_ylim(-1, 1) + ax_corr.set_yticks([]) + + self.cepstrum = AxesWithCursor(ax_cep, + ax_cep.plot(np.linspace(0, 10, int(10e-3 * sr), False), + np.zeros(int(10e-3 * sr))), + (ax_cep.axvline(10, c='k'),)) + ax_cep.set_xlim(0, 10) + ax_cep.set_xlabel('IPI man:None auto:None') + ax_cep.set_ylim(-1, 1) + ax_cep.set_yticks([]) + self.set_visible(False) + + def __getitem__(self, item): + return (self.signal, self.spectrogram, self.correlation, self.cepstrum)[item] + + def __len__(self): + return 4 + + def set_visible(self, state): + self.signal.set_visible(state) + self.spectrogram.set_visible(state) + self.correlation.set_visible(state) + self.cepstrum.set_visible(state) + + @property + def signal_ipi(self): + return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0] + + def init(in_path, channel, low=2e3, high=20e3): song, sr, song_resample = load_file(in_path, channel, low, high) fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True) gs = fig.add_gridspec(12, 20) - full_sig = plt.subplot(gs[:2, 1:-1]) + full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1]) # Type annotation needed with plt 3.1 callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0], song, song_resample, sr, full_sig) callback.fax = full_sig @@ -610,19 +669,18 @@ def init(in_path, channel, low=2e3, high=20e3): vs = 2 hfs = 3 hs = 1 - ax_view = [[plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]), - plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs])], - [plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]), - plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs])], - [plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]), - plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]), - plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs])]] - callback.view_ax = ax_view + ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]), + plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]), + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]), + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr), + AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]), + plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]), + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]), + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr), + AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]), + plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]), + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]), + plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr)] play_b_ax = plt.subplot(gs[-1, :2]) play_b = Button(play_b_ax, 'Play\ncurrent segment') @@ -659,60 +717,16 @@ def init(in_path, channel, low=2e3, high=20e3): spec_button = {'plus': freq_p_b, 'minus': freq_m_b, 'plus_res': freq_res_p_b, 'minus_res': freq_res_m_b} callback.spec_b = spec_button - data_view = [[2 * [None] for _ in range(4)] for _ in range(3)] - m_cursor = [None for _ in range(3)] - # m_cursor2 = [[None for _ in range(4)] for _ in range(3)] + m_cursor: List[Union[MyMultiCursor, None]] = len(ax_group) * [None] callback.cursor = m_cursor - for i in range(3): - data_view[i][0][1] = (ax_view[i][0].axvline(10, c='k', linestyle='--'), ax_view[i][0].axvline(10, c='k')) - data_view[i][0][0] = ax_view[i][0].plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0] - ax_view[i][0].set_xlim(0, 20) - ax_view[i][0].set_xlabel('IPI man:None') - ax_view[i][0].set_ylim(-1, 1) - data_view[i][0][1][0].set_visible(False) - data_view[i][0][1][1].set_visible(False) - - data_view[i][1][0] = ax_view[i][1].specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)), - Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1] - data_view[i][1][0].set_clim(2000,2100) - data_view[i][1][1] = (ax_view[i][1].axvline(0.01, c='k', linestyle='--'), ax_view[i][1].axvline(0.01, c='k')) - data_view[i][1][1][0].set_visible(False) - data_view[i][1][1][1].set_visible(False) - - data_view[i][2][1] = ax_view[i][2].axvline(10, c='k') - data_view[i][2][0] = ax_view[i][2].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[ - 0] - ax_view[i][2].set_xlim(0, 10) - ax_view[i][2].set_xlabel('IPI man:None auto:None') - ax_view[i][2].set_ylim(-1, 1) - data_view[i][2][1].set_visible(False) - - data_view[i][3][1] = ax_view[i][3].axvline(10, c='k') - data_view[i][3][0] = ax_view[i][3].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[ - 0] - ax_view[i][3].set_xlim(0, 10) - ax_view[i][3].set_xlabel('IPI man:None auto:None') - ax_view[i][3].set_ylim(0, 1) - data_view[i][3][1].set_visible(False) - - for j in range(4): - # m_cursor2[i][j] = Cursor(ax_view[i][j], horizOn=False, useblit=True, c='k') - if j != 1: - ax_view[i][j].set_yticks([]) - else: - ax_view[i][j].set_ylim(0, min(20e3, sr/2)) - ax_view[i][j].set_yticks(ax_view[i][j].get_yticks()) - ax_view[i][j].set_yticklabels((ax_view[i][j].get_yticks() / 1e3).astype(int)) - # m_cursor2[i][0].linev.set_linestyle('--') - # m_cursor2[i][1].linev.set_linestyle('--') - m_cursor[i] = MyMultiCursor(ax_view[i], data_view[i][0][1][0], horizOn=False, useblit=True, c='k') + for i in range(len(ax_group)): + m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], horizOn=False, useblit=True, c='k') m_cursor[i].linev[0].set_linestyle('--') m_cursor[i].linev[1].set_linestyle('--') - callback.view_data = data_view + callback.view_data = ax_group return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons': - {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button, - 'ind_b': ind_b, 'fs_click': cid1, 'key_pressed': cid2, 'reset_b': reset_b , - 'spec_button': spec_button}} # Needed to keep the callbacks alive + {'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button, + 'ind_b': ind_b, 'fs_click': cid1, 'key_pressed': cid2, 'reset_b': reset_b, 'spec_button': spec_button}} def reset(callback, in_path, channel, low=2e3, high=20e3): @@ -734,7 +748,7 @@ def main(args): return 1 EMLN['onaxis'] = -1 - ref_dict = init(args.input, args.channel) + ref_dict = init(args.input, args.channel, args.low, args.up) if args.resume: df = pd.read_hdf(outpath) if 'onaxis' not in df.columns: @@ -774,6 +788,8 @@ if __name__ == '__main__': parser.add_argument("input", type=str, help="Input file") parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'") parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.") + parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter') + parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter') group = parser.add_mutually_exclusive_group() group.add_argument("--erase", action='store_true', help="If out file exist and this option is not given," " the computation will be halted") @@ -800,6 +816,8 @@ if __name__ == '__main__': self.input = input("What is the input file path? ") self.out = input("What is the out file path? (Leave empty for default)") self.channel = int(input("Which channel do you want to use starting from 0? ")) + self.low = float(input("What is the lower frequency of the bandpass? ")) + self.up = float(input("What is the upper frequency of the bandpass? ")) self.erase = ask("Do you want to erase the out file if it already exist?") if not self.erase: self.resume = ask("Do you want to resume the out file if it already exist?")