Skip to content
Snippets Groups Projects
Select Git revision
  • bf6f37d6f0a892ff44de3fc9bc7f49e36f455e7d
  • develop default protected
  • master
3 results

ci-settings.xml

Blame
  • train_AE.py 3.84 KiB
    from torchvision.utils import make_grid
    import torch
    import pandas as pd
    import utils as u, models
    from tqdm import tqdm
    from torch.utils.tensorboard import SummaryWriter
    import argparse
    
    parser = argparse.ArgumentParser(description="""This script trains an auto-encoder to compress and depcompress vocalisation spectrograms.
                                     Reconstruction quality can be monitored via tensorboard ($tensorboard --logdir=runs/ --bind_all)""",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. A column \'filename\' (path of the soundfile) and a column \'pos\{ (center of the detection in seconds) are needed")
    parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
    parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
    parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
    parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
    parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
    parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
    args = parser.parse_args()
    
    # init AE architecture
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32"
    assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)"
    frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
    encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
    decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
    model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
    
    # training / optimisation setup
    lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16
    optimizer = torch.optim.AdamW(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .99**epoch)
    vgg16 = models.vgg16.eval().to(device)
    loss_fun = torch.nn.MSELoss()
    
    # data loader
    df = pd.read_csv(args.detections)
    loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
    
    modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.weights'
    step, writer = 0, SummaryWriter('runs/'+modelname)
    print(f'Go for model {modelname} with {len(df)} vocalizations')
    for epoch in range(100_000//len(loader)):
        for x, name in tqdm(loader, desc=str(epoch), leave=False):
            optimizer.zero_grad()
            label = frontend(x.to(device))
            x = encoder(label)
            pred = decoder(x)
            vgg_pred = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:]))
            vgg_label = vgg16(label.expand(label.shape[0], 3, *label.shape[2:]))
    
            score = loss_fun(vgg_pred, vgg_label)
            score.backward()
            optimizer.step()
            writer.add_scalar('loss', score.item(), step)
    
            if step%50==0 :
                images = [(e-e.min())/(e.max()-e.min()) for e in label[:8]]
                grid = make_grid(images)
                writer.add_image('target', grid, step)
                writer.add_embedding(x.detach(), global_step=step, label_img=label)
                images = [(e-e.min())/(e.max()-e.min()) for e in pred[:8]]
                grid = make_grid(images)
                writer.add_image('reconstruct', grid, step)
    
            step += 1
            if step % 500 == 0:
                scheduler.step()
        torch.save(model.state_dict(), modelname)