Select Git revision
gettingStarted.md
-
Franck Dary authoredFranck Dary authored
ipi_extract.py 37.21 KiB
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import scipy.signal as sg
import soundfile as sf
import os
import sys
from matplotlib.widgets import Button, Cursor, RadioButtons, AxesWidget
from fractions import Fraction
from pydub import AudioSegment
from pydub.playback import play
import pandas as pd
from typing import List, Union
FSSR = 48_000 # Sampling rate of full signal plot
FSPK = 0.1 # Max distance to detect a click in full sig in seconds
IPIPK = 0.15 # Max distance to detect a IPI in milliseconds
SPSC = 80 # Spectrogram scale
EMLN = {'p1_pos': np.nan, 'ipi_sig': np.nan,
'ipi_corr_man': np.nan, 'ipi_corr_auto': np.nan,
'ipi_ceps_man': np.nan, 'ipi_ceps_auto': np.nan,
'ind_number': np.nan} # Empty dataline
def read(file_path, always_2d=True):
try:
return sf.read(file_path, always_2d=always_2d)
except Exception as e:
return load_anysound(file_path)
def load_anysound(file_path):
tmp = AudioSegment.from_file(file_path)
return np.array(tmp.get_array_of_samples()).reshape(-1, tmp.channels), tmp.frame_rate
def load_file(in_path, channel, low, high):
print(f'Loading and processing {in_path}')
song, sr = read(in_path, always_2d=True)
song = song[:, channel]
sos = sg.butter(3, [low, high], 'bandpass', fs=sr, output='sos')
song = sg.sosfiltfilt(sos, song)
if len(song) < 20*sr:
song = np.pad(song, (1, 20*sr), mode='constant')
frac = Fraction(FSSR, sr)
song_resample = sg.resample_poly(song, frac.numerator, frac.denominator)
print('Done processing')
return song, sr, song_resample
def norm(x):
return x/(np.abs(x).max()+1e-10)
def norm_std(x, alpha=1.5):
return x/(1.5*np.std(x)+1e-10)
class MyRadioButtons(RadioButtons):
def __init__(self, ax, labels, active=0, activecolor='blue', size=49,
orientation="vertical", **kwargs):
"""
Add radio buttons to an `~.axes.Axes`.
Parameters
----------
ax : `~matplotlib.axes.Axes`
The axes to add the buttons to.
labels : list of str
The button labels.
active : int
The index of the initially selected button.
activecolor : color
The color of the selected button.
size : float
Size of the radio buttons
orientation : str
The orientation of the buttons: 'vertical' (default), or 'horizontal'.
Further parameters are passed on to `Legend`.
"""
AxesWidget.__init__(self, ax)
self._activecolor = activecolor
axcolor = ax.get_facecolor()
self.value_selected = None
ax.set_xticks([])
ax.set_yticks([])
ax.set_navigate(False)
circles = []
for i, label in enumerate(labels):
if i == active:
self.value_selected = label
facecolor = self.activecolor
else:
facecolor = axcolor
p = ax.scatter([],[], s=size, marker="o", edgecolor='black',
facecolor=facecolor)
circles.append(p)
if orientation == "horizontal":
kwargs.update(ncol=len(labels), mode="expand")
kwargs.setdefault("frameon", False)
self.box = ax.legend(circles, labels, loc="center", **kwargs)
self.labels = self.box.texts
self.circles = self.box.legendHandles
for c in self.circles:
c.set_picker(5)
self.cnt = 0
self.observers = {}
self.connect_event('pick_event', self._clicked)
def _clicked(self, event):
if self.ignore(event) or event.mouseevent.button != 1 or event.mouseevent.inaxes != self.ax:
return
if event.artist in self.circles:
self.set_active(self.circles.index(event.artist))
@property
def activecolor(self):
if hasattr(self._activecolor, '__getitem__'):
return self._activecolor[int(self.value_selected[-1])]
else:
return self._activecolor
class MyMultiCursor(AxesWidget):
def __init__(self, axes, p1, num_cur, vertOn=True, useblit=False,
**lineprops):
AxesWidget.__init__(self, axes[0])
self.axes = axes
self.p1 = p1
self.num_cur = num_cur
self.connect_event('motion_notify_event', self.onmove)
self.connect_event('draw_event', self.clear)
self.visible = True
self.vertOn = vertOn
self.useblit = useblit and self.canvas.supports_blit
if self.useblit:
lineprops['animated'] = True
self.linev = []
for ax in self.axes:
for line in ax.cursors:
if line.get_linestyle() == '--':
continue
self.linev.append(ax.axes.axvline(ax.axes.get_xbound()[0], visible=False, **lineprops))
self.linev[0].set_linestyle('--')
self.linev[self.num_cur].set_linestyle('--')
self.n_axes = len(self.linev)
self.background = len(self.axes)*[None]
self.needclear = False
def clear(self, event):
"""Internal event handler to clear the cursor."""
if self.ignore(event):
return
for i, ax in enumerate(self.axes):
if self.useblit:
self.background[i] = self.canvas.copy_from_bbox(ax.axes.bbox)
self.linev[i].set_visible(False)
def onmove(self, event):
"""Internal event handler to draw the cursor when the mouse moves."""
if self.ignore(event):
return
if not self.canvas.widgetlock.available(self):
return
if event.inaxes not in self.axes:
for i in range(self.n_axes):
self.linev[i].set_visible(False)
if self.needclear:
self.canvas.draw()
self.needclear = False
return
self.needclear = True
if not self.visible:
return
for ax_idx, ax in enumerate(self.axes):
if ax.axes == event.inaxes:
break
if ax_idx < 2:
if ax_idx == 1:
pos = event.xdata*1e3
else:
pos = event.xdata
if self.linev[0].get_linestyle() == '--':
self.linev[0].set_xdata((pos, pos))
self.linev[self.num_cur].set_xdata((pos / 1e3, pos / 1e3))
self.linev[0].set_visible(self.visible and self.vertOn)
self.linev[self.num_cur].set_visible(self.visible and self.vertOn)
self._update()
return
else:
ipi_pos = abs(pos - self.p1.get_xdata()[0])
for i in range(self.n_axes):
if i < self.num_cur:
self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
elif i < 2*self.num_cur:
self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
else:
self.linev[i].set_xdata((ipi_pos, ipi_pos))
self.linev[i].set_visible(self.visible and self.vertOn)
self._update()
return
else:
ipi_pos = event.xdata
pos = self.p1.get_xdata()[0] + ipi_pos
restore = False
if self.linev[0].get_linestyle() == '--':
restore = True
self.linev[0].set_linestyle('-')
self.linev[self.num_cur].set_linestyle('-')
for i in range(self.n_axes):
if self.p1.get_visible():
if i < self.num_cur:
self.linev[i].set_xdata(2*(pos + i*ipi_pos,))
elif i < 2 * self.num_cur:
self.linev[i].set_xdata(2*((pos + (i-self.num_cur)*ipi_pos) / 1e3,))
self.linev[i].set_visible(self.visible and self.vertOn)
if i >= 2 * self.num_cur:
self.linev[i].set_xdata((ipi_pos, ipi_pos))
self.linev[i].set_visible(self.visible and self.vertOn)
self._update()
if restore:
self.linev[0].set_linestyle('--')
self.linev[self.num_cur].set_linestyle('--')
def _update(self):
if self.useblit:
k = 0
for bk, ax in zip(self.background, self.axes):
if bk is not None:
self.canvas.restore_region(bk)
for k in range(k, k + self.num_cur if k < 2*self.num_cur else k+1):
ax.axes.draw_artist(self.linev[k])
k += 1
self.canvas.blit(ax.axes.bbox)
else:
self.canvas.draw_idle()
return False
class Callback(object):
def __init__(self, line, song, song_resample, sr, full_ax, num_view, after_length):
self.p = 0
self.df = dict()
self.line = line
self.song = song
self.song_resample = song_resample
self.sr = sr
self.fax = full_ax
self.num_view = num_view
self.after_length = after_length
self.view_data: [AxesGroup, None] = None
self.curr = 0 # current view selected
self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]])
self.offset = np.zeros((0, 2))
self.curr_ind: List[Union[int, None]] = num_view*[None] # Indices of click for each plot
self.curr_vert = num_view*[0] # Current vertical line of sig/spec for each plot
self.cursor = None
self.f_cursor = None
self.reset_b = None
self.r_button = None
self.ind_b = None
self.onaxis_b = None
self.spec_b = None
self.nfft = 128
self.ind_select = False
def shift_left(self, event):
self.p = max(0, self.p - FSSR*13)
self._shit()
def shift_right(self, event):
self.p = min(len(self.song_resample) - FSSR*20, self.p + FSSR*13)
self._shit()
def _shit(self):
self.line.set_ydata(self.song_resample[self.p:self.p+FSSR*20])
lim = np.abs(self.song_resample[self.p:self.p+FSSR*20]).max()*1.2
self.fax.set_ylim(-lim, lim)
self.line.set_xdata(np.linspace(self.p/FSSR, self.p/FSSR+20, FSSR*20, False))
self.fax.set_xlim(self.p/FSSR, self.p/FSSR+20)
plt.draw()
def _on_clicked_fax(self, event):
# Click on full signal plot
pos = int(FSSR * event.xdata)
mpos = np.argmax(self.song_resample[max(pos - int(FSSR * FSPK), 0):
pos + int(FSSR * FSPK)]) + max(pos - int(FSSR * FSPK), 0)
ax_group = self.view_data[self.curr]
if self.curr_ind[self.curr] is not None:
if np.all(pd.isnull(list(self.df[self.offset[self.curr_ind[self.curr], 0]].values()))):
self.scat._facecolors[self.curr_ind[self.curr]] = [0.7, 0.2, 0.5, 1]
else:
self.scat._facecolors[self.curr_ind[self.curr]] = [0, 0, 0, 1]
if mpos / FSSR not in self.offset[:, 0]:
self.offset = np.concatenate([self.offset, [[mpos / FSSR, self.song_resample[mpos]]]], axis=0)
self.scat.set_offsets(self.offset)
self.df[mpos / FSSR] = EMLN.copy()
c = [[0, 0, 0, 1]]
c[0][self.curr] = 1
if len(self.offset) == 1:
self.scat.set_color(c)
else:
self.scat.set_color(np.concatenate([self.scat._facecolors, c], axis=0))
self.curr_ind[self.curr] = len(self.offset) - 1
self.curr_vert[self.curr] = 0
self.cursor[self.curr].linev[0].set_linestyle('--')
self.cursor[self.curr].linev[self.cursor[self.curr].num_cur].set_linestyle('--')
ax_group.set_visible(False)
else:
self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0])
c = np.array([0, 0, 0, 0.])
k = 0
for i, v in enumerate(self.curr_ind):
if v == self.curr_ind[self.curr]:
c += to_rgba('rgbkcmyrgbkcmy'[i])
k += 1
for i in range(self.num_view):
c /= k
self.scat._facecolors[self.curr_ind[self.curr]] = c
self.scat.set_color(self.scat._facecolors)
row = self.df[mpos / FSSR]
if np.isnan(row['p1_pos']):
ax_group.signal.set_visible(False)
ax_group.spectrogram.set_visible(False)
else:
ax_group.signal.cursors[0].set_xdata((row['p1_pos'], row['p1_pos']))
ax_group.spectrogram.cursors[0].set_xdata((row['p1_pos'] * 1e-3, row['p1_pos'] * 1e-3))
ax_group.signal.cursors[0].set_visible(True)
ax_group.spectrogram.cursors[0].set_visible(True)
if np.isnan(row['ipi_sig']):
ax_group.signal.cursors[1].set_visible(False)
ax_group.spectrogram.cursors[1].set_visible(False)
else:
p2 = row['p1_pos'] + row['ipi_sig']
ax_group.signal.cursors[1].set_xdata((p2, p2))
ax_group.spectrogram.cursors[1].set_xdata((p2 * 1e-3, p2 * 1e-3))
ax_group.signal.cursors[1].set_visible(True)
ax_group.spectrogram.cursors[1].set_visible(True)
if np.isnan(row['ipi_corr_man']):
ax_group.correlation.set_visible(False)
else:
ax_group.correlation.cursors[0].set_xdata((row['ipi_corr_man'], row['ipi_corr_man']))
ax_group.correlation.set_visible(True)
if np.isnan(row['ipi_ceps_man']):
ax_group.cepstrum.set_visible(False)
else:
ax_group.cepstrum.cursors[0].set_xdata((row['ipi_ceps_man'], row['ipi_ceps_man']))
ax_group.cepstrum.set_visible(True)
self.ind_select = False
self.ind_b.color = '0.85'
self.ind_b.howercolor = '0.95'
self.ind_b.label.set_text(f'Current individual:\n{self.df[mpos / FSSR]["ind_number"]}')
click = self.song[max(int(mpos * self.sr / FSSR - 10e-3 * self.sr), 0):
int(mpos * self.sr / FSSR + self.after_length * self.sr)]
if len(click) < 2 * int((10e-3 + self.after_length) * self.sr):
np.pad(click, (0, 2 * int((10e-3 + self.after_length) * self.sr) - len(click)), mode='constant')
ax_group.signal.im.set_ydata(norm(click))
self._update_spectrogram(ax_group)
ax_group.correlation.im.set_ydata(norm(np.correlate(click, click, 'same')[-int(10e-3 * self.sr):]))
ax_group.cepstrum.im.set_ydata(
norm_std(np.abs(np.fft.irfft(np.log10(np.abs(np.fft.rfft(click))))[:int(10e-3 * self.sr)])))
self._set_label()
plt.draw()
def _on_clicked_ax_group(self, event, ax_group, group_num, cell_num):
if cell_num < 2:
pos = event.xdata if cell_num == 0 else event.xdata * 1000
cur_type = self.curr_vert[group_num]
ax_group.signal.cursors[cur_type].set_xdata(2*(pos,))
ax_group.spectrogram.cursors[cur_type].set_xdata(2*(pos*10**-3,))
ax_group.signal.cursors[cur_type].set_visible(True)
ax_group.spectrogram.cursors[cur_type].set_visible(True)
if cur_type == 0:
self.df[self.offset[self.curr_ind[group_num], 0]]['p1_pos'] = pos
self.curr_vert[group_num] ^= 1
cur_type ^= 1
self.cursor[group_num].linev[0].set_linestyle(['--', '-'][cur_type])
self.cursor[group_num].linev[self.cursor[group_num].num_cur].set_linestyle(['--', '-'][cur_type])
if ax_group.signal.cursors[1].get_visible():
ipi_man = ax_group.signal_ipi
self.df[self.offset[self.curr_ind[group_num], 0]]['ipi_sig'] = ipi_man
ax_group.signal.axes.set_xlabel(f'Sig man:{ipi_man:.5f}')
else:
cell = ax_group[cell_num]
cell.cursors[0].set_xdata((event.xdata, event.xdata))
cell.cursors[0].set_visible(True)
lim_min = max(int(self.sr/1e3*(event.xdata-IPIPK)), 0)
ipi_auto = np.argmax(cell.im.get_data()[1][lim_min:int(self.sr/1e3*(event.xdata+IPIPK))]) + lim_min
col = 'ipi_' + ('corr' if cell_num == 2 else 'ceps')
self.df[self.offset[self.curr_ind[group_num], 0]][col + '_man'] = event.xdata
self.df[self.offset[self.curr_ind[group_num], 0]][col + '_auto'] = ipi_auto*1e3/self.sr
cell.axes.set_xlabel(f'{"Corr" if cell_num == 2 else "Ceps"} man:{event.xdata:.3f} '
f'auto:{ipi_auto*1e3/self.sr:.3f}')
plt.draw()
def on_clicked(self, event):
if event.inaxes == self.fax:
self._on_clicked_fax(event)
for i, ax_group in enumerate(self.view_data): # Look if a click plot was clicked and which one
for j in range(len(ax_group)):
if event.inaxes == ax_group[j].axes:
self._on_clicked_ax_group(event, ax_group, i, j)
return
def change_curr(self, label):
self.curr = int(label[-1])
self.f_cursor.linev.set_color('rgbkcmyrgbkcmy'[self.curr])
self.reset_b.label.set_c('rgbkcmyrgbkcmy'[self.curr])
plt.draw()
def toggle_ind(self, event):
if not len(self.offset):
return
self.ind_select = not self.ind_select
self.ind_b.color = 'limegreen' if self.ind_select else '0.85'
self.ind_b.hovercolor = 'lime' if self.ind_select else '0.95'
if self.ind_select:
self.df[self.offset[self.curr_ind[self.curr], 0]]['ind_number'] = ''
self.ind_b.label.set_text(f'Current individual:\n'
f'{self.df[self.offset[self.curr_ind[self.curr], 0]]["ind_number"]}')
plt.draw()
def key_pressed(self, event):
if self.ind_select:
row = self.df[self.offset[self.curr_ind[self.curr], 0]]
if event.key == 'backspace':
row['ind_number'] = row['ind_number'][:-1]
elif event.key in ['shift', 'control', 'alt']:
pass
else:
row['ind_number'] = row['ind_number'] + event.key
self.ind_b.label.set_text(f'Current individual:\n{row["ind_number"]}')
plt.draw()
else:
if event.key in '012':
self.change_curr(event.key)
self.r_button.set_active(int(event.key))
def play(self, event):
sound = (norm(self.song_resample[self.p:self.p+FSSR*20])*(2**15-1)).astype(np.int16)
try:
play(AudioSegment(sound.tobytes(), frame_rate=FSSR, sample_width=sound.dtype.itemsize, channels=1))
except KeyboardInterrupt:
pass
def resize(self, event):
self.fax.get_figure().set_constrained_layout(True)
plt.draw()
plt.pause(0.2)
self.fax.get_figure().set_constrained_layout(False)
def onaxis(self, event):
row = self.df[self.offset[self.curr_ind[self.curr], 0]]
if row['onaxis'] == -1:
row['onaxis'] = 1
else:
row['onaxis'] ^= 1
self.onaxis_b.label.set_text('On-axis'
if row['onaxis'] else 'Off-axis')
plt.draw()
def increase_freq(self, event):
lim = self.view_data[0].spectrogram.axes.get_ylim()[1] + 1e3
for ax_group in self.view_data:
ax_group.spectrogram.axes.set_ylim(0, lim)
plt.draw()
def decrease_freq(self, event):
lim = max(self.view_data[0].spectrogram.axes.get_ylim()[1] - 1e3, 1e3)
for ax_group in self.view_data:
ax_group.spectrogram.axes.set_ylim(0, lim)
plt.draw()
def _update_spectrogram(self, ax_group):
click = ax_group.signal.im.get_ydata()
spec, _, t = plt.mlab.specgram(click, Fs=self.sr, NFFT=self.nfft, noverlap=self.nfft-1, pad_to=2*self.nfft)
spec = np.flipud(20*np.log10(spec))
ax_group.spectrogram.im.set_data(spec)
ax_group.spectrogram.im.set_clim(spec.max() - SPSC, spec.max())
ax_group.spectrogram.im.set_extent([t[0] - 1/self.sr/2, t[-1] + 1/self.sr/2, 0., self.sr/2])
def increase_res(self, event):
if self.nfft > int(10e-3*self.sr):
return
self.nfft *= 2
for ax_group in self.view_data:
self._update_spectrogram(ax_group)
if self.nfft > int(10e-3*self.sr):
self.spec_b['plus_res'].label.set_text('Can\'t go\nhigher')
else:
self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins')
self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins')
plt.draw()
def decrease_res(self, event):
if self.nfft < 8:
return
self.nfft //= 2
for ax_group in self.view_data:
self._update_spectrogram(ax_group)
self.spec_b['plus_res'].label.set_text(f'{self.nfft*2}\nbins')
if self.nfft < 8:
self.spec_b['minus_res'].label.set_text('Can\'t go\nlower')
else:
self.spec_b['minus_res'].label.set_text(f'{self.nfft//2}\nbins')
plt.draw()
def _set_label(self, ind=None, dic=None):
if ind is None:
ind = self.curr
if dic is None:
dic = self.df[self.offset[self.curr_ind[ind], 0]]
ax_group = self.view_data[ind]
ax_group.signal.axes.set_xlabel(f'Sig man:{dic["ipi_sig"]:.3f}')
ax_group.correlation.axes.set_xlabel(f'Corr man:{dic["ipi_corr_man"]:.3f} auto:{dic["ipi_corr_auto"]:.3f}')
ax_group.cepstrum.axes.set_xlabel(f'Ceps man:{dic["ipi_ceps_man"]:.3f} auto:{dic["ipi_ceps_auto"]:.3f}')
self.ind_b.label.set_text(f'Current individual:\n{dic["ind_number"]}')
if self.onaxis_b is not None:
self.onaxis_b.label.set_text(['Off', 'On', '?'][dic['onaxis']] + '-axis')
def _set_visible(self, ind=None, state=False):
if ind is None:
ind = self.curr
self.view_data[ind].set_visible(state)
def reset_curr(self, event):
self.df[self.offset[self.curr_ind[self.curr], 0]] = EMLN.copy()
self._set_label()
self._set_visible()
self.curr_vert[self.curr] = 0
self.cursor[self.curr].linev[0].set_linestyle('--')
self.cursor[self.curr].linev[self.cursor[self.curr].num_cur].set_linestyle('--')
self.view_data[self.curr].set_visible(False)
plt.draw()
def reset(self, song, sr, song_resample):
self.p = 0
self.df = dict()
self.song = song
self.song_resample = song_resample
self._shit()
sr_update = False
if self.sr != sr:
self.sr = sr
sr_update = True
self.change_curr('0') # reset current view to 0
self.r_button.set_active(0)
self.offset = np.zeros((0, 2))
self.scat.set_offsets(self.offset)
self.scat.set_color([[0, 0, 0, 1]])
self.curr_ind = self.num_view * [None] # Ind of click for each plot
self.curr_vert = self.num_view * [0] # Current vertical line of sig/spec for each plot
for ax_group in self.view_data:
ax_group.signal.im.set_ydata(np.zeros(int((10e-3 + self.after_length) * sr)))
ax_group.correlation.im.set_ydata(np.zeros(int(10e-3 * sr)))
ax_group.cepstrum.im.set_ydata(np.zeros(int(10e-3 * sr)))
if sr_update:
ax_group.signal.im.set_xdata(np.linspace(0, (10e-3 + self.after_length)*1e3,
int((10e-3 + self.after_length)*sr), False))
ax_group.correlation.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
ax_group.cepstrum.im.set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
ax_group.spectrogram.im.set_clim(2000, 2100)
for i in range(self.num_view):
self._set_label(i, EMLN)
self._set_visible(i)
self.ind_select = False
self.ind_b.color = '0.85'
self.ind_b.howercolor = '0.95'
self.ind_b.label.set_text(f'Current individual:\nnan')
plt.draw()
class AxesWithCursor:
__slots__ = ['axes', 'im', 'cursors']
def __init__(self, axes: plt.Axes, im=None, cursors=None):
self.axes = axes
self.im = im
self.cursors = list() if cursors is None else cursors
def set_visible(self, state):
for cursor in self.cursors:
cursor.set_visible(state)
@property
def figure(self):
return self.axes.figure
class AxesGroup:
__slots__ = ['signal', 'spectrogram', 'correlation', 'cepstrum']
def __init__(self, ax_sig, ax_spec, ax_corr, ax_cep, sr, num_cur=2, after_length=10e-3):
self.signal = AxesWithCursor(ax_sig,
ax_sig.plot(np.linspace(0, 10 + after_length*1e3, int((10e-3 + after_length)*sr),
False), np.zeros(int((10e-3 + after_length)*sr)))[0],
(ax_sig.axvline(10, c='k', linestyle='--'),) +
tuple(ax_sig.axvline(10, c='k') for _ in range(num_cur)))
ax_sig.set_xlim(0, 10 + after_length*1e3)
ax_sig.set_xlabel('IPI man:None')
ax_sig.set_ylim(-1, 1)
ax_sig.set_yticks([])
self.spectrogram = AxesWithCursor(ax_spec,
ax_spec.specgram(np.random.normal(0, 1e-6, int((10e-3 + after_length)*sr)),
Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1],
(ax_spec.axvline(0.01, c='k', linestyle='--'),) +
tuple(ax_spec.axvline(0.01, c='k') for _ in range(num_cur)))
self.spectrogram.im.set_clim(2000, 2100)
ax_spec.set_ylim(0, min(20e3, sr/2))
ax_spec.set_yticks(ax_spec.get_yticks()) # Needed, otherwise updating ylim doesn't update ticks properly
ax_spec.set_yticklabels((ax_spec.get_yticks() / 1e3).astype(int))
self.correlation = AxesWithCursor(ax_corr,
ax_corr.plot(np.linspace(0, 10, int(10e-3 * sr), False),
np.zeros(int(10e-3 * sr)))[0],
(ax_corr.axvline(10, c='k'),))
ax_corr.set_xlim(0, 10)
ax_corr.set_xlabel('IPI man:None auto:None')
ax_corr.set_ylim(-1, 1)
ax_corr.set_yticks([])
self.cepstrum = AxesWithCursor(ax_cep,
ax_cep.plot(np.linspace(0, 10, int(10e-3 * sr), False),
np.zeros(int(10e-3 * sr)))[0],
(ax_cep.axvline(10, c='k'),))
ax_cep.set_xlim(0, 10)
ax_cep.set_xlabel('IPI man:None auto:None')
ax_cep.set_ylim(0, 0.33)
ax_cep.set_yticks([])
self.set_visible(False)
def __getitem__(self, item):
return (self.signal, self.spectrogram, self.correlation, self.cepstrum)[item]
def __len__(self):
return 4
def __contains__(self, item):
if isinstance(item, AxesWithCursor):
return any(item == ax for ax in self)
else:
return any(item == ax.axes for ax in self)
def set_visible(self, state):
self.signal.set_visible(state)
self.spectrogram.set_visible(state)
self.correlation.set_visible(state)
self.cepstrum.set_visible(state)
@property
def signal_ipi(self):
return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0]
def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3, after_length=10e-3):
song, sr, song_resample = load_file(in_path, channel, low, high)
fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
gs = fig.add_gridspec(12, 20)
full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1]) # Type annotation needed with plt 3.1
callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
song, song_resample, sr, full_sig, num_graph, after_length)
callback.fax = full_sig
full_sig.set_xlim(0, 20)
lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2
full_sig.set_ylim(-lim, lim)
full_sig.set_yticks([])
callback.f_cursor = Cursor(full_sig, horizOn=False, useblit=True, c='r')
cid1 = fig.canvas.mpl_connect('button_release_event', callback.on_clicked)
b_left_ax = plt.subplot(gs[:2, 0])
b_right_ax = plt.subplot(gs[:2, -1])
b_left = Button(b_left_ax, '<|')
b_right = Button(b_right_ax, '|>')
b_left.on_clicked(callback.shift_left)
b_right.on_clicked(callback.shift_right)
r_button_ax = plt.subplot(gs[10, 1:-1])
r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(num_graph)], orientation='horizontal',
size=128, activecolor=list('rgbkcmyrgbkcmy'[:num_graph])) # !Last char of labels is use as index in rest of code
r_button_ax.axis('off')
r_button.on_clicked(callback.change_curr)
for i, c in enumerate('rgbkcmyrgbkcmy'[:num_graph]):
r_button.labels[i].set_c(c)
callback.r_button = r_button
for v in "fullscreen,home,back,forward,pan,zoom,save,quit,grid,yscale,xscale,all_axes".split(','):
plt.rcParams[f'keymap.{v}'] = [] # disable default shortcut but fullsreen
cid2 = fig.canvas.mpl_connect('key_press_event', callback.key_pressed)
# c_button_ax = plt.subplot(gs[10,3:6])
# c_button = CheckButtons(c_button_ax, [f'Save {i}' for i in range(3)], [False for i in range(3)])
vfs = 4
vs = 2
hs = 1
hfs = (20-hs)//(2*num_graph)
ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
plt.subplot(gs[vs:vs + vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
sr, num_cur, after_length) for i in range(num_graph)]
play_b_ax = plt.subplot(gs[-1, :2])
play_b = Button(play_b_ax, 'Play\ncurrent segment')
play_b.on_clicked(callback.play)
resize_b_ax = plt.subplot(gs[-1, 2:4])
resize_b = Button(resize_b_ax, 'Resize plot')
resize_b.on_clicked(callback.resize)
reset_b_ax = plt.subplot(gs[-1,4:6])
reset_b = Button(reset_b_ax, 'Reset current')
reset_b.label.set_c('r')
reset_b.on_clicked(callback.reset_curr)
callback.reset_b = reset_b
ind_b_ax = plt.subplot(gs[-1,6:8])
ind_b = Button(ind_b_ax, 'Current individual:\nnan')
ind_b.on_clicked(callback.toggle_ind)
callback.ind_b = ind_b
freq_p_b_ax = plt.subplot(gs[2, 0])
freq_m_b_ax = plt.subplot(gs[3, 0])
freq_res_p_b_ax = plt.subplot(gs[4, 0])
freq_res_m_b_ax = plt.subplot(gs[5, 0])
freq_p_b = Button(freq_p_b_ax, '+\n1kHz')
freq_p_b.on_clicked(callback.increase_freq)
freq_m_b = Button(freq_m_b_ax, '-\n1kHz')
freq_m_b.on_clicked(callback.decrease_freq)
freq_res_p_b = Button(freq_res_p_b_ax, '256\nbins')
freq_res_p_b.on_clicked(callback.increase_res)
freq_res_m_b = Button(freq_res_m_b_ax, '64\nbins')
freq_res_m_b.on_clicked(callback.decrease_res)
spec_button = {'plus': freq_p_b, 'minus': freq_m_b, 'plus_res': freq_res_p_b, 'minus_res': freq_res_m_b}
callback.spec_b = spec_button
m_cursor: List[Union[MyMultiCursor, None]] = len(ax_group) * [None]
callback.cursor = m_cursor
for i in range(len(ax_group)):
m_cursor[i] = MyMultiCursor(ax_group[i], ax_group[i].signal.cursors[0], num_cur, useblit=True, c='k')
callback.view_data = ax_group
return {'callback': callback, 'fig': fig, 'gridspec': gs, 'buttons':
{'b_left': b_left, 'b_right': b_right, 'play_b': play_b, 'resize_b': resize_b, 'r_button': r_button,
'ind_b': ind_b, 'fs_click': cid1, 'key_pressed': cid2, 'reset_b': reset_b, 'spec_button': spec_button}}
def reset(callback, in_path, channel, low=2e3, high=20e3):
song, sr, song_resample = load_file(in_path, channel, low, high)
callback.reset(song, sr, song_resample)
def main(args):
if args.out == '':
outpath = args.input.rsplit('.', 1)[0] + '.pred.h5'
else:
outpath = args.out
if os.path.isfile(outpath):
if not (args.erase or args.resume):
print(f'Out file {outpath} already exist and erase or resume option isn\'t set.')
return 1
elif args.resume:
print(f'Out file {outpath} does not already exist and resume option is set.')
return 1
if args.pulse < 1:
print(f'{args.pulse} is an invalid number of pulses.')
return 1
elif args.click < 1:
print(f'{args.click} is an invalid number of clicks.')
return 1
EMLN['onaxis'] = -1
ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click, args.after*1e-3)
if args.resume:
df = pd.read_hdf(outpath)
if 'onaxis' not in df.columns:
df['onaxis'] = -1
ref_dict['callback'].df = df.to_dict(orient='index')
ref_dict['callback'].offset = np.tile(np.array(list(ref_dict['callback'].df.keys()))[:, np.newaxis], (1, 2))
ref_dict['callback'].offset[:, 1] = ref_dict['callback'].song_resample[(ref_dict['callback'].offset[:, 1] * FSSR).astype(int)]
ref_dict['callback'].scat.set_offsets(ref_dict['callback'].offset)
colors = np.array(len(ref_dict['callback'].offset)*[[0.7, 0.2, 0.5, 1]])
for i, p in enumerate(ref_dict['callback'].offset[:, 0]):
if not np.all(pd.isnull(list(ref_dict['callback'].df[p].values()))):
colors[i] = [0, 0, 0, 1]
ref_dict['callback'].scat.set_color(colors)
onaxis_b_ax = plt.subplot(ref_dict['gridspec'][-1, 8:10])
onaxis_b = Button(onaxis_b_ax, '?-axis')
onaxis_b.on_clicked(ref_dict['callback'].onaxis)
ref_dict['callback'].onaxis_b = onaxis_b
plt.draw()
plt.pause(0.2)
ref_dict['fig'].set_constrained_layout(False)
plt.show()
df = pd.DataFrame.from_dict(ref_dict['callback'].df, orient='index')
df.to_hdf(outpath, 'df', format='table')
return 0
if __name__ == '__main__':
class ArgumentParser(argparse.ArgumentParser):
def error(self, message):
if message.startswith('the following arguments are required:'):
raise ValueError(message)
super(ArgumentParser, self).error(message)
parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("input", type=str, help="Input file")
parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
parser.add_argument("--pulse", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
parser.add_argument("--click", type=int, default=3, help="Number of click to plot simultaneously.")
parser.add_argument("--after", type=float, default=10., help="Duration to plot after P1 in ms.")
parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter')
parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter')
group = parser.add_mutually_exclusive_group()
group.add_argument("--erase", action='store_true', help="If out file exist and this option is not given,"
" the computation will be halted")
group.add_argument("--resume", action='store_true', help="If out file exist and this option is given,"
" the previous annotation file will be loaded")
try:
args = parser.parse_args()
except ValueError as e:
print(f'Error while parsing the command line arguments: {e}')
def ask(string):
y = {'y', 'yes', 'o', 'oui'}
a = {'y', 'yes', 'o', 'oui', 'n', 'no', 'non'}
while True:
ans = input(string + ' [y/n] ').lower()
if ans in a:
return ans in y
if not ask('Do you want to manually specify them?'):
sys.exit(2) # exit code of invalid argparse
class VirtualArgParse(object):
def __init__(self):
self.input = input("What is the input file path? ")
self.out = input("What is the out file path? (Leave empty for default)")
self.channel = int(input("Which channel do you want to use starting from 0? "))
self.pulse = int(input("How many pulses do you want to display (> 1) "))
self.click = int(input("How many clicks do you want to display (> 1) "))
self.after = float(input("What duration do you want after P1?"))
self.low = float(input("What is the lower frequency of the bandpass? "))
self.up = float(input("What is the upper frequency of the bandpass? "))
self.erase = ask("Do you want to erase the out file if it already exist?")
if not self.erase:
self.resume = ask("Do you want to resume the out file if it already exist?")
else:
self.resume = False
args = VirtualArgParse()
sys.exit(main(args))