Skip to content
Snippets Groups Projects
Commit 93c4b08c authored by Paul Best's avatar Paul Best
Browse files

Update file train_AE.py

parent f0f14169
No related branches found
No related tags found
No related merge requests found
...@@ -38,7 +38,7 @@ loss_fun = torch.nn.MSELoss() ...@@ -38,7 +38,7 @@ loss_fun = torch.nn.MSELoss()
df = pd.read_csv(args.detections) df = pd.read_csv(args.detections)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn) loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.stdc' modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.weights'
step, writer = 0, SummaryWriter('runs/'+modelname) step, writer = 0, SummaryWriter('runs/'+modelname)
print(f'Go for model {modelname} with {len(df)} vocalizations') print(f'Go for model {modelname} with {len(df)} vocalizations')
for epoch in range(100_000//len(loader)): for epoch in range(100_000//len(loader)):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment