Skip to content
Snippets Groups Projects
Commit cd277078 authored by maxence's avatar maxence
Browse files

Add number of pulses argument

parent 05fea75c
No related branches found
No related tags found
No related merge requests found
......@@ -123,36 +123,41 @@ class MyRadioButtons(RadioButtons):
class MyMultiCursor(AxesWidget):
def __init__(self, axes, p1, horizOn=True, vertOn=True, useblit=False,
def __init__(self, axes, p1, num_cur, vertOn=True, useblit=False,
**lineprops):
AxesWidget.__init__(self, axes[0])
self.axes = axes
self.n_axes = len(axes)
self.p1 = p1
self.num_cur = num_cur
self.connect_event('motion_notify_event', self.onmove)
self.connect_event('draw_event', self.clear)
self.visible = True
self.horizOn = horizOn
self.vertOn = vertOn
self.useblit = useblit and self.canvas.supports_blit
if self.useblit:
lineprops['animated'] = True
self.linev = []
for i in range(self.n_axes):
self.linev.append(axes[i].axes.axvline(axes[i].axes.get_xbound()[0], visible=False, **lineprops))
for ax in self.axes:
for line in ax.cursors:
if line.get_linestyle() == '--':
continue
self.linev.append(ax.axes.axvline(ax.axes.get_xbound()[0], visible=False, **lineprops))
self.linev[0].set_linestyle('--')
self.linev[self.num_cur].set_linestyle('--')
self.n_axes = len(self.linev)
self.background = self.n_axes*[None]
self.background = len(self.axes)*[None]
self.needclear = False
def clear(self, event):
"""Internal event handler to clear the cursor."""
if self.ignore(event):
return
for i in range(self.n_axes):
for i, ax in enumerate(self.axes):
if self.useblit:
self.background[i] = self.canvas.copy_from_bbox(self.axes[i].axes.bbox)
self.background[i] = self.canvas.copy_from_bbox(ax.axes.bbox)
self.linev[i].set_visible(False)
def onmove(self, event):
......@@ -181,18 +186,18 @@ class MyMultiCursor(AxesWidget):
pos = event.xdata
if self.linev[0].get_linestyle() == '--':
self.linev[0].set_xdata((pos, pos))
self.linev[1].set_xdata((pos / 1e3, pos / 1e3))
self.linev[self.num_cur].set_xdata((pos / 1e3, pos / 1e3))
self.linev[0].set_visible(self.visible and self.vertOn)
self.linev[1].set_visible(self.visible and self.vertOn)
self.linev[self.num_cur].set_visible(self.visible and self.vertOn)
self._update()
return
else:
ipi_pos = abs(pos - self.p1.get_xdata()[0])
for i in range(self.n_axes):
if i == 0:
self.linev[i].set_xdata((pos, pos))
elif i == 1:
self.linev[i].set_xdata((pos / 1e3, pos / 1e3))
if i < self.num_cur:
self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
elif i < 2*self.num_cur:
self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
else:
self.linev[i].set_xdata((ipi_pos, ipi_pos))
self.linev[i].set_visible(self.visible and self.vertOn)
......@@ -202,33 +207,36 @@ class MyMultiCursor(AxesWidget):
ipi_pos = event.xdata
pos = self.p1.get_xdata()[0] + ipi_pos
restore = False
for i in range(self.n_axes):
if self.p1.get_visible():
if i == 0:
if self.linev[0].get_linestyle() == '--':
restore = True
self.linev[0].set_linestyle('-')
self.linev[1].set_linestyle('-')
self.linev[i].set_xdata((pos, pos))
elif i == 1:
self.linev[i].set_xdata((pos/1e3, pos/1e3))
self.linev[self.num_cur].set_linestyle('-')
for i in range(self.n_axes):
if self.p1.get_visible():
if i < self.num_cur:
self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
elif i < 2 * self.num_cur:
self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
self.linev[i].set_visible(self.visible and self.vertOn)
if i > 1:
if i >= 2 * self.num_cur:
self.linev[i].set_xdata((ipi_pos, ipi_pos))
self.linev[i].set_visible(self.visible and self.vertOn)
self._update()
if restore:
self.linev[0].set_linestyle('--')
self.linev[1].set_linestyle('--')
self.linev[self.num_cur].set_linestyle('--')
def _update(self):
if self.useblit:
for i in range(self.n_axes):
if self.background[i] is not None:
self.canvas.restore_region(self.background[i])
self.axes[i].axes.draw_artist(self.linev[i])
self.canvas.blit(self.axes[i].axes.bbox)
k = 0
for bk, ax in zip(self.background, self.axes):
if bk is not None:
self.canvas.restore_region(bk)
for k in range(k, k + self.num_cur if k < 2*self.num_cur else k+1):
ax.axes.draw_artist(self.linev[k])
k += 1
self.canvas.blit(ax.axes.bbox)
else:
self.canvas.draw_idle()
return False
......@@ -355,10 +363,10 @@ class Callback(object):
max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):int(mpos * self.sr / FSSR + 10e-3 * self.sr)]
if len(click) != 2 * int(10e-3 * self.sr):
np.pad(click, (0, 2 * int(10e-3 * self.sr) - len(click)), mode='constant')
ax_group[0][0].set_ydata(norm(click))
ax_group.signal.im.set_ydata(norm(click))
self._update_spectrogram(ax_group)
ax_group[2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):]))
ax_group[3][0].set_ydata(
ax_group.correlation.im.set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):]))
ax_group.cepstrum.im.set_ydata(
norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3 * self.sr)])))
self._set_label()
plt.draw()
......@@ -472,7 +480,7 @@ class Callback(object):
plt.draw()
def _update_spectrogram(self, ax_group):
click = ax_group.signal.axes.get_ydata()
click = ax_group.signal.im.get_ydata()
spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1,
pad_to=2*self.nfft)[0]))
ax_group.spectrogram.im.set_data(spec)
......@@ -510,9 +518,9 @@ class Callback(object):
if dic is None:
dic = self.df[self.offset[self.curr_ind[ind], 0]]
ax_group = self.view_data[ind]
ax_group[0].set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
ax_group[2].set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
ax_group[3].set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
ax_group.signal.axes.set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
ax_group.correlation.axes.set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
ax_group.cepstrum.axes.set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
self.ind_b.label.set_text(f'Current individual:\n{dic["ind_number"]}')
if self.onaxis_b is not None:
self.onaxis_b.label.set_text(['Off', 'On', '?'][dic['onaxis']] + '-axis')
......@@ -584,10 +592,11 @@ class AxesWithCursor:
class AxesGroup:
__slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum']
def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr):
def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2):
self.signal = AxesWithCursor(ax_sig,
ax_sig.plot(np.linspace(0, 20, int(20e-3*sr), False), np.zeros(int(20e-3*sr)))[0],
(ax_sig.axvline(10, c='k', linestyle='--'), ax_sig.axvline(10, c='k')))
(ax_sig.axvline(10, c='k', linestyle='--'),) +
tuple(ax_sig.axvline(10, c='k') for _ in range(num_cur)))
ax_sig.set_xlim(0, 20)
ax_sig.set_xlabel('IPI man:None')
ax_sig.set_ylim(-1, 1)
......@@ -595,7 +604,8 @@ class AxesGroup:
self.spectrogram = AxesWithCursor(ax_spec,
ax_spec.specgram(np.random.normal(0, 1e-6, int(20e-3 * sr)),
Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1],
(ax_spec.axvline(0.01, c='k', linestyle='--'), ax_spec.axvline(0.01, c='k')))
(ax_spec.axvline(0.01, c='k', linestyle='--'),) +
tuple(ax_spec.axvline(0.01, c='k') for _ in range(num_cur)))
self.spectrogram.im.set_clim(2000, 2100)
ax_spec.set_ylim(0, min(20e3, sr/2))
ax_spec.set_yticklabels((ax_spec.get_yticks() / 1e3).astype(int))
......@@ -610,11 +620,11 @@ class AxesGroup:
self.cepstrum = AxesWithCursor(ax_cep,
ax_cep.plot(np.linspace(0, 10, int(10e-3 * sr), False),
np.zeros(int(10e-3 * sr))),
np.zeros(int(10e-3 * sr)))[0],
(ax_cep.axvline(10, c='k'),))
ax_cep.set_xlim(0, 10)
ax_cep.set_xlabel('IPI man:None auto:None')
ax_cep.set_ylim(-1, 1)
ax_cep.set_ylim(0, 0.33)
ax_cep.set_yticks([])
self.set_visible(False)
......@@ -641,7 +651,7 @@ class AxesGroup:
return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0]
def init(in_path, channel, low=2e3, high=20e3):
def init(in_path, channel, low=2e3, high=20e3, num_cur=1):
song, sr, song_resample = load_file(in_path, channel, low, high)
fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
gs = fig.add_gridspec(12, 20)
......@@ -684,15 +694,15 @@ def init(in_path, channel, low=2e3, high=20e3):
ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]),
plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr, num_cur),
AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]),
plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr, num_cur),
AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]),
plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr)]
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr, num_cur)]
play_b_ax = plt.subplot(gs[-1, :2])
play_b = Button(play_b_ax, 'Play\ncurrent segment')
......@@ -732,9 +742,7 @@ def init(in_path, channel, low=2e3, high=20e3):
m_cursor: List[Union[MyMultiCursor, None]] = len(ax_group) * [None]
callback.cursor = m_cursor
for i in range(len(ax_group)):
m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], horizOn=False, useblit=True, c='k')
m_cursor[i].linev[0].set_linestyle('--')
m_cursor[i].linev[1].set_linestyle('--')
m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], num_cur, useblit=True, c='k')
callback.view_data = ax_group
return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons':
{'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
......@@ -758,9 +766,12 @@ def main(args):
elif args.resume:
print(f'Out file {outpath} does not already exist and resume option is set.')
return 1
if args.num < 1:
print(f'{args.num} is an invalid number of pulses.')
return 1
EMLN['onaxis'] = -1
ref_dict = init(args.input, args.channel, args.low, args.up)
ref_dict = init(args.input, args.channel, args.low, args.up, args.num)
if args.resume:
df = pd.read_hdf(outpath)
if 'onaxis' not in df.columns:
......@@ -800,6 +811,7 @@ if __name__ == '__main__':
parser.add_argument("input", type=str, help="Input file")
parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
parser.add_argument("--num", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter')
parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter')
group = parser.add_mutually_exclusive_group()
......@@ -828,6 +840,7 @@ if __name__ == '__main__':
self.input = input("What is the input file path? ")
self.out = input("What is the out file path? (Leave empty for default)")
self.channel = int(input("Which channel do you want to use starting from 0? "))
self.num = int(input("How many pulse do you want to display (> 1) "))
self.low = float(input("What is the lower frequency of the bandpass? "))
self.up = float(input("What is the upper frequency of the bandpass? "))
self.erase = ask("Do you want to erase the out file if it already exist?")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment