Skip to content
Snippets Groups Projects
Commit 5bfbe745 authored by paul.best's avatar paul.best
Browse files

loss early stop

parent 67a61790
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,7 @@ args = parser.parse_args()
df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
print(f'{len(df)} available vocs')
modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel}_{args.encoder}_decod2_BN_nomaxPool.stdc'
modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel if 'Mel' in args.frontend else ''}_{args.encoder}_decod2_BN_nomaxPool.stdc'
gpu = torch.device(f'cuda:{args.cuda}')
writer = SummaryWriter(f'runs2/{modelname}')
os.system(f'cp *.py runs2/{modelname}')
......@@ -47,7 +47,7 @@ loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', meta
batch_size=args.batch_size, shuffle=True, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn)
MSE = torch.nn.MSELoss()
step, NMIs = 0, []
step, loss = 0, []
for epoch in range(100_000//len(loader)):
for x, name in tqdm(loader, desc=str(epoch), leave=False):
optimizer.zero_grad()
......@@ -63,6 +63,11 @@ for epoch in range(100_000//len(loader)):
score.backward()
optimizer.step()
writer.add_scalar('loss', score.item(), step)
loss.append(score.item())
if min(loss) - 1e-2 < min(loss[-1000:]):
print('Early stop')
exit()
# TEST ROUTINE
if step % 500 == 0:
......@@ -93,6 +98,8 @@ for epoch in range(100_000//len(loader)):
X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
except:
print('\rUMAP failed :s')
step += 1
model[1:].train()
continue
print('\rRunning HDBSCAN...', end='')
clusters = hdbscan.HDBSCAN(min_cluster_size=len(df)//100, min_samples=5, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
......@@ -151,10 +158,6 @@ for epoch in range(100_000//len(loader)):
# writer.add_histogram('K-Means Recalls ', np.array(recs), step)
# df.drop('cluster', axis=1, inplace=True)
if len(NMIs) > 10 and max(NMIs) > max(NMIs[-10:]):
print('\rEarly stop')
exit()
print('\r', end='')
model[1:].train()
step += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment