Select Git revision
ipi_extract.py 24.04 KiB
import argparse
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as sg
import soundfile as sf
import os
from matplotlib.widgets import Button, Cursor, MultiCursor, CheckButtons, RadioButtons, AxesWidget, TextBox
from matplotlib.patches import Circle
from fractions import Fraction
from pydub import AudioSegment
from pydub.playback import play
import pandas as pd
FSSR = 48_000 # Sampling rate of full signal plot
FSPK = 0.1 # Max distance to detect a click in full sig in seconds
IPIPK= 0.15 # Max distance to detect a IPI in milliseconds
SPSC = 80 # Spectrogram scale
def read(file_path, always_2d=True):
try:
return sf.read(file_path, always_2d=always_2d)
except Exception as e:
return load_anysound(file_path)
def load_anysound(file_path):
tmp = AudioSegment.from_file(file_path)
return np.array(tmp.get_array_of_samples()).reshape(-1, tmp.channels), tmp.frame_rate
def norm(x):
return x/(np.abs(x).max()+1e-10)
def norm_std(x, alpha=1.5):
return x/(1.5*np.std(x)+1e-10)
class MyRadioButtons(RadioButtons):
def __init__(self, ax, labels, active=0, activecolor='blue', size=49,
orientation="vertical", **kwargs):
"""
Add radio buttons to an `~.axes.Axes`.
Parameters
----------
ax : `~matplotlib.axes.Axes`
The axes to add the buttons to.
labels : list of str
The button labels.
active : int
The index of the initially selected button.
activecolor : color
The color of the selected button.
size : float
Size of the radio buttons
orientation : str
The orientation of the buttons: 'vertical' (default), or 'horizontal'.
Further parameters are passed on to `Legend`.
"""
AxesWidget.__init__(self, ax)
self._activecolor = activecolor
axcolor = ax.get_facecolor()
self.value_selected = None
ax.set_xticks([])
ax.set_yticks([])
ax.set_navigate(False)
circles = []
for i, label in enumerate(labels):
if i == active:
self.value_selected = label
facecolor = self.activecolor
else:
facecolor = axcolor
p = ax.scatter([],[], s=size, marker="o", edgecolor='black',
facecolor=facecolor)
circles.append(p)
if orientation == "horizontal":
kwargs.update(ncol=len(labels), mode="expand")
kwargs.setdefault("frameon", False)
self.box = ax.legend(circles, labels, loc="center", **kwargs)
self.labels = self.box.texts
self.circles = self.box.legendHandles
for c in self.circles:
c.set_picker(5)
self.cnt = 0
self.observers = {}
self.connect_event('pick_event', self._clicked)
def _clicked(self, event):
if self.ignore(event) or event.mouseevent.button != 1 or event.mouseevent.inaxes != self.ax:
return
if event.artist in self.circles:
self.set_active(self.circles.index(event.artist))
@property
def activecolor(self):
if hasattr(self._activecolor, '__getitem__'):
return self._activecolor[int(self.value_selected[-1])]
else:
return self._activecolor
class MyMultiCursor(AxesWidget):
def __init__(self, axes, p1, horizOn=True, vertOn=True, useblit=False,
**lineprops):
AxesWidget.__init__(self, axes[0])
self.axes = axes
self.n_axes = len(axes)
self.p1 = p1
self.connect_event('motion_notify_event', self.onmove)
self.connect_event('draw_event', self.clear)
self.visible = True
self.horizOn = horizOn
self.vertOn = vertOn
self.useblit = useblit and self.canvas.supports_blit
if self.useblit:
lineprops['animated'] = True
#self.lineh = []
self.linev = []
for i in range(self.n_axes):
#self.lineh.append(axes[i].axhline(axes[i].get_ybound()[0], visible=False, **lineprops))
self.linev.append(axes[i].axvline(axes[i].get_xbound()[0], visible=False, **lineprops))
self.background = self.n_axes*[None]
self.needclear = False
def clear(self, event):
"""Internal event handler to clear the cursor."""
if self.ignore(event):
return
for i in range(self.n_axes):
if self.useblit:
self.background[i] = self.canvas.copy_from_bbox(self.axes[i].bbox)
self.linev[i].set_visible(False)
#self.lineh[i].set_visible(False)
def onmove(self, event):
"""Internal event handler to draw the cursor when the mouse moves."""
if self.ignore(event):
return
if not self.canvas.widgetlock.available(self):
return
if event.inaxes not in self.axes:
for i in range(self.n_axes):
self.linev[i].set_visible(False)
#self.lineh[i].set_visible(False)
if self.needclear:
self.canvas.draw()
self.needclear = False
return
self.needclear = True
if not self.visible:
return
ax_idx = self.axes.index(event.inaxes)
if ax_idx < 2:
if ax_idx == 1:
pos = event.xdata*1e3
else:
pos = event.xdata
if self.linev[0].get_linestyle() == '--':
self.linev[0].set_xdata((pos, pos))
self.linev[1].set_xdata((pos / 1e3, pos / 1e3))
self.linev[0].set_visible(self.visible and self.vertOn)
self.linev[1].set_visible(self.visible and self.vertOn)
self._update()
return
else:
ipi_pos = abs(pos - self.p1.get_xdata()[0])
for i in range(self.n_axes):
if i == 0:
self.linev[i].set_xdata((pos, pos))
elif i == 1:
self.linev[i].set_xdata((pos / 1e3, pos / 1e3))
else:
self.linev[i].set_xdata((ipi_pos, ipi_pos))
# self.lineh[i].set_ydata((event.ydata, event.ydata))
self.linev[i].set_visible(self.visible and self.vertOn)
# self.lineh[i].set_visible(self.visible and self.horizOn)
self._update()
return
else:
ipi_pos = event.xdata
pos = self.p1.get_xdata()[0] + ipi_pos
restore = False
for i in range(self.n_axes):
if self.p1.get_visible():
if i == 0:
if self.linev[0].get_linestyle() == '--':
restore = True
self.linev[0].set_linestyle('-')
self.linev[1].set_linestyle('-')
self.linev[i].set_xdata((pos, pos))
elif i == 1:
self.linev[i].set_xdata((pos/1e3, pos/1e3))
self.linev[i].set_visible(self.visible and self.vertOn)
if i > 1:
self.linev[i].set_xdata((ipi_pos, ipi_pos))
#self.lineh[i].set_ydata((event.ydata, event.ydata))
self.linev[i].set_visible(self.visible and self.vertOn)
#self.lineh[i].set_visible(self.visible and self.horizOn)
self._update()
if restore:
self.linev[0].set_linestyle('--')
self.linev[1].set_linestyle('--')
def _update(self):
if self.useblit:
for i in range(self.n_axes):
if self.background[i] is not None:
self.canvas.restore_region(self.background[i])
self.axes[i].draw_artist(self.linev[i])
#self.axes[i].draw_artist(self.lineh[i])
self.canvas.blit(self.axes[i].bbox)
else:
self.canvas.draw_idle()
return False
class Callback(object):
def __init__(self, line, song, song_resample, sr, full_ax):
self.p = 0
self.df = dict()
self.line = line
self.song = song
self.song_resample = song_resample
self.sr = sr
self.fax = full_ax
self.view_ax = None
self.view_data = None
self.curr = 0 # current view selected
self.scat = self.fax.scatter([],[], marker='x', c=[[0.7, 0.2, 0.5, 1]])
self.offset = np.zeros((0,2))
self.curr_ind = 3*[None] # Ind of click for each plot
self.curr_vert = 3*[0] # Current vertical line of sig/spec for each plot
self.cursor = None
self.f_cursor = None
def shift_left(self, event):
self.p = max(0, self.p - FSSR*13)
self._shit()
def shift_right(self, event):
self.p = min(len(self.song_resample) - FSSR*20, self.p + FSSR*13)
self._shit()
def _shit(self):
self.line.set_ydata(self.song_resample[self.p:self.p+FSSR*20])
lim = np.abs(self.song_resample[self.p:self.p+FSSR*20]).max()*1.2
self.fax.set_ylim(-lim, lim)
self.line.set_xdata(np.linspace(self.p/FSSR, self.p/FSSR+20, FSSR*20, False))
self.fax.set_xlim(self.p/FSSR, self.p/FSSR+20)
plt.draw()
def on_clicked(self, event):
if event.inaxes == self.fax: # Click on full signal plot
pos = int(FSSR*event.xdata)
mpos = np.argmax(self.song_resample[max(pos-int(FSSR*FSPK),0):pos+int(FSSR*FSPK)]) + max(pos-int(FSSR*FSPK),0)
if self.curr_ind[self.curr] is not None:
self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1]
if mpos/FSSR not in self.offset[:,0]:
self.offset = np.concatenate([self.offset, [[mpos/FSSR, self.song_resample[mpos]]]], axis=0)
self.scat.set_offsets(self.offset)
self.df[mpos/FSSR] = {'p1_pos':np.nan, 'ipi_sig': np.nan,
'ipi_corr_man': np.nan, 'ipi_corr_auto': np.nan,
'ipi_ceps_man': np.nan, 'ipi_ceps_auto': np.nan,
'ind_number': np.nan}
c = [[0, 0, 0, 1]]
c[0][self.curr] = 1
if len(self.offset) == 1:
self.scat.set_color(c)
else:
self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0))
self.curr_ind[self.curr] = len(self.offset) - 1
self.curr_vert[self.curr] = 0
self.cursor[self.curr].linev[0].set_linestyle('--')
self.cursor[self.curr].linev[1].set_linestyle('--')
for i in range(4):
if i < 2:
self.view_data[self.curr][i][1][0].set_visible(False)
self.view_data[self.curr][i][1][1].set_visible(False)
else:
self.view_data[self.curr][i][1].set_visible(False)
else:
self.curr_ind[self.curr] = np.argmax(mpos/FSSR == self.offset[:,0])
c = [0, 0, 0, 1]
k = 0
for i, v in enumerate(self.curr_ind):
if v == self.curr_ind[self.curr]:
c[i] = 1
k += 1
for i in range(3):
c[i] /=k
self.scat._facecolors[self.curr_ind[self.curr]] = c
self.scat.set_color(self.scat._facecolors)
row = self.df[mpos/FSSR]
if np.isnan(row['p1_pos']):
self.view_data[self.curr][0][1][0].set_visible(False)
self.view_data[self.curr][0][1][1].set_visible(False)
self.view_data[self.curr][1][1][0].set_visible(False)
self.view_data[self.curr][1][1][1].set_visible(False)
else:
self.view_data[self.curr][0][1][0].set_xdata((row['p1_pos'], row['p1_pos']))
self.view_data[self.curr][1][1][0].set_xdata((row['p1_pos']*1e-3, row['p1_pos']*1e-3))
self.view_data[self.curr][0][1][0].set_visible(True)
self.view_data[self.curr][1][1][0].set_visible(True)
if np.isnan(row['ipi_sig']):
self.view_data[self.curr][0][1][1].set_visible(False)
self.view_data[self.curr][1][1][1].set_visible(False)
else:
p2 = row['p1_pos'] + row['ipi_sig']
self.view_data[self.curr][0][1][1].set_xdata((p2, p2))
self.view_data[self.curr][1][1][1].set_xdata((p2*1e-3, p2*1e-3))
self.view_data[self.curr][0][1][1].set_visible(True)
self.view_data[self.curr][1][1][1].set_visible(True)
if np.isnan(row['ipi_corr_man']):
self.view_data[self.curr][2][1].set_visible(False)
else:
self.view_data[self.curr][2][1].set_xdata((row['ipi_corr_man'], row['ipi_corr_man']))
self.view_data[self.curr][2][1].set_visible(True)
if np.isnan(row['ipi_ceps_man']):
self.view_data[self.curr][3][1].set_visible(False)
else:
self.view_data[self.curr][3][1].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man']))
self.view_data[self.curr][3][1].set_visible(True)
click = self.song[max(int(mpos*self.sr/FSSR-10e-3*self.sr),0):int(mpos*self.sr/FSSR+10e-3*self.sr)]
if len(click) != 2*int(10e-3*self.sr):
np.pad(click, (0, 2*int(10e-3*self.sr) - len(click)), mode='constant')
self.view_data[self.curr][0][0].set_ydata(norm(click))
spec = np.flipud(20*np.log10(plt.mlab.specgram(click, Fs=self.sr, NFFT=128, noverlap=127)[0]))
self.view_data[self.curr][1][0].set_data(spec)
self.view_data[self.curr][1][0].set_clim(spec.max()-SPSC, spec.max())
self.view_data[self.curr][2][0].set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3*self.sr):]))
self.view_data[self.curr][3][0].set_ydata(norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3*self.sr)])))
plt.draw()
return
for i in range(3): # Look if a click plot was clicked and which one
for j in range(4):
if event.inaxes == self.view_ax[i][j]:
break
else:
continue
break
else:
return
if j < 2:
pos = event.xdata * 10**(3*j)
self.view_data[i][0][1][self.curr_vert[i]].set_xdata((pos, pos))
self.view_data[i][1][1][self.curr_vert[i]].set_xdata((pos*10**-3, pos*10**-3))
self.view_data[i][0][1][self.curr_vert[i]].set_visible(True)
self.view_data[i][1][1][self.curr_vert[i]].set_visible(True)
if self.curr_vert[i] == 0:
self.df[self.offset[self.curr_ind[i],0]]['p1_pos'] = pos
self.curr_vert[i] ^= 1
self.cursor[i].linev[0].set_linestyle(['--','-'][self.curr_vert[i]])
self.cursor[i].linev[1].set_linestyle(['--','-'][self.curr_vert[i]])
if self.view_data[i][0][1][1].get_visible():
ipi_man = self.view_data[i][0][1][1].get_xdata()[0] - self.view_data[i][0][1][0].get_xdata()[0]
self.df[self.offset[self.curr_ind[i], 0]]['ipi_sig'] = ipi_man
self.view_ax[i][0].set_xlabel(f'IPI man:{ipi_man:.5f}')
else:
self.view_data[i][j][1].set_xdata((event.xdata, event.xdata))
self.view_data[i][j][1].set_visible(True)
ipi_auto = np.argmax(self.view_data[i][j][0].get_data()[1][
max(int(self.sr/1e3*(event.xdata-IPIPK)),0):int(self.sr/1e3*(event.xdata+IPIPK))])\
+ max(int(self.sr/1e3*(event.xdata-IPIPK)),0)
col = 'ipi_' + ('corr' if j == 2 else 'ceps')
self.df[self.offset[self.curr_ind[i],0]][col + '_man'] = event.xdata
self.df[self.offset[self.curr_ind[i],0]][col + '_auto'] = ipi_auto*1e3/self.sr
self.view_ax[i][j].set_xlabel(f'IPI man:{event.xdata:.3f} auto:{ipi_auto*1e3/self.sr:.3f}')
plt.draw()
def change_curr(self, label):
self.curr = int(label[-1])
self.f_cursor.linev.set_color('rgb'[self.curr])
def play(self, event):
sound = (norm(self.song_resample[self.p:self.p+FSSR*20])*(2**15-1)).astype(np.int16)
try:
play(AudioSegment(sound.tobytes(), frame_rate=self.sr, sample_width=sound.dtype.itemsize, channels=1))
except KeyboardInterrupt:
pass
def resize(self, event):
self.fax.get_figure().set_constrained_layout(True)
plt.draw()
plt.pause(0.2)
self.fax.get_figure().set_constrained_layout(False)
def load_file(in_path, channel, low, high):
print(f'Loading and processing {in_path}')
song, sr = read(in_path, always_2d=True)
song = song[:, channel]
sos = sg.butter(3, [low, high], 'bandpass', fs=sr, output='sos')
song = sg.sosfiltfilt(sos, song)
frac = Fraction(FSSR, sr)
song_resample = sg.resample_poly(song, frac.numerator, frac.denominator)
print('Done processing')
return song, sr, song_resample
def init(in_path, channel, low=2e3, high=20e3):
song, sr, song_resample = load_file(in_path, channel, low, high)
fig = plt.figure('IPI', figsize=[16, 9], constrained_layout=True)
gs = fig.add_gridspec(12, 20)
full_sig = plt.subplot(gs[:2, 1:-1])
callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
song, song_resample, sr, full_sig)
callback.fax = full_sig
full_sig.set_xlim(0, 20)
lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2
full_sig.set_ylim(-lim, lim)
full_sig.set_yticks([])
callback.f_cursor = Cursor(full_sig, horizOn=False, useblit=True, c='r')
cid = fig.canvas.mpl_connect('button_release_event', callback.on_clicked)
b_left_ax = plt.subplot(gs[:2, 0])
b_right_ax = plt.subplot(gs[:2, -1])
b_left = Button(b_left_ax, '<|')
b_right = Button(b_right_ax, '|>')
b_left.on_clicked(callback.shift_left)
b_right.on_clicked(callback.shift_right)
r_button_ax = plt.subplot(gs[10, 1:-1])
r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(3)], orientation='horizontal',
size=128, activecolor=list('rgb')) # !Last char of labels is use as index in rest of code
r_button_ax.axis('off')
r_button.on_clicked(callback.change_curr)
for i, c in enumerate('rgb'):
r_button.labels[i].set_c(c)
# c_button_ax = plt.subplot(gs[10,3:6])
# c_button = CheckButtons(c_button_ax, [f'Save {i}' for i in range(3)], [False for i in range(3)])
vfs = 4
vs = 2
hfs = 3
hs = 1
ax_view = [[plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]),
plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs])],
[plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]),
plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs])],
[plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]),
plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs])]]
callback.view_ax = ax_view
play_b_ax = plt.subplot(gs[-1, :2])
play_b = Button(play_b_ax, 'Play\ncurrent segment')
play_b.on_clicked(callback.play)
resize_b_ax = plt.subplot(gs[-1, 2:4])
resize_b = Button(resize_b_ax, 'Resize plot')
resize_b.on_clicked(callback.resize)
# text_b_ax = plt.subplot(gs[-1,4:6])
# text_b = TextBox(text_b_ax, 'Individue #\nof current')
# # text_b.on_clicked(callback.resize)
data_view = [[2 * [None] for _ in range(4)] for _ in range(3)]
m_cursor = [None for _ in range(3)]
# m_cursor2 = [[None for _ in range(4)] for _ in range(3)]
callback.cursor = m_cursor
for i in range(3):
data_view[i][0][1] = (ax_view[i][0].axvline(10, c='k', linestyle='--'), ax_view[i][0].axvline(10, c='k'))
data_view[i][0][0] = ax_view[i][0].plot(np.linspace(0, 20, int(20e-3 * sr), False), np.zeros(int(20e-3 * sr)))[
0]
ax_view[i][0].set_xlim(0, 20)
ax_view[i][0].set_xlabel('IPI man:None')
ax_view[i][0].set_ylim(-1, 1)
data_view[i][0][1][0].set_visible(False)
data_view[i][0][1][1].set_visible(False)
data_view[i][1][0] = ax_view[i][1].specgram(np.random.normal(0, 1, int(20e-3 * sr)),
Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1]
data_view[i][1][1] = (ax_view[i][1].axvline(0.01, c='k', linestyle='--'), ax_view[i][1].axvline(0.01, c='k'))
data_view[i][1][1][0].set_visible(False)
data_view[i][1][1][1].set_visible(False)
data_view[i][2][1] = ax_view[i][2].axvline(10, c='k')
data_view[i][2][0] = ax_view[i][2].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[
0]
ax_view[i][2].set_xlim(0, 10)
ax_view[i][2].set_xlabel('IPI man:None auto:None')
ax_view[i][2].set_ylim(-1, 1)
data_view[i][2][1].set_visible(False)
data_view[i][3][1] = ax_view[i][3].axvline(10, c='k')
data_view[i][3][0] = ax_view[i][3].plot(np.linspace(0, 10, int(10e-3 * sr), False), np.zeros(int(10e-3 * sr)))[
0]
ax_view[i][3].set_xlim(0, 10)
ax_view[i][3].set_xlabel('IPI man:None auto:None')
ax_view[i][3].set_ylim(0, 1)
data_view[i][3][1].set_visible(False)
for j in range(4):
# m_cursor2[i][j] = Cursor(ax_view[i][j], horizOn=False, useblit=True, c='k')
if j != 1:
ax_view[i][j].set_yticks([])
else:
ax_view[i][j].set_yticklabels((ax_view[i][j].get_yticks() / 1e3).astype(int))
# m_cursor2[i][0].linev.set_linestyle('--')
# m_cursor2[i][1].linev.set_linestyle('--')
m_cursor[i] = MyMultiCursor(ax_view[i], data_view[i][0][1][0], horizOn=False, useblit=True, c='k')
m_cursor[i].linev[0].set_linestyle('--')
m_cursor[i].linev[1].set_linestyle('--')
callback.view_data = data_view
plt.draw()
plt.pause(0.2)
fig.set_constrained_layout(False)
return {'callback': callback, 'fig': fig, 'buttons':
{'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
'fs_click': cid}} # Needed to keep the callbacks alive
def reset(in_path, channel, low=2e3, high=20e3):
song, sr, song_resample = load_file(in_path, channel, low, high)
def main(args):
if args.out == '':
outpath = args.out.rsplit('.', 1)[0] + '.pred.h5'
else:
outpath = args.out
if os.path.isfile(outpath) and not args.erase:
return 1
ref_dict = init(args.input, args.channel)
plt.show()
df = pd.DataFrame.from_dict(ref_dict['callback'].df, orient='index')
df.to_hdf(outpath, 'df')
return 0
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("input", type=str, help="Input file")
parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
parser.add_argument("--erase", action='store_true', help="If out file exist and this option is not given,"
" the computation will be halted")
args = parser.parse_args()
main(args)