Select Git revision
ci-settings.xml
-
Emmanuel Bruno authoredEmmanuel Bruno authored
train_AE.py 3.84 KiB
from torchvision.utils import make_grid
import torch
import pandas as pd
import utils as u, models
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import argparse
parser = argparse.ArgumentParser(description="""This script trains an auto-encoder to compress and depcompress vocalisation spectrograms.
Reconstruction quality can be monitored via tensorboard ($tensorboard --logdir=runs/ --bind_all)""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. A column \'filename\' (path of the soundfile) and a column \'pos\{ (center of the detection in seconds) are needed")
parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
args = parser.parse_args()
# init AE architecture
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32"
assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)"
frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
# training / optimisation setup
lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .99**epoch)
vgg16 = models.vgg16.eval().to(device)
loss_fun = torch.nn.MSELoss()
# data loader
df = pd.read_csv(args.detections)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.weights'
step, writer = 0, SummaryWriter('runs/'+modelname)
print(f'Go for model {modelname} with {len(df)} vocalizations')
for epoch in range(100_000//len(loader)):
for x, name in tqdm(loader, desc=str(epoch), leave=False):
optimizer.zero_grad()
label = frontend(x.to(device))
x = encoder(label)
pred = decoder(x)
vgg_pred = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:]))
vgg_label = vgg16(label.expand(label.shape[0], 3, *label.shape[2:]))
score = loss_fun(vgg_pred, vgg_label)
score.backward()
optimizer.step()
writer.add_scalar('loss', score.item(), step)
if step%50==0 :
images = [(e-e.min())/(e.max()-e.min()) for e in label[:8]]
grid = make_grid(images)
writer.add_image('target', grid, step)
writer.add_embedding(x.detach(), global_step=step, label_img=label)
images = [(e-e.min())/(e.max()-e.min()) for e in pred[:8]]
grid = make_grid(images)
writer.add_image('reconstruct', grid, step)
step += 1
if step % 500 == 0:
scheduler.step()
torch.save(model.state_dict(), modelname)