Skip to content
Snippets Groups Projects
Select Git revision
  • a45b2b750daeab5e21956e9b9f36116ca73e5a0a
  • master default
2 results

benchmark_classification.py

Blame
  • get_time_freq_detection.py 1.85 KiB
    import pandas as pd
    import os
    import ipdb
    from tqdm import tqdm
    import argparse
    from datetime import date
    
    def arg_directory(path):
        if os.path.isdir(path):
            return path
        else:
            raise argparse.ArgumentTypeError(f'`{path}` is not a valid path')
    
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='TODO')
    parser.add_argument('-p','--path_to_data', type=arg_directory, help = 'Path of the folder that contain the .txt files',required=True)
    parser.add_argument('-d','--directory', type=arg_directory, help = 'Directory to wich the dataframe will be stored',required=True)
    parser.add_argument('-t','--duration', type = int, help = 'Duration of the spectrogram', required = True)
    parser.add_argument('-s','--SR', type = int, help = 'Sampling Rate of the spectrogram')
    args = parser.parse_args()
    
    annots = args.path_to_data
    
    today = date.today()
    out_file = str('YOLO_detection'+str('_'+str(today.day)+'_'+str(today.month)))
    
    outdir = args.directory
    
    df = pd.concat({f:pd.read_csv(os.path.join(annots, f), sep=' ', names=['espece', 'x', 'y', 'w', 'h'])
    				for f in tqdm(os.listdir(annots))}, names=['file'])
    
    df = df.reset_index(level=[0])
    df = df.reset_index()
    del df['index']
    df['idx'] = df.file.str.split('_').str[-1].str.split('.').str[0]
    df.file = df.file.str.rsplit('.',1).str[0]+'.wav'
    
    DUREE_SPECTRO = args.duration
    SR = args.SR
    
    #put the classes here
    names = []
    
    df['annot'] = 'None'
    for j in range (len(df)):
    	df.loc[j,('annot')] = names[int(df.espece.iloc[j])]
    
    print('Calculating the positions','\n')
    df['midl'] = (df.x*DUREE_SPECTRO)+(df.idx.astype(int))
    df['freq_center'] = (1-df.y)*(SR/2)
    df['freq_min'] = df.freq_center - (df.h*(SR/2))/2
    df['freq_max'] = df.freq_center + (df.h*(SR/2))/2
    
    df.to_csv(os.path.join(outdir,str(out_file+'.csv')), index= False)
    print('saved as ',os.path.join(outdir,str(out_file+'.csv')))