Skip to content
Snippets Groups Projects
Commit d6a68d08 authored by maxence's avatar maxence
Browse files

Add number of graphs argument

parent cd277078
No related branches found
No related tags found
No related merge requests found
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import scipy.signal as sg
import soundfile as sf
import os
......@@ -243,7 +244,7 @@ class MyMultiCursor(AxesWidget):
class Callback(object):
def __init__(self, line, song, song_resample, sr, full_ax):
def __init__(self, line, song, song_resample, sr, full_ax, num_view):
self.p = 0
self.df = dict()
self.line = line
......@@ -251,13 +252,13 @@ class Callback(object):
self.song_resample = song_resample
self.sr = sr
self.fax = full_ax
self.num_view = 3
self.num_view = num_view
self.view_data: [AxesGroup, None] = None
self.curr = 0 # current view selected
self.scat = self.fax.scatter([], [], marker='x', c=[[0.7, 0.2, 0.5, 1]])
self.offset = np.zeros((0, 2))
self.curr_ind: List[Union[int, None]] = 3*[None] # Indices of click for each plot
self.curr_vert = 3*[0] # Current vertical line of sig/spec for each plot
self.curr_ind: List[Union[int, None]] = num_view*[None] # Indices of click for each plot
self.curr_vert = num_view*[0] # Current vertical line of sig/spec for each plot
self.cursor = None
self.f_cursor = None
self.reset_b = None
......@@ -317,14 +318,14 @@ class Callback(object):
ax_group[i][1].set_visible(False)
else:
self.curr_ind[self.curr] = np.argmax(mpos / FSSR == self.offset[:, 0])
c = [0, 0, 0, 1]
c = np.array([0, 0, 0, 0.])
k = 0
for i, v in enumerate(self.curr_ind):
if v == self.curr_ind[self.curr]:
c[i] = 1
c += to_rgba('rgbkcmyrgbkcmy'[i])
k += 1
for i in range(self.num_view):
c[i] /= k
c /= k
self.scat._facecolors[self.curr_ind[self.curr]] = c
self.scat.set_color(self.scat._facecolors)
row = self.df[mpos / FSSR]
......@@ -412,8 +413,8 @@ class Callback(object):
def change_curr(self, label):
self.curr = int(label[-1])
self.f_cursor.linev.set_color('rgb'[self.curr])
self.reset_b.label.set_c('rgb'[self.curr])
self.f_cursor.linev.set_color('rgbkcmyrgbkcmy'[self.curr])
self.reset_b.label.set_c('rgbkcmyrgbkcmy'[self.curr])
plt.draw()
def toggle_ind(self, event):
......@@ -562,7 +563,7 @@ class Callback(object):
ax_group[2][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
ax_group[3][0].set_xdata(np.linspace(0, 10, int(10e-3*sr), False))
ax_group[1][0].set_clim(2000, 2100)
for i in range(3):
for i in range(self.num_view):
self._set_label(i, EMLN)
self._set_visible(i)
self.ind_select = False
......@@ -651,14 +652,14 @@ class AxesGroup:
return self.signal.cursors[1].get_xdata()[0] - self.signal.cursors[0].get_xdata()[0]
def init(in_path, channel, low=2e3, high=20e3, num_cur=1):
def init(in_path, channel, low=2e3, high=20e3, num_cur=1, num_graph=3):
song, sr, song_resample = load_file(in_path, channel, low, high)
fig = plt.figure('IPI of ' + in_path.rsplit('/', 1)[-1], figsize=[16, 9], constrained_layout=True)
gs = fig.add_gridspec(12, 20)
full_sig: plt.Axes = plt.subplot(gs[:2, 1:-1]) # Type annotation needed with plt 3.1
callback = Callback(full_sig.plot(np.linspace(0, 20, FSSR * 20, False), song_resample[:FSSR * 20])[0],
song, song_resample, sr, full_sig)
song, song_resample, sr, full_sig, num_graph)
callback.fax = full_sig
full_sig.set_xlim(0, 20)
lim = np.abs(song_resample[:FSSR * 20]).max() * 1.2
......@@ -674,11 +675,11 @@ def init(in_path, channel, low=2e3, high=20e3, num_cur=1):
b_left.on_clicked(callback.shift_left)
b_right.on_clicked(callback.shift_right)
r_button_ax = plt.subplot(gs[10, 1:-1])
r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(3)], orientation='horizontal',
size=128, activecolor=list('rgb')) # !Last char of labels is use as index in rest of code
r_button = MyRadioButtons(r_button_ax, [f'Change {i}' for i in range(num_graph)], orientation='horizontal',
size=128, activecolor=list('rgbkcmyrgbkcmy'[:num_graph])) # !Last char of labels is use as index in rest of code
r_button_ax.axis('off')
r_button.on_clicked(callback.change_curr)
for i, c in enumerate('rgb'):
for i, c in enumerate('rgbkcmyrgbkcmy'[:num_graph]):
r_button.labels[i].set_c(c)
callback.r_button = r_button
for v in "fullscreen,home,back,forward,pan,zoom,save,quit,grid,yscale,xscale,all_axes".split(','):
......@@ -689,20 +690,13 @@ def init(in_path, channel, low=2e3, high=20e3, num_cur=1):
vfs = 4
vs = 2
hfs = 3
hs = 1
ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs:hs + hfs]),
plt.subplot(gs[vs:vs + vfs, hs + hfs:hs + 2 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs:hs + hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + hfs:hs + 2 * hfs]), sr, num_cur),
AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2 * hfs:hs + 3 * hfs]),
plt.subplot(gs[vs:vs + vfs, hs + 3 * hfs:hs + 4 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2 * hfs:hs + 3 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 3 * hfs:hs + 4 * hfs]), sr, num_cur),
AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 4 * hfs:hs + 5 * hfs]),
plt.subplot(gs[vs:vs + vfs, hs + 5 * hfs:hs + 6 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 4 * hfs:hs + 5 * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 5 * hfs:hs + 6 * hfs]), sr, num_cur)]
hfs = (20-hs)//(2*num_graph)
ax_group = [AxesGroup(plt.subplot(gs[vs:vs + vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
plt.subplot(gs[vs:vs + vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + 2*i * hfs:hs + (2*i+1) * hfs]),
plt.subplot(gs[vs + vfs:vs + 2 * vfs, hs + (2*i+1) * hfs:hs + (2*i+2) * hfs]), sr, num_cur)
for i in range(num_graph)]
play_b_ax = plt.subplot(gs[-1, :2])
play_b = Button(play_b_ax, 'Play\ncurrent segment')
......@@ -766,12 +760,15 @@ def main(args):
elif args.resume:
print(f'Out file {outpath} does not already exist and resume option is set.')
return 1
if args.num < 1:
print(f'{args.num} is an invalid number of pulses.')
if args.pulse < 1:
print(f'{args.pulse} is an invalid number of pulses.')
return 1
elif args.click < 1:
print(f'{args.click} is an invalid number of clicks.')
return 1
EMLN['onaxis'] = -1
ref_dict = init(args.input, args.channel, args.low, args.up, args.num)
ref_dict = init(args.input, args.channel, args.low, args.up, args.pulse, args.click)
if args.resume:
df = pd.read_hdf(outpath)
if 'onaxis' not in df.columns:
......@@ -811,7 +808,8 @@ if __name__ == '__main__':
parser.add_argument("input", type=str, help="Input file")
parser.add_argument("--out", type=str, default='', help="Output file. Default to the input_path'.pred.h5'")
parser.add_argument("--channel", type=int, default=0, help="Sound channel to be analysed. Indices start from 0.")
parser.add_argument("--num", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
parser.add_argument("--pulse", type=int, default=1, help="Number of IPI cursor to display. Must be greater than 1.")
parser.add_argument("--click", type=int, default=3, help="Number of click to plot simultaneously.")
parser.add_argument('--low', type=float, default=2e3, help='Lower frequency of the bandpass filter')
parser.add_argument('--up', type=float, default=20e3, help='Upper frequency of the bandpass filter')
group = parser.add_mutually_exclusive_group()
......@@ -840,7 +838,8 @@ if __name__ == '__main__':
self.input = input("What is the input file path? ")
self.out = input("What is the out file path? (Leave empty for default)")
self.channel = int(input("Which channel do you want to use starting from 0? "))
self.num = int(input("How many pulse do you want to display (> 1) "))
self.pulse = int(input("How many pulses do you want to display (> 1) "))
self.click = int(input("How many clicks do you want to display (> 1) "))
self.low = float(input("What is the lower frequency of the bandpass? "))
self.up = float(input("What is the upper frequency of the bandpass? "))
self.erase = ask("Do you want to erase the out file if it already exist?")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment