diff --git a/extract_pitch_best.py b/extract_pitch_best.py new file mode 100644 index 0000000000000000000000000000000000000000..031e7ecc74ea63a6154ad0a16699657d55fce371 --- /dev/null +++ b/extract_pitch_best.py @@ -0,0 +1,160 @@ +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from scipy.special import expit as sigmoid +import scipy.signal as signal +import soundfile +import parselmouth +from matplotlib.transforms import Bbox +from tqdm import tqdm +import os +import argparse + +def main(inputFolder, output, extend, pred_threshold, display, saveimg, specific_time): + for filename in os.listdir(inputFolder): + if os.path.isfile(output+filename[:-6]+'.npy'): + print(filename+' Hey I already treated this file ! find the output at '+output+filename[:-6]+'.npy') + elif os.path.getsize(os.path.join(inputFolder,filename))<5: + print(os.path.join(inputFolder,filename)+' This file is empty duh... ') + elif not filename.endswith('.preds'): + print(filename+' I\'m so sorry but I only work with .preds files !') + else : + preds = np.load(os.path.join(inputFolder,filename),allow_pickle=True) + #load the preds into a pandas dataframe + ok = [] + for key, value in preds.items(): + if 'h1' in key: + date = key.split('_')[3]+key.split('_')[4] + tempdic = {'filename':key[:-8]+'.flac','date':pd.Timestamp(pd.to_datetime(date)),'prob_h1':sigmoid(value)} + for h in ['h2','h3','h4','h5']: + probs = sigmoid(preds[key[:-7]+h+'.flac']) + tempdic['prob_'+h]=probs + ok.append(tempdic) + df = pd.DataFrame(ok) + df.set_index('date', inplace=True) + df.sort_values('date',inplace=True) + + hydrotochan = {'h1':3,'h2':4,'h3':5,'h4':8,'h5':7} #starts at 1 not 0 ! + # fft window length : 1024, fft window hop size : 315 + # pred window length : 103, pred window hop size : 9 + pred_offset = ((103-1)*315 + 1024)/2 # pred offset is a pred window divided by 2 + pred_hop_size = 9*315 + orcatime = np.arange(922) * pred_hop_size/22050 + pred_offset/22050 #for orcalab, we have 922 predictions per file + + call_count = 0 + call_missed = 0 + out_dic = {} + if specific_time: + df = df[specific_time] + # iterate through the dates + for index, row in tqdm(df.iterrows(), total=len(df), desc=filename): + # check throughout the hydros if there is any high prediction before loading the soundfile ! + probs_sum = sum([np.asarray(row['prob_'+hydro] > pred_threshold).sum() for hydro in hydrotochan]) + + if probs_sum > 1: + soundpath = "/nfs/NAS3/SABIOD/SITE/OrcaLab/YEAR_MONTH_DAY/{}/{:0>2}/{:0>2}".format(row.name.year, row.name.month, row.name.day)+"/"+row['filename'] + try: + sound = parselmouth.Sound(soundpath) + except: + print('error opening '+soundpath) + continue + for hydro in hydrotochan: + #print(hydro, index) + binary_probs = np.atleast_1d(row['prob_'+hydro] > pred_threshold) + if binary_probs.sum()>1: + monosound = sound.extract_channel(hydrotochan[hydro]) + sndpitches0 = monosound.to_pitch_ac(pitch_floor=1000, pitch_ceiling=2500,voicing_threshold=0.20) + pitchtime = sndpitches0.xs() + pitch = sndpitches0.selected_array['frequency'] + + if extend > 0 : + #extend each call by x preds on the right side + orcas = np.append(binary_probs[:extend], + [np.clip(binary_probs[i-extend:i+1].sum(),a_max=1, a_min=0) for i in range(len(binary_probs[extend:]))]) + else: + orcas = binary_probs + #go through all predictions + predid = 0 + while predid < len(orcas): + isOrca = orcas[predid] + if isOrca: + lencall = 1 + while (predid+lencall)<len(orcas) and orcas[predid+lencall] : + lencall+=1 + + t1 = orcatime[predid] + t2 = orcatime[predid+lencall-1] + callmask = ((pitchtime>t1 )&(pitchtime<t2)&(pitch>0)) + call = pitch[callmask] + calltime = pitchtime[callmask] + + # if Parselmouth didn't find any pitch within the call or the call is less than 3 predictions long + if len(call) == 0 or lencall < 3 : + #print('Missed call n°', call_missed, ' at :', index, hydro, t1, t2) + call_missed+=1 + else : #let's clean the call of outliers + cleancall = [call[0]] + cleancalltime = [calltime[0]] + for k, v in enumerate(call[1:]): + # if the next pitch is not further than 200Hz, save it + if abs(v-cleancall[-1]) < 200: + cleancall.append(v) + cleancalltime.append(calltime[k+1]) + # else if we did not save any pitch for more than 40% of the call : + elif (cleancalltime[-1]==calltime).argmax() < k-len(call)*.40: + # start from scratch right after the big gap + start = (calltime==cleancalltime[-1]).argmax()+1 + cleancall = [call[start]] + cleancalltime = [calltime[start]] + for k, v in enumerate(call[start+1:]): + if abs(v-cleancall[-1]) < 200: + cleancall.append(v) + cleancalltime.append(calltime[k+start]) + break; + + if display or saveimg: + sig, sr = soundfile.read(soundpath) + plt.close() + fig,ax = plt.subplots(1) + ax.set_title(row['filename']+' '+hydro) + ax.specgram(sig.T[hydrotochan[hydro]-1], window=signal.windows.hamming(1024), Fs=sr, NFFT=1024, noverlap=890, cmap='viridis') + ax.scatter(calltime, call, c='r', s=1, label='orginal pitch') + ax.scatter(cleancalltime, cleancall, c='b', s=1, label='filtered pitch') + ax.set_ylim((np.clip(min(call)-2000, a_min=0, a_max=20000), np.clip(max(call)+2000, a_min=0, a_max=20000))) + ax.set_xlim((t1-0.5, t2+0.5)) + ax.plot([t1, t2],[min(call)-800, min(call)-800], c='k', label='model\'s prediction') + ax.legend() + #ax.add_patch(Rectangle((t1, call.min()), t2-t1, call.max()-call.min(), linewidth=1, edgecolor='k', facecolor='none')) + if display: + plt.show() + if saveimg: + plt.savefig(saveimg+'/'+str(call_count)+'.png') + + out_dic[call_count] = {'filename':row['filename'], 'hydro':hydro, 'call':cleancall, 'calltime':cleancalltime} + #print('Saved call ', call_count) + + call_count += 1 + + predid += lencall if isOrca else 1 + + np.save(output+str(filename[:-6]), out_dic) + print('saved '+str(len(out_dic))+' calls into '+output+filename[:-6]) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Once your best detector has outlined your favorite animal's calls, this script allows you to extract their pitch (using the Praat technology https://parselmouth.readthedocs.io/en/stable/ , based on the Paul Boersma (1993): \"Accurate short-term analysis of the fundamental frequency and the harmonics-to-noise ratio of a sampled sound.\" Proceedings of the Institute of Phonetic Sciences 17: 97–110. University of Amsterdam.) The dictionnary storing the pitch of the given calls will be saved in a .npy in the folder of your choice") + + parser.add_argument("file", type=str, help="Folder of .preds (dictionnaries 'filename': model's prediction array)") + parser.add_argument("--output", type=str, default='./', help="Output folder, ") + parser.add_argument("--extend", type=int, default=0, help="Number of preds to extend the calls on the right side") + parser.add_argument("--pred_threshold", type=float, default=.9, help="Threshold for considering an prediction true (between 0 and 1)") + parser.add_argument("--display", type=bool, default=False, help="Boolean for displaying calls and their pitch") + parser.add_argument("--saveimg", type=str, default=None, help="Path for saving vocalisation images") + parser.add_argument("--specific_time", type=str, default=None, help="Date for running at specific time using pandas date indexing, format : YYYY-MM-DD HH:MM:SS") + args = parser.parse_args() + + main(args.file, args.output, args.extend, args.pred_threshold, args.display, args.saveimg, args.specific_time) + +