Skip to content
Snippets Groups Projects
Commit d0872e1e authored by Paul Best's avatar Paul Best
Browse files

update

parent df484cc9
No related branches found
No related tags found
No related merge requests found
File mode changed from 100644 to 100755
......@@ -5,4 +5,3 @@ cassin_vireo
black-headed_grosbeaks
humpback
dolphin
otter
......@@ -46,7 +46,7 @@ meta = {
'sampleDur': 2
},
'humpback2':{
'nfft': 2048,
'nfft': 1024,
'sr': 11025,
'sampleDur': 2
},
......@@ -98,7 +98,7 @@ frontend = {
STFT(nfft, int((sampleDur*sr - nfft)/128)),
Log1p(7, trainable=False),
nn.InstanceNorm2d(1),
nn.AdaptiveMaxPool2d((128, 128))
nn.AdaptiveMaxPool2d((n_mel, 128))
),
'pcenMel': lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)),
......
......@@ -9,19 +9,21 @@ info = {
'cassin_vireo': ['cassin vireo', 'arriaga2015bird', 'bird'],
'black-headed_grosbeaks': ['black-headed grosbeaks', 'arriaga2015bird', 'bird'],
'zebra_finch': ['zebra finch', 'elie2018zebra', 'bird'],
'otter': ['otter', '', ''],
'humpback': ['humpback whale', 'malige2021use', 'cetacean'],
'dolphin': ['bottlenose dolphin', 'sayigh2022sarasota', 'cetacean']
}
fig, ax = plt.subplots(nrows=4, ncols=3, figsize=(10, 10))
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(15, 10))
for i, specie in enumerate(species):
df = pd.read_csv(f'{specie}/{specie}.csv')
ax[i//3, i%3].bar(range(df.label.nunique() + 1), list(df.label.value_counts()) + [df.label.isna().sum()], log=True)
ax[i//3, i%3].set_title(specie)
ax[i//4, i%4].bar(range(df.label.nunique() + 1), list(df.label.value_counts()) + [df.label.isna().sum()], log=True)
ax[i//4, i%4].set_title(specie)
plt.tight_layout()
plt.savefig('annot_distrib.pdf')
a = "Specie & \# Classes & \# Samples & Annotations \% \\\\ \hline \n"
a = "\\textbf{Specie and source} & \\textbf{\# Unit types} & \\textbf{\# Vocalisations} & \\textbf{\% Labelling} \\\\ \hline \n"
for specie in species:
df = pd.read_csv(f'{specie}/{specie}.csv')
a += f"{info[specie][0]} \cite{{{info[specie][1]}}} & {df.label.nunique()} & {len(df)} & {int(100*(~df.label.isna()).sum()/len(df))} \\\\ \hline \n"
......
import matplotlib.pyplot as plt
import pandas as pd, numpy as np
import torch, hdbscan
import models, utils as u
species = np.loadtxt('good_species.txt', dtype=str)
fig, ax = plt.subplots(nrows=len(species), figsize=(7, 10))
for i, specie in enumerate(species):
meta = models.meta[specie]
frontend = models.frontend['pcenMel'](meta['sr'], meta['nfft'], meta['sampleDur'], 128)
dic = np.load(f'{specie}/encodings//encodings_{specie}_256_pcenMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
idxs, X = dic['idxs'], dic['umap']
df = pd.read_csv(f'{specie}/{specie}.csv')
clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
df.loc[idxs, 'cluster'] = clusters.astype(int)
for j, cluster in enumerate(np.random.choice(np.arange(df.cluster.max()), 5)):
for k, (x, name) in enumerate(torch.utils.data.DataLoader(u.Dataset(df[df.cluster==cluster].sample(10), f'{specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=1)):
ax[i].imshow(frontend(x).squeeze().numpy(), extent=[k, k+1, j, j+1], origin='lower', aspect='auto')
ax[i].set_xticks([])
ax[i].set_yticks([])
# ax[i].grid(color='w', xdata=np.arange(1, 10), ydata=np.arange(1, 5))
ax[i].set_ylabel(specie.replace('_', ' '))
ax[i].set_xlim(0, 10)
ax[i].set_ylim(0, 5)
plt.tight_layout()
plt.savefig('clusters.pdf')
\ No newline at end of file
import matplotlib.pyplot as plt
import pandas as pd, numpy as np
frontends = ['biosound', 'vggish', '256_logMel128', '256_logSTFT', '256_Mel128', '256_pcenMel128', '512_pcenMel128', '128_pcenMel128', '64_pcenMel128', '32_pcenMel128', '16_pcenMel128']
fn = open('hdbscan_HP.txt', 'r')
out = []
for l in fn.readlines():
l = l[:-1].split(' ')
if len(l)==2:
specie, frontend = l[0], l[1]
else:
out.append({'specie':specie, 'frontend':frontend, 'nmi':float(l[0]), 'mcs':int(l[1]), 'ms': l[2], 'eps': float(l[3]), 'al': l[4]})
df = pd.DataFrame().from_dict(out)
df.ms = df.ms.replace('None', 0).astype(int)
df.to_csv('hdbscan_HP3.csv', index=False)
best = df.loc[df.groupby(["specie", 'frontend']).nmi.idxmax()]
res = [(conf, (grp.set_index(['specie', 'frontend']).nmi / best.set_index(['specie', 'frontend']).nmi).sum()) for conf, grp in df.groupby(['mcs', 'ms', 'eps', 'al'])]
conf = res[np.argmax([r[1] for r in res])]
print('best HP', conf)
conf = conf[0]
fig, ax = plt.subplots(ncols = 2, figsize=(10, 5), sharex=True, sharey=True)
for s, grp in df.groupby('specie'):
ax[0].scatter([grp[grp.frontend==f].nmi.max() for f in frontends], np.arange(len(frontends)))
ax[0].grid()
ax[1].grid()
ax[0].set_yticks(np.arange(len(frontends)))
ax[0].set_yticklabels(frontends)
for s, grp in df[(( df.mcs==conf[0] )&( df.ms==conf[1] )&( df.eps==conf[2] )&( df.al==conf[3] ))].groupby('specie'):
ax[1].scatter([grp[grp.frontend==f].nmi.max() for f in frontends], np.arange(len(frontends)), label=s)
ax[1].legend(bbox_to_anchor=(1,1))
plt.tight_layout()
ax[0].set_title('Best HDBSCAN settings')
ax[1].set_title('Shared HDBSCAN settings')
ax[0].set_xlabel('NMI')
ax[1].set_xlabel('NMI')
plt.tight_layout()
plt.savefig('NMIs_hdbscan.pdf')
import hdbscan
import matplotlib.pyplot as plt
import pandas as pd, numpy as np
species = np.loadtxt('good_species.txt', dtype=str)
info = {
'bengalese_finch1': ['bengalese finch', 'nicholson2017bengalese', 'bird'],
'bengalese_finch2': ['bengalese finch', 'koumura2016birdsongrecognition', 'bird'],
'california_thrashers': ['california trashers', 'arriaga2015bird', 'bird'],
'cassin_vireo': ['cassin vireo', 'arriaga2015bird', 'bird'],
'black-headed_grosbeaks': ['black-headed grosbeaks', 'arriaga2015bird', 'bird'],
'zebra_finch': ['zebra finch', 'elie2018zebra', 'bird'],
'otter': ['otter', '', ''],
'humpback': ['humpback whale', 'malige2021use', 'cetacean'],
'dolphin': ['bottlenose dolphin', 'sayigh2022sarasota', 'cetacean']
}
out = "\\textbf{Specie and source} & \\textbf{\# labels} & \\textbf{\# clusters} & \\textbf{\# discr. clusters} & \\textbf{\% clustered vocs} & \\textbf{\# missed labels} \\\\ \hline \n"
for specie in species:
dic = np.load(f'{specie}/encodings//encodings_{specie}_256_pcenMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
idxs, X = dic['idxs'], dic['umap']
df = pd.read_csv(f'{specie}/{specie}.csv')
clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
df.loc[idxs, 'cluster'] = clusters.astype(int)
mask = ~df.loc[idxs].label.isna()
print(specie)
labelled = df[~df.label.isna()]
goodClusters, missedLabels = [], []
for l, grp in labelled.groupby('label'):
precisions = grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()
best = precisions.idxmax()
goodClusters.extend(precisions[precisions > 0.9].index)
if not (precisions > .9).any():
missedLabels.append(l)
# print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \
# with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f}\
# and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}')
print(f'{len(goodClusters)} clusters would sort {df.cluster.isin(goodClusters).sum()/len(df)*100:.0f}% of samples')
print(f'{len(goodClusters)/df.label.nunique():.1f} cluster per label in avg)')
print(f'{len(missedLabels)} over {df.label.nunique()} missed labels')
out += f"{info[specie][0]} \cite{{{info[specie][1]}}} & {df.label.nunique()} & {df.cluster.nunique()-1} & {len(goodClusters)} & {df.cluster.isin(goodClusters).sum()/len(df)*100:.0f} & {len(missedLabels)} \\\\ \hline \n"
f = open('cluster_distrib.tex', 'w')
f.write(out)
f.close()
import matplotlib.pyplot as plt
import numpy as np, pandas as pd
species = np.loadtxt('good_species.txt', dtype=str)
fig, ax = plt.subplots(ncols=4, nrows=2, figsize=(20, 10))
non_zero_min = lambda arr: np.min(arr[arr!=0])
for i, specie in enumerate(species):
dic = np.load(f'{specie}/encodings//encodings_{specie}_256_pcenMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
df = pd.read_csv(f'{specie}/{specie}.csv')
df.loc[dic['idxs'], 'umap_x'] = dic['umap'][:,0]
df.loc[dic['idxs'], 'umap_y'] = dic['umap'][:,1]
ax[i//4,i%4].scatter(df[df.label.isna()].umap_x, df[df.label.isna()].umap_y, s=2, color='grey')
for label, grp in df[~df.label.isna()].groupby('label'):
grp = grp.sample(min(len(grp), 1000))
ax[i//4,i%4].scatter(grp.umap_x, grp.umap_y, s=2)
ax[i//4,i%4].set_title(specie.replace('_', ' '))
sampSize = 100
X = df.sample(sampSize)[['umap_x', 'umap_y']].to_numpy()
Y = np.vstack([np.random.normal(np.mean(X[:,0]), np.std(X[:,0]), sampSize), np.random.normal(np.mean(X[:,1]), np.std(X[:,1]), sampSize)]).T
U = np.sum([min(np.sqrt(np.sum((y - df[['umap_x', 'umap_y']].to_numpy())**2, axis=1))) for y in Y])
W = np.sum([non_zero_min(np.sqrt(np.sum((x - df[['umap_x', 'umap_y']].to_numpy())**2, axis=1))) for x in X])
hopkins = U / (U + W)
ax[i//4, i%4].text(ax[i//4, i%4].get_xlim()[0] + .5, ax[i//4, i%4].get_ylim()[0] + .5, f'{hopkins:.2f}', fontsize=15)
plt.tight_layout()
plt.savefig('projections.pdf')
plt.savefig('projections.png')
\ No newline at end of file
......@@ -3,26 +3,56 @@ import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import metrics
from tqdm import tqdm
import os
species = np.loadtxt('good_species.txt', dtype=str)
frontends = ['16_logMel128', '16_logSTFT', '16_Mel128', '16_pcenMel128', '8_pcenMel64', '32_pcenMel128', '64_pcenMel128']
frontends = ['biosound'] #['vggish', '256_logMel128', '256_logSTFT', '256_Mel128', '32_pcenMel128', '64_pcenMel128', '128_pcenMel128', '256_pcenMel128', '512_pcenMel128']
file = open('hdbscan_HP2.txt', 'w')
plt.figure()
for specie in species:
for specie in ['humpback', 'dolphin', 'black-headed_grosbeaks', 'california_thrashers']: #species:
df = pd.read_csv(f'{specie}/{specie}.csv')
nmis = []
for i, frontend in enumerate(frontends):
print(specie, frontend)
dic = np.load(f'{specie}/encodings_{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap']
for i, frontend in tqdm(enumerate(frontends), desc=specie, total=len(frontends)):
file.write(specie+' '+frontend+'\n')
fn = f'{specie}/encodings/encodings_' + (f'{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy' if not frontend in ['vggish', 'biosound'] else frontend+'.npy')
if not os.path.isfile(fn):
nmis.append(None)
print('not found')
continue
dic = np.load(fn, allow_pickle=True).item()
idxs, X = dic['idxs'], dic['umap']
clusters = hdbscan.HDBSCAN(min_cluster_size=100, min_samples=20, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
# clusters = hdbscan.HDBSCAN(min_cluster_size=max(10, len(df)//100), core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
# df.loc[idxs, 'cluster'] = clusters.astype(int)
# mask = ~df.loc[idxs].label.isna()
# clusters, labels = clusters[mask], df.loc[idxs[mask]].label
# nmis.append(metrics.normalized_mutual_info_score(labels, clusters))
# df.drop('cluster', axis=1, inplace=True)
# continue
nmi = 0
for mcs in [10, 20, 50, 100, 150, 200]:
for ms in [None, 3, 5, 10, 20, 30]:
for eps in [0.0, 0.01, 0.02, 0.05, 0.1]:
for al in ['leaf', 'eom']:
clusters = hdbscan.HDBSCAN(min_cluster_size=mcs, min_samples=ms, cluster_selection_epsilon=eps, \
core_dist_n_jobs=-1, cluster_selection_method=al).fit_predict(X)
df.loc[idxs, 'cluster'] = clusters.astype(int)
mask = ~df.loc[idxs].label.isna()
clusters, labels = clusters[mask], df.loc[idxs[mask]].label
nmis.append(metrics.normalized_mutual_info_score(labels, clusters))
tnmi = metrics.normalized_mutual_info_score(labels, clusters)
file.write(f'{tnmi} {mcs} {ms} {eps} {al}\n')
if tnmi > nmi :
nmi
df.drop('cluster', axis=1, inplace=True)
nmis.append(nmi)
plt.scatter(nmis, np.arange(len(frontends)), label=specie)
file.close()
plt.yticks(range(len(frontends)), frontends)
plt.ylabel('archi')
plt.xlabel('NMI with expert labels')
......
......@@ -17,14 +17,16 @@ frontend = models.frontend[args.frontend](meta['sr'], meta['nfft'], meta['sample
os.system(f'rm -R {args.specie}/annot_pngs/*')
for label, grp in df.groupby('label'):
os.system(f'mkdir -p "{args.specie}/annot_pngs/{label}"')
loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 100)), args.specie+'/audio/', meta['sr'], meta['sampleDur']),\
loader = torch.utils.data.DataLoader(u.Dataset(grp, args.specie+'/audio/', meta['sr'], meta['sampleDur']),\
batch_size=1, num_workers=4, pin_memory=True)
for x, idx in tqdm(loader, desc=args.specie + ' ' + label, leave=False):
x = frontend(x).squeeze().detach()
assert not torch.isnan(x).any(), "Found a NaN in spectrogram... :/"
plt.figure()
plt.imshow(x, origin='lower', aspect='auto')
plt.savefig(f'{args.specie}/annot_pngs/{label}/{idx.item()}')
row = df.loc[idx.item()]
plt.savefig(f'{args.specie}/annot_pngs/{label}/{row.fn.split(".")[0]}_{row.pos:.2f}.png')
# plt.savefig(f'{args.specie}/annot_pngs/{label}/{idx.item()}')
plt.close()
......
......@@ -28,25 +28,28 @@ def process(idx):
if sig.ndim == 2:
sig = sig[:,0]
if len(sig) < meta['sampleDur'] * fs:
sig = np.concatenate([sig, np.zeros(int(self.sampleDur * fs) - len(sig))])
sig = np.concatenate([sig, np.zeros(int(meta['sampleDur'] * fs) - len(sig))])
if fs != meta['sr']:
sig = resample(sig, int(len(sig)/fs*meta['sr']))
sound = BioSound(soundWave=norm(sig), fs=fs)
sound = BioSound(soundWave=norm(sig), fs=meta['sr'])
sound.spectroCalc(max_freq=meta['sr']//2, spec_sample_rate=128//meta['sampleDur'])
sound.rms = sound.sound.std()
sound.ampenv(cutoff_freq = 20, amp_sample_rate = 1000)
sound.spectrum(f_high=meta['sr']//2 - 1)
sound.fundest(maxFund = 6000, minFund = 200, lowFc = 200, highFc = 6000,
sound.fundest(maxFund = 5000, minFund = 200, lowFc = 200, highFc = 5000,
minSaliency = 0.5, debugFig = 0,
minFormantFreq = 500, maxFormantBW = 500, windowFormant = 0.1,
method='Stack')
return [sound.__dict__[f] for f in feats]
res = p_tqdm.p_map(process, df.index[:100], num_cpus=16)
res = p_tqdm.p_map(process, df.index, num_cpus=14)
for i, mr in zip(df.index[:100], res):
for i, mr in zip(df.index, res):
for f, r in zip(feats, mr):
try:
df.loc[i, f] = r
except:
print(i, f, r)
df.to_csv(f'{args.specie}/{args.specie}_biosound.csv', index=False)
......@@ -7,7 +7,6 @@ import models, utils as u
import pandas as pd, numpy as np, torch
torch.multiprocessing.set_sharing_strategy('file_system')
parser = argparse.ArgumentParser()
parser.add_argument("specie", type=str)
args = parser.parse_args()
......@@ -16,7 +15,7 @@ df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
meta = models.meta[args.specie]
if not os.path.isfile(f'{args.specie}/encodings_vggish.npy'):
if not os.path.isfile(f'{args.specie}/encodings/encodings_vggish.npy'):
gpu = torch.device('cuda')
frontend = models.frontend['logMel_vggish'](meta['sr'], meta['nfft'], meta['sampleDur'], 64)
vggish = torch.hub.load('harritaylor/torchvggish', 'vggish')
......@@ -30,14 +29,14 @@ if not os.path.isfile(f'{args.specie}/encodings_vggish.npy'):
for x, idx in tqdm(loader, desc='test '+args.specie, leave=False):
# encoding = model(x.to(gpu))
encoding = vggish(x.numpy().squeeze(0), fs=16000)
idxs.extend(idx)
encodings.extend(encoding.cpu().detach())
idxs.extend(idx.numpy())
encodings.extend(encoding.cpu().numpy())
idxs, encodings = np.array(idxs), np.stack(encodings)
X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
np.save(f'{args.specie}/encodings_vggish.npy', {'idxs':idxs, 'encodings':encodings, 'umap':X})
np.save(f'{args.specie}/encodings/encodings_vggish.npy', {'idxs':idxs, 'encodings':encodings, 'umap':X})
else:
dic = np.load(f'{args.specie}/encodings_vggish.npy', allow_pickle=True).item()
dic = np.load(f'{args.specie}/encodings/encodings_vggish.npy', allow_pickle=True).item()
idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap']
clusters = hdbscan.HDBSCAN(min_cluster_size=50, min_samples=5, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
......@@ -50,7 +49,7 @@ plt.figure(figsize=(20, 10))
plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey')
plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20')
plt.tight_layout()
plt.savefig(f'{args.specie}/vggish_projection_clusters.png')
plt.savefig(f'{args.specie}/projections/vggish_projection_clusters.png')
plt.figure(figsize=(20, 10))
plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey')
......@@ -58,7 +57,7 @@ for l, grp in df.groupby('label'):
plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l)
plt.legend()
plt.tight_layout()
plt.savefig(f'{args.specie}/vggish_projection_labels.png')
plt.savefig(f'{args.specie}/projections/vggish_projection_labels.png')
clusters, labels = clusters[mask], df.loc[idxs[mask]].label
......
......@@ -22,8 +22,8 @@ df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
print(f'Tests for model {modelname}')
print(f'{len(df)} available vocs')
if os.path.isfile(f'{args.specie}/encodings_{modelname[:-4]}npy'):
dic = np.load(f'{args.specie}/encodings_{modelname[:-4]}npy', allow_pickle=True).item()
if os.path.isfile(f'{args.specie}/encodings/encodings_{modelname[:-4]}npy'):
dic = np.load(f'{args.specie}/encodings/encodings_{modelname[:-4]}npy', allow_pickle=True).item()
idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap']
else:
gpu = torch.device('cuda')
......@@ -31,7 +31,7 @@ else:
encoder = models.__dict__[args.encoder](*((args.bottleneck // 16, (4, 4)) if args.nMel == 128 else (args.bottleneck // 8, (2, 4))))
decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(gpu)
model.load_state_dict(torch.load(f'{args.specie}/{modelname}'))
model.load_state_dict(torch.load(f'{args.specie}/weights/{modelname}'))
model.eval()
loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=64, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
with torch.no_grad():
......@@ -43,9 +43,11 @@ else:
idxs, encodings = np.array(idxs), np.stack(encodings)
X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
np.save(f'{args.specie}/encodings_{modelname[:-4]}npy', {'idxs':idxs, 'encodings':encodings, 'umap':X})
np.save(f'{args.specie}/encodings/encodings_{modelname[:-4]}npy', {'idxs':idxs, 'encodings':encodings, 'umap':X})
clusters = hdbscan.HDBSCAN(min_cluster_size=50, min_samples=5, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
clusters = hdbscan.HDBSCAN(min_cluster_size=len(df)//100, min_samples=5, core_dist_n_jobs=-1, cluster_selection_method='eom').fit_predict(X)
#clusters = hdbscan.HDBSCAN(min_cluster_size=20, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
# clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
df.loc[idxs, 'cluster'] = clusters.astype(int)
mask = ~df.loc[idxs].label.isna()
......@@ -55,7 +57,7 @@ plt.figure(figsize=(20, 10))
plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey')
plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20')
plt.tight_layout()
plt.savefig(f'{args.specie}/{modelname[:-5]}_projection_clusters.png')
plt.savefig(f'{args.specie}/projections/{modelname[:-5]}_projection_clusters.png')
plt.figure(figsize=(20, 10))
plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey')
......@@ -63,7 +65,7 @@ for l, grp in df.groupby('label'):
plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l)
plt.legend()
plt.tight_layout()
plt.savefig(f'{args.specie}/{modelname[:-5]}_projection_labels.png')
plt.savefig(f'{args.specie}/projections/{modelname[:-5]}_projection_labels.png')
clusters, labels = clusters[mask], df.loc[idxs[mask]].label
......
......@@ -67,7 +67,7 @@ for epoch in range(100_000//len(loader)):
if len(loss) > 2000 and np.median(loss[-2000:-1000]) < np.median(loss[-1000:]):
print('Early stop')
torch.save(model.state_dict(), f'{args.specie}/{modelname}')
torch.save(model.state_dict(), f'{args.specie}/weights/{modelname}')
exit()
step += 1
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment