from torchvision.utils import make_grid import torch import pandas as pd import utils as u, models from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter import argparse parser = argparse.ArgumentParser(description="""This script trains an auto-encoder to compress and depcompress vocalisation spectrograms. Reconstruction quality can be monitored via tensorboard ($tensorboard --logdir=runs/ --bind_all)""", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. A column \'filename\' (path of the soundfile) and a column \'pos\{ (center of the detection in seconds) are needed") parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files") parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation") parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)") parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation") parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck') args = parser.parse_args() # init AE architecture device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32" assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)" frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) model = torch.nn.Sequential(frontend, encoder, decoder).to(device) # training / optimisation setup lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16 optimizer = torch.optim.AdamW(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .99**epoch) vgg16 = models.vgg16.eval().to(device) loss_fun = torch.nn.MSELoss() # data loader df = pd.read_csv(args.detections) loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn) modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.weights' step, writer = 0, SummaryWriter('runs/'+modelname) print(f'Go for model {modelname} with {len(df)} vocalizations') for epoch in range(100_000//len(loader)): for x, name in tqdm(loader, desc=str(epoch), leave=False): optimizer.zero_grad() label = frontend(x.to(device)) x = encoder(label) pred = decoder(x) vgg_pred = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:])) vgg_label = vgg16(label.expand(label.shape[0], 3, *label.shape[2:])) score = loss_fun(vgg_pred, vgg_label) score.backward() optimizer.step() writer.add_scalar('loss', score.item(), step) if step%50==0 : images = [(e-e.min())/(e.max()-e.min()) for e in label[:8]] grid = make_grid(images) writer.add_image('target', grid, step) writer.add_embedding(x.detach(), global_step=step, label_img=label) images = [(e-e.min())/(e.max()-e.min()) for e in pred[:8]] grid = make_grid(images) writer.add_image('reconstruct', grid, step) step += 1 if step % 500 == 0: scheduler.step() torch.save(model.state_dict(), modelname)