Skip to content
Snippets Groups Projects
Commit ec9790c1 authored by paul.best's avatar paul.best
Browse files

fix

parent c2511aff
No related branches found
No related tags found
No related merge requests found
...@@ -65,7 +65,7 @@ for epoch in range(100_000//len(loader)): ...@@ -65,7 +65,7 @@ for epoch in range(100_000//len(loader)):
writer.add_scalar('loss', score.item(), step) writer.add_scalar('loss', score.item(), step)
loss.append(score.item()) loss.append(score.item())
if len(loss) > 1000 and min(loss) - 1e-2 < min(loss[-1000:]): if len(loss) > 1000 and min(loss[:-1000]) - 1e-2 < min(loss[-1000:]):
print('Early stop') print('Early stop')
exit() exit()
...@@ -106,8 +106,7 @@ for epoch in range(100_000//len(loader)): ...@@ -106,8 +106,7 @@ for epoch in range(100_000//len(loader)):
# df.loc[idxs, 'cluster'] = clusters.astype(int) # df.loc[idxs, 'cluster'] = clusters.astype(int)
mask = ~df.loc[idxs].label.isna() mask = ~df.loc[idxs].label.isna()
clusters, labels = clusters[mask], df.loc[idxs[mask]].label clusters, labels = clusters[mask], df.loc[idxs[mask]].label
NMIs.append(metrics.normalized_mutual_info_score(labels, clusters)) writer.add_scalar('NMI HDBSCAN', metrics.normalized_mutual_info_score(labels, clusters), step)
writer.add_scalar('NMI HDBSCAN', NMIs[-1], step)
try: try:
writer.add_scalar('ARI HDBSCAN', metrics.adjusted_rand_score(labels, clusters), step) writer.add_scalar('ARI HDBSCAN', metrics.adjusted_rand_score(labels, clusters), step)
except: except:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment