diff --git a/global_ipi.py b/global_ipi.py new file mode 100644 index 0000000000000000000000000000000000000000..67182aa1cacf0d8a5c574b723ab5df534af9083f --- /dev/null +++ b/global_ipi.py @@ -0,0 +1,246 @@ +import argparse +import numpy as np +import matplotlib.pyplot as plt +import scipy.signal as sg +from scipy. stats import gaussian_kde +from fractions import Fraction +import os +import sys +from pydub import AudioSegment +import soundfile as sf +from matplotlib.widgets import Button, Cursor, RadioButtons, AxesWidget + + +def norm(x, axis=None, eps=1e-18): + return (x-x.mean(axis, keepdims=axis is not None))/(x.std(axis, keepdims=axis is not None) + eps) + + +def norm_abs(x, axis=None, eps=1e-18): + return (x-x.mean(axis, keepdims=axis is not None))/(np.abs(x).max(axis, keepdims=axis is not None) + eps) + +def read(file_path, always_2d=True): + try: + return sf.read(file_path, always_2d=always_2d) + except Exception as e: + return load_anysound(file_path) + + +def load_anysound(file_path): + tmp = AudioSegment.from_file(file_path) + return np.array(tmp.get_array_of_samples()).reshape(-1, tmp.channels), tmp.frame_rate + + +def load_file(in_path, channel, low, high): + print(f'Loading and processing {in_path}') + song, sr = read(in_path, always_2d=True) + song = song[:, channel] + sos = sg.butter(3, [low, high], 'bandpass', fs=sr, output='sos') + song = sg.sosfiltfilt(sos, song) + print('Done processing') + return song, sr + + +class MyRadioButtons(RadioButtons): + + def __init__(self, ax, labels, active=0, activecolor='blue', size=49, + orientation="vertical", **kwargs): + """ + Add radio buttons to an `~.axes.Axes`. + Parameters + ---------- + ax : `~matplotlib.axes.Axes` + The axes to add the buttons to. + labels : list of str + The button labels. + active : int + The index of the initially selected button. + activecolor : color + The color of the selected button. + size : float + Size of the radio buttons + orientation : str + The orientation of the buttons: 'vertical' (default), or 'horizontal'. + Further parameters are passed on to `Legend`. + """ + AxesWidget.__init__(self, ax) + self._activecolor = activecolor + axcolor = ax.get_facecolor() + self.value_selected = None + + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_navigate(False) + circles = [] + for i, label in enumerate(labels): + if i == active: + self.value_selected = label + facecolor = self.activecolor + else: + facecolor = axcolor + p = ax.scatter([],[], s=size, marker="o", edgecolor='black', + facecolor=facecolor) + circles.append(p) + if orientation == "horizontal": + kwargs.update(ncol=len(labels), mode="expand") + kwargs.setdefault("frameon", False) + self.box = ax.legend(circles, labels, loc="center", **kwargs) + self.labels = self.box.texts + self.circles = self.box.legendHandles + for c in self.circles: + c.set_picker(5) + self.cnt = 0 + self.observers = {} + self.connect_event('pick_event', self._clicked) + + def _clicked(self, event): + if self.ignore(event) or event.mouseevent.button != 1 or event.mouseevent.inaxes != self.ax: + return + if event.artist in self.circles: + self.set_active(self.circles.index(event.artist)) + + @property + def activecolor(self): + if hasattr(self._activecolor, '__getitem__') and not isinstance(self._activecolor, str): + return self._activecolor[int(self.value_selected[-1])] + else: + return self._activecolor + +def main(args): + song, sr = load_file(args.input, args.channel, args.low, args.high) + mean = np.ones(int(6e-3*sr)) + mean /= len(mean) + pos, prev = sg.find_peaks(np.log10(np.correlate(song ** 2, mean, 'same'))[int(sr*17e-3):int(-sr*17e-3)], distance=int(sr * 20e-3), wlen=int(sr * 20e-3), + prominence=0) + pos += int(sr*17e-3) + prev = prev['prominences'] + kde = gaussian_kde(prev, 0.02) + pointers = {'mask': prev > 0.65} + all_clicks = song[pos[:, None] + np.arange(int(-sr*17e-3), int(sr*17e-3))] + fine_pos = np.argmax(all_clicks[:,int(sr*10e-3):int(sr*24e-3)], -1) + int(sr*10e-3) + fine_pos_glob = fine_pos + pos + int(-sr*17e-3) + all_clicks = norm_abs(all_clicks[np.arange(len(all_clicks))[:,None], fine_pos[:,None]+np.arange(int(-sr*10e-3), int(sr*10e-3))], -1) + all_autocorr = norm_abs(np.vstack([np.correlate(c, c, 'same') for c in all_clicks]), -1)[:, int(sr * 10e-3):] + all_cepstrum = np.fft.ifftshift(np.abs(np.fft.irfft(np.log10( + np.abs(np.fft.rfft(all_clicks, axis=-1))+1e-18), axis=-1)), axes=-1)[:, int(sr*10e-3):] + all_cepstrum /= all_cepstrum.max(-1, keepdims=True) + 1e-18 + + fig = plt.figure('IPI of ' + args.input.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True) + gs = fig.add_gridspec(12, 20) + full_sig = plt.subplot(gs[:2, :]) + full_sig.plot(np.arange(len(song))/sr, song, c='k') + scat = full_sig.scatter(fine_pos_glob[pointers['mask']]/sr, song[fine_pos_glob[pointers['mask']]], 50, marker='x', c='r') + full_sig.set_xlim(0, len(song)/sr) + pointers['clicks'] = all_clicks[pointers['mask']] + pointers['autocorr'] = all_autocorr[pointers['mask']] + pointers['cepstrum'] = all_cepstrum[pointers['mask']] + raster = plt.subplot(gs[2:-1, :-6]) + im = raster.imshow(pointers['clicks'].T, aspect='auto', origin='lower', cmap='jet', extent=[0, len(pointers['clicks']), -10, 10]) + mean_raster = plt.subplot(gs[2:-1, -6:-4]) + line, = mean_raster.plot(pointers['clicks'].sum(0), np.arange(int(-sr*10e-3), int(sr*10e-3))/sr*1e3) + mean_raster.set_ylim(-10, 10) + r_button_ax = plt.subplot(gs[-1:, :5]) + + r_button = MyRadioButtons(r_button_ax, ['signal', 'autocorr', 'cepstrum'], orientation='horizontal', + size=666) + r_button_ax.axis('off') + + + spec_ax = plt.subplot(gs[2:-1, -4:]) + spec = spec_ax.specgram(pointers['clicks'].mean(0), Fs=sr, NFFT=128, noverlap=127, cmap='jet')[-1] + spec.set_clim(spec.get_clim()[1] - 80, spec.get_clim()[1]) + + def change_graph(label): + if label == 'signal': + im.set_data(pointers['clicks'].T) + im.set_clim(pointers['clicks'].min(), pointers['clicks'].max()) + im.set_extent([0, len(pointers['clicks']), -10, 10]) + raster.set_xlim(0, len(pointers['clicks'])) + line.set_xdata(pointers['clicks'].sum(0)) + line.set_ydata(np.arange(int(-sr*10e-3), int(sr*10e-3))/sr*1e3) + mean_raster.set_ylim(-10, 10) + elif label == 'autocorr': + im.set_data(pointers['autocorr'].T) + im.set_clim(pointers['autocorr'].min(), pointers['autocorr'].max()) + im.set_extent([0, len(pointers['clicks']), 0, 10]) + line.set_xdata(pointers['autocorr'].sum(0)) + line.set_ydata(np.arange(0, int(sr*10e-3))/sr*1e3) + mean_raster.set_ylim(0, 10) + elif label == 'cepstrum': + im.set_data(pointers['cepstrum'].T) + im.set_clim(0, 2*np.nanstd(pointers['cepstrum'])) + im.set_extent([0, len(pointers['clicks']), 0, 10]) + line.set_xdata(np.nansum(pointers['cepstrum'], 0)) + line.set_ydata(np.arange(0, int(sr*10e-3))/sr*1e3) + mean_raster.set_ylim(0, 10) + plt.draw() + + def resize(event): + fig.set_constrained_layout(True) + plt.draw() + plt.pause(0.2) + fig.set_constrained_layout(False) + + r_button.on_clicked(change_graph) + + resize_b_ax = plt.subplot(gs[-1:, 5:7]) + resize_b = Button(resize_b_ax, 'Resize plot') + resize_b.on_clicked(resize) + + hist_ax = plt.subplot(gs[-1:, 7:-4]) + hist_ax.plot(np.linspace(0, prev.max()+0.5, 1024), kde(np.linspace(0, prev.max()+0.5, 1024))) + hist_ax.set_xlim(0, prev.max()+0.5) + hist_ax.set_ylim(0, kde(np.linspace(0, prev.max()+0.5, 1024)).std()*3.3) + vlined = hist_ax.axvline(0.65, c='k') + vlineu = hist_ax.axvline(0.65, c='k') + down = [0.65] + + def onclick_down(event): + if event.inaxes != hist_ax: + return + down[0] = event.xdata + + def onclick_up(event): + if event.inaxes != hist_ax: + return + if abs(event.xdata - down[0]) < 0.025: + pointers['mask'] = prev > event.xdata + vlined.set_xdata(event.xdata) + vlineu.set_xdata(event.xdata) + elif event.xdata > down[0]: + pointers['mask'] = (prev > down[0]) & (prev < event.xdata) + vlined.set_xdata(down[0]) + vlineu.set_xdata(event.xdata) + else: + pointers['mask'] = (prev < down[0]) & (prev > event.xdata) + vlineu.set_xdata(down[0]) + vlined.set_xdata(event.xdata) + pointers['clicks'] = all_clicks[pointers['mask']] + pointers['autocorr'] = all_autocorr[pointers['mask']] + pointers['cepstrum'] = all_cepstrum[pointers['mask']] + scat.set_offsets(np.vstack((fine_pos_glob[pointers['mask']]/sr, song[fine_pos_glob[pointers['mask']]])).T) + spectro = np.flipud(20 * np.log10( + plt.mlab.specgram(pointers['clicks'].mean(0), Fs=sr, NFFT=128, noverlap=127)[0])) + spec.set_data(spectro) + spec.set_clim(spectro.max() - 80, spectro.max()) + change_graph(r_button.value_selected) + + cur = Cursor(hist_ax, horizOn=False, vertOn=True, useblit=True, c='r') + cur.connect_event('button_press_event', onclick_down) + cur.connect_event('button_release_event', onclick_up) + + plt.draw() + plt.pause(0.2) + fig.set_constrained_layout(False) + plt.show() + return 0 + +if __name__ == '__main__': + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("input", type=str, help="Input file") + parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0") + parser.add_argument("--low", type=int, default=2_000, help="Low frequency cut of the bandpass") + parser.add_argument("--high", type=int, default=20_000, help="High frequency cut of the bandpass") + + args = parser.parse_args() + + sys.exit(main(args)) \ No newline at end of file