diff --git a/.gitignore b/.gitignore index d4acccc40c48fb2183901bcb00260a82f5810f71..76ccee7b16358c65b6c72e2d84db0c85d666f015 100755 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,8 @@ *__pycache__ *log humpback2/annots - +gibbon +new_specie/*/ otter/pone.0112562.s003.xlsx zebra_finch/Library_notes.pdf annot_distrib.pdf diff --git a/plot_annot_distrib.py b/plot_annot_distrib.py index fc0433b3742c14ed2b6295b980b60dcea319a94e..c79e709de411ccc9a322552db712c0bba07f0f25 100755 --- a/plot_annot_distrib.py +++ b/plot_annot_distrib.py @@ -11,6 +11,7 @@ info = { 'zebra_finch': ['zebra finch', 'elie2018zebra', 'bird'], 'otter': ['otter', '', ''], 'humpback': ['humpback whale', 'malige2021use', 'cetacean'], + 'humpback2':['humpback whale', 'malige2021use', 'cetacean'], 'dolphin': ['bottlenose dolphin', 'sayigh2022sarasota', 'cetacean'] } diff --git a/plot_clusters.py b/plot_clusters.py old mode 100644 new mode 100755 index 91d494f4548bb74d4157409bc43c0b2ff56039a4..e71515f124b9a60d94eefea46c14738fa542e5e2 --- a/plot_clusters.py +++ b/plot_clusters.py @@ -5,7 +5,7 @@ import models, utils as u species = np.loadtxt('good_species.txt', dtype=str) -fig, ax = plt.subplots(nrows=len(species), figsize=(7, 10)) +fig, ax = plt.subplots(nrows=len(species), figsize=(7, 10), dpi=200) for i, specie in enumerate(species): meta = models.meta[specie] frontend = models.frontend['pcenMel'](meta['sr'], meta['nfft'], meta['sampleDur'], 128) @@ -14,15 +14,18 @@ for i, specie in enumerate(species): df = pd.read_csv(f'{specie}/{specie}.csv') clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X) df.loc[idxs, 'cluster'] = clusters.astype(int) - for j, cluster in enumerate(np.random.choice(np.arange(df.cluster.max()), 5)): - for k, (x, name) in enumerate(torch.utils.data.DataLoader(u.Dataset(df[df.cluster==cluster].sample(10), f'{specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=1)): - ax[i].imshow(frontend(x).squeeze().numpy(), extent=[k, k+1, j, j+1], origin='lower', aspect='auto') + for j, cluster in enumerate(np.random.choice(np.arange(df.cluster.max()), 4)): + for k, (x, name) in enumerate(torch.utils.data.DataLoader(u.Dataset(df[df.cluster==cluster].sample(8), f'{specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=1)): + spec = frontend(x).squeeze().numpy() + ax[i].imshow(spec, extent=[k, k+1, j, j+1], origin='lower', aspect='auto', cmap='Greys', vmin=np.quantile(spec, .2), vmax=np.quantile(spec, .98)) ax[i].set_xticks([]) ax[i].set_yticks([]) # ax[i].grid(color='w', xdata=np.arange(1, 10), ydata=np.arange(1, 5)) ax[i].set_ylabel(specie.replace('_', ' ')) - ax[i].set_xlim(0, 10) - ax[i].set_ylim(0, 5) - + ax[i].set_xlim(0, 8) + ax[i].set_ylim(0, 4) + ax[i].vlines(np.arange(1, 8), 0, 4, linewidths=1, color='black') + ax[i].hlines(np.arange(1, 4), 0, 8, linewidths=1, color='black') +plt.subplots_adjust(wspace=0.1) plt.tight_layout() plt.savefig('clusters.pdf') \ No newline at end of file diff --git a/plot_hdbscan_HP.py b/plot_hdbscan_HP.py deleted file mode 100755 index ff0ac01dae8e9d335105d75d2696a4c9152dcf70..0000000000000000000000000000000000000000 --- a/plot_hdbscan_HP.py +++ /dev/null @@ -1,44 +0,0 @@ -import matplotlib.pyplot as plt -import pandas as pd, numpy as np - - -frontends = ['biosound', 'vggish', '256_logMel128', '256_logSTFT', '256_Mel128', '256_pcenMel128', '512_pcenMel128', '128_pcenMel128', '64_pcenMel128', '32_pcenMel128', '16_pcenMel128'] - -fn = open('hdbscan_HP.txt', 'r') - -out = [] -for l in fn.readlines(): - l = l[:-1].split(' ') - if len(l)==2: - specie, frontend = l[0], l[1] - else: - out.append({'specie':specie, 'frontend':frontend, 'nmi':float(l[0]), 'mcs':int(l[1]), 'ms': l[2], 'eps': float(l[3]), 'al': l[4]}) -df = pd.DataFrame().from_dict(out) - -df.ms = df.ms.replace('None', 0).astype(int) - -df.to_csv('hdbscan_HP3.csv', index=False) - -best = df.loc[df.groupby(["specie", 'frontend']).nmi.idxmax()] -res = [(conf, (grp.set_index(['specie', 'frontend']).nmi / best.set_index(['specie', 'frontend']).nmi).sum()) for conf, grp in df.groupby(['mcs', 'ms', 'eps', 'al'])] -conf = res[np.argmax([r[1] for r in res])] -print('best HP', conf) -conf = conf[0] - -fig, ax = plt.subplots(ncols = 2, figsize=(10, 5), sharex=True, sharey=True) -for s, grp in df.groupby('specie'): - ax[0].scatter([grp[grp.frontend==f].nmi.max() for f in frontends], np.arange(len(frontends))) -ax[0].grid() -ax[1].grid() -ax[0].set_yticks(np.arange(len(frontends))) -ax[0].set_yticklabels(frontends) -for s, grp in df[(( df.mcs==conf[0] )&( df.ms==conf[1] )&( df.eps==conf[2] )&( df.al==conf[3] ))].groupby('specie'): - ax[1].scatter([grp[grp.frontend==f].nmi.max() for f in frontends], np.arange(len(frontends)), label=s) -ax[1].legend(bbox_to_anchor=(1,1)) -plt.tight_layout() -ax[0].set_title('Best HDBSCAN settings') -ax[1].set_title('Shared HDBSCAN settings') -ax[0].set_xlabel('NMI') -ax[1].set_xlabel('NMI') -plt.tight_layout() -plt.savefig('NMIs_hdbscan.pdf') diff --git a/plot_main_results.py b/plot_main_results.py new file mode 100755 index 0000000000000000000000000000000000000000..75de2a853666619347d35ae92f30b20c939fd857 --- /dev/null +++ b/plot_main_results.py @@ -0,0 +1,31 @@ +import matplotlib.pyplot as plt +import pandas as pd, numpy as np + +all_frontends = [['256_logMel128', '256_logMel128_noprcptl', 'spec32', 'biosound'][::-1], + ['256_logMel128', 'allbut_AE_256_logMel128', 'openl3', 'wav2vec2', 'crepe'][::-1], + ['256_logMel128', '256_Mel128', '256_pcenMel128', '256_logSTFT']] + +all_frontendNames = [['AE prcptl', 'AE MSE', 'Spectro.', 'PAFs'][::-1], + ['AE', 'gen. AE', 'OpenL3', 'Wav2Vec2', 'CREPE'][::-1], + ['log-Mel', 'Mel', 'PCEN-Mel', 'log-STFT']] + +all_plotNames = ['handcraft', 'deepembed', 'frontends'] +for frontends, frontendNames, plotName in zip(all_frontends, all_frontendNames, all_plotNames): + df = pd.read_csv('hdbscan_HP.csv') + df.loc[df.ms.isna(), 'ms'] = 0 + best = df.loc[df.groupby(["specie", 'frontend']).nmi.idxmax()] + res = [(conf, (grp.set_index(['specie', 'frontend']).nmi / best.set_index(['specie', 'frontend']).nmi).sum()) for conf, grp in df.groupby(['ncomp', 'mcs', 'ms', 'eps', 'al'])] + conf = res[np.argmax([r[1] for r in res])] + print('best HP', conf) + conf = conf[0] + + df = df[((df.ncomp == conf[0])&(df.mcs == conf[1] )&( df.ms==conf[2] )&( df.eps==conf[3] )&(df.al==conf[4] ))] + df.specie = df.specie.replace('humpback2', 'humpback\n(small)') + species = df.specie.unique() + out = pd.DataFrame({fn:[df[((df.specie==s)&(df.frontend==f))].iloc[0].nmi for s in species] for f, fn in zip(frontends, frontendNames)}, index=[s.replace('_','\n') for s in species]) + out.plot.bar(figsize=(9, 3), rot=0) + plt.legend(bbox_to_anchor=(1,1)) + plt.ylabel('NMI') + plt.grid(axis='y') + plt.tight_layout() + plt.savefig(f'NMIs_hdbscan_barplot_{plotName}.pdf') diff --git a/plot_prec_rec.py b/plot_prec_rec.py old mode 100644 new mode 100755 index 4f5a7923a2cb28ee403d35ad8386c198e59b42c2..22f38f797ae156331408c01fbb871f29c3e9d6f9 --- a/plot_prec_rec.py +++ b/plot_prec_rec.py @@ -10,30 +10,43 @@ info = { 'cassin_vireo': ['cassin vireo', 'arriaga2015bird', 'bird'], 'black-headed_grosbeaks': ['black-headed grosbeaks', 'arriaga2015bird', 'bird'], 'zebra_finch': ['zebra finch', 'elie2018zebra', 'bird'], - 'otter': ['otter', '', ''], 'humpback': ['humpback whale', 'malige2021use', 'cetacean'], + 'humpback2': ['humpback whale (small)', 'malige2021use', 'cetacean'], 'dolphin': ['bottlenose dolphin', 'sayigh2022sarasota', 'cetacean'] } -out = "\\textbf{Specie and source} & \\textbf{\# labels} & \\textbf{\# clusters} & \\textbf{\# discr. clusters} & \\textbf{\% clustered vocs} & \\textbf{\# missed labels} \\\\ \hline \n" +#out = "\\textbf{Specie and source} & \\textbf{\# labels} & \\textbf{\# clusters} & \\textbf{\% discr. clusters} & \\textbf{\% clustered vocs} & \\textbf{\# missed labels} \\\\ \hline \n" +out = "" for specie in species: - dic = np.load(f'{specie}/encodings//encodings_{specie}_256_pcenMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() - idxs, X = dic['idxs'], dic['umap'] + dic = np.load(f'{specie}/encodings//encodings_{specie}_256_logMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() + idxs, X = dic['idxs'], dic['umap8'] df = pd.read_csv(f'{specie}/{specie}.csv') - clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X) + clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpbacjjk' in specie else 'eom').fit_predict(X) df.loc[idxs, 'cluster'] = clusters.astype(int) + + dic = np.load(f'{specie}/encodings/encodings_spec32.npy', allow_pickle=True).item() + idxs, X = dic['idxs'], dic['umap8'] + clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpbacjjk' in specie else 'eom').fit_predict(X) + df.loc[idxs, 'cluster2'] = clusters.astype(int) + mask = ~df.loc[idxs].label.isna() print(specie) labelled = df[~df.label.isna()] - goodClusters, missedLabels = [], [] + goodClusters, goodClusters2, missedLabels, missedLabels2 = [], [], [], [] for l, grp in labelled.groupby('label'): precisions = grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count() best = precisions.idxmax() goodClusters.extend(precisions[precisions > 0.9].index) if not (precisions > .9).any(): missedLabels.append(l) + + precisions = grp.groupby('cluster2').fn.count() / labelled.groupby('cluster2').fn.count() + goodClusters2.extend(precisions[precisions > 0.9].index) + if not (precisions > .9).any(): + missedLabels2.append(l) + # print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \ # with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f}\ # and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}') @@ -43,7 +56,9 @@ for specie in species: print(f'{len(goodClusters)/df.label.nunique():.1f} cluster per label in avg)') print(f'{len(missedLabels)} over {df.label.nunique()} missed labels') - out += f"{info[specie][0]} \cite{{{info[specie][1]}}} & {df.label.nunique()} & {df.cluster.nunique()-1} & {len(goodClusters)} & {df.cluster.isin(goodClusters).sum()/len(df)*100:.0f} & {len(missedLabels)} \\\\ \hline \n" + out += f"{info[specie][0]} \cite{{{info[specie][1]}}} & {df.label.nunique()} & {df.cluster.nunique()-1} & {df.cluster2.nunique()-1} & {len(goodClusters)/labelled.cluster.nunique()*100:.0f} &" + out += f"{len(goodClusters2)/labelled.cluster2.nunique()*100:.0f} & {df.cluster.isin(goodClusters).sum()/len(df[df.cluster.isin(labelled.cluster.unique())])*100:.0f} & " + out += f"{df.cluster2.isin(goodClusters2).sum()/len(df[df.cluster2.isin(labelled.cluster2.unique())])*100:.0f} & {len(missedLabels)} & {len(missedLabels2)} \\\\ \hline \n" f = open('cluster_distrib.tex', 'w') f.write(out) diff --git a/plot_projections.py b/plot_projections.py old mode 100644 new mode 100755 index 5ed7d55d3505103b2940d29489080ffd2cd6373c..1ad4cc9633b80a539cee7ecea81b36da749f9f0e --- a/plot_projections.py +++ b/plot_projections.py @@ -1,33 +1,38 @@ +from tqdm import tqdm import matplotlib.pyplot as plt import numpy as np, pandas as pd -species = np.loadtxt('good_species.txt', dtype=str) - -fig, ax = plt.subplots(ncols=4, nrows=2, figsize=(20, 10)) +fig, ax = plt.subplots(ncols=4, nrows=2, figsize=(13, 7)) non_zero_min = lambda arr: np.min(arr[arr!=0]) -for i, specie in enumerate(species): - dic = np.load(f'{specie}/encodings//encodings_{specie}_256_pcenMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() +for i, specie in tqdm(enumerate(['bengalese_finch1', 'bengalese_finch2', 'california_thrashers', 'cassin_vireo', 'black-headed_grosbeaks', 'humpback', 'humpback2', 'dolphin'])): + dic = np.load(f'{specie}/encodings//encodings_{specie}_256_logMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() df = pd.read_csv(f'{specie}/{specie}.csv') df.loc[dic['idxs'], 'umap_x'] = dic['umap'][:,0] df.loc[dic['idxs'], 'umap_y'] = dic['umap'][:,1] - ax[i//4,i%4].scatter(df[df.label.isna()].umap_x, df[df.label.isna()].umap_y, s=2, color='grey') + ax[i//4,i%4].scatter(df[df.label.isna()].umap_x, df[df.label.isna()].umap_y, s=1, color='grey') for label, grp in df[~df.label.isna()].groupby('label'): grp = grp.sample(min(len(grp), 1000)) - ax[i//4,i%4].scatter(grp.umap_x, grp.umap_y, s=2) - ax[i//4,i%4].set_title(specie.replace('_', ' ')) - + ax[i//4,i%4].scatter(grp.umap_x, grp.umap_y, s=1) + ax[i//4,i%4].set_title(specie.replace('_', ' ') if specie != 'humpback2' else 'humpback (small)') - sampSize = 100 + # Hopkins statistic + sampSize = len(df)//15 X = df.sample(sampSize)[['umap_x', 'umap_y']].to_numpy() Y = np.vstack([np.random.normal(np.mean(X[:,0]), np.std(X[:,0]), sampSize), np.random.normal(np.mean(X[:,1]), np.std(X[:,1]), sampSize)]).T U = np.sum([min(np.sqrt(np.sum((y - df[['umap_x', 'umap_y']].to_numpy())**2, axis=1))) for y in Y]) W = np.sum([non_zero_min(np.sqrt(np.sum((x - df[['umap_x', 'umap_y']].to_numpy())**2, axis=1))) for x in X]) hopkins = U / (U + W) - ax[i//4, i%4].text(ax[i//4, i%4].get_xlim()[0] + .5, ax[i//4, i%4].get_ylim()[0] + .5, f'{hopkins:.2f}', fontsize=15) + ax[i//4, i%4].text(ax[i//4, i%4].get_xlim()[0] + .5, ax[i//4, i%4].get_ylim()[0] + .5, f'{hopkins:.2f}', fontsize=10) + ax[i//4, i%4].set_xticks([]) + ax[i//4, i%4].set_yticks([]) + +#ax[1,3].set_xticks([]) +#ax[1,3].set_yticks([]) +#ax[1,3].set_frame_on(False) plt.tight_layout() plt.savefig('projections.pdf') -plt.savefig('projections.png') \ No newline at end of file +plt.savefig('projections.png') diff --git a/plot_results_hdbcsan.py b/plot_results_hdbcsan.py deleted file mode 100755 index 09cd3417896f539a4f5937a06536967b4a999265..0000000000000000000000000000000000000000 --- a/plot_results_hdbcsan.py +++ /dev/null @@ -1,63 +0,0 @@ -import hdbscan -import pandas as pd -import matplotlib.pyplot as plt -import numpy as np -from sklearn import metrics -from tqdm import tqdm -import os - -species = np.loadtxt('good_species.txt', dtype=str) -frontends = ['biosound'] #['vggish', '256_logMel128', '256_logSTFT', '256_Mel128', '32_pcenMel128', '64_pcenMel128', '128_pcenMel128', '256_pcenMel128', '512_pcenMel128'] - -file = open('hdbscan_HP2.txt', 'w') - -plt.figure() -for specie in ['humpback', 'dolphin', 'black-headed_grosbeaks', 'california_thrashers']: #species: - df = pd.read_csv(f'{specie}/{specie}.csv') - nmis = [] - for i, frontend in tqdm(enumerate(frontends), desc=specie, total=len(frontends)): - file.write(specie+' '+frontend+'\n') - fn = f'{specie}/encodings/encodings_' + (f'{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy' if not frontend in ['vggish', 'biosound'] else frontend+'.npy') - if not os.path.isfile(fn): - nmis.append(None) - print('not found') - continue - dic = np.load(fn, allow_pickle=True).item() - idxs, X = dic['idxs'], dic['umap'] - - # clusters = hdbscan.HDBSCAN(min_cluster_size=max(10, len(df)//100), core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X) - # df.loc[idxs, 'cluster'] = clusters.astype(int) - # mask = ~df.loc[idxs].label.isna() - # clusters, labels = clusters[mask], df.loc[idxs[mask]].label - # nmis.append(metrics.normalized_mutual_info_score(labels, clusters)) - # df.drop('cluster', axis=1, inplace=True) - # continue - - nmi = 0 - for mcs in [10, 20, 50, 100, 150, 200]: - for ms in [None, 3, 5, 10, 20, 30]: - for eps in [0.0, 0.01, 0.02, 0.05, 0.1]: - for al in ['leaf', 'eom']: - clusters = hdbscan.HDBSCAN(min_cluster_size=mcs, min_samples=ms, cluster_selection_epsilon=eps, \ - core_dist_n_jobs=-1, cluster_selection_method=al).fit_predict(X) - df.loc[idxs, 'cluster'] = clusters.astype(int) - mask = ~df.loc[idxs].label.isna() - clusters, labels = clusters[mask], df.loc[idxs[mask]].label - tnmi = metrics.normalized_mutual_info_score(labels, clusters) - file.write(f'{tnmi} {mcs} {ms} {eps} {al}\n') - if tnmi > nmi : - nmi - df.drop('cluster', axis=1, inplace=True) - nmis.append(nmi) - plt.scatter(nmis, np.arange(len(frontends)), label=specie) - -file.close() - -plt.yticks(range(len(frontends)), frontends) -plt.ylabel('archi') -plt.xlabel('NMI with expert labels') -plt.grid() -plt.tight_layout() -plt.legend() -plt.savefig('NMIs_hdbscan.pdf') -plt.close() diff --git a/print_annot.py b/print_annot.py index ac568d4c98416b606cc30055887ac0c510382f41..da77b06dc33611633b0de7e461b9dac729effc71 100755 --- a/print_annot.py +++ b/print_annot.py @@ -1,7 +1,7 @@ -import os, argparse +import os, shutil, argparse from tqdm import tqdm import matplotlib.pyplot as plt -import pandas as pd +import pandas as pd, numpy as np import models, utils as u import torch @@ -14,19 +14,21 @@ args = parser.parse_args() meta = models.meta[args.specie] df = pd.read_csv(f'{args.specie}/{args.specie}.csv') frontend = models.frontend[args.frontend](meta['sr'], meta['nfft'], meta['sampleDur'], args.nMel) -os.system(f'rm -R {args.specie}/annot_pngs/*') +shutil.rmtree(f'{args.specie}/annot_pngs', ignore_errors=True) + for label, grp in df.groupby('label'): - os.system(f'mkdir -p "{args.specie}/annot_pngs/{label}"') + os.makedirs(f'{args.specie}/annot_pngs/{label}', exist_ok=True) loader = torch.utils.data.DataLoader(u.Dataset(grp, args.specie+'/audio/', meta['sr'], meta['sampleDur']),\ batch_size=1, num_workers=4, pin_memory=True) for x, idx in tqdm(loader, desc=args.specie + ' ' + label, leave=False): x = frontend(x).squeeze().detach() assert not torch.isnan(x).any(), "Found a NaN in spectrogram... :/" plt.figure() - plt.imshow(x, origin='lower', aspect='auto') + plt.imshow(x, origin='lower', aspect='auto', cmap='Greys', vmin=np.quantile(x, .7)) + plt.subplots_adjust(top=1, bottom=0, left=0, right=1) row = df.loc[idx.item()] - plt.savefig(f'{args.specie}/annot_pngs/{label}/{row.fn.split(".")[0]}_{row.pos:.2f}.png') - # plt.savefig(f'{args.specie}/annot_pngs/{label}/{idx.item()}') + #plt.savefig(f'{args.specie}/annot_pngs/{label}/{row.fn.split(".")[0]}_{row.pos:.2f}.png') + plt.savefig(f'{args.specie}/annot_pngs/{label}/{idx.item()}') plt.close() diff --git a/print_reconstr.py b/print_reconstr.py new file mode 100755 index 0000000000000000000000000000000000000000..66efcbf29cf6e88038362fd63b9ed3efd7fb1068 --- /dev/null +++ b/print_reconstr.py @@ -0,0 +1,48 @@ +import os, shutil, argparse +from tqdm import tqdm +import matplotlib.pyplot as plt +import pandas as pd, numpy as np +import models, utils as u +import torch + +parser = argparse.ArgumentParser() +parser.add_argument("specie", type=str) +parser.add_argument("-bottleneck", type=int, default=16) +parser.add_argument("-frontend", type=str, default='logMel') +parser.add_argument("-encoder", type=str, default='sparrow_encoder') +parser.add_argument("-prcptl", type=int, default=1) +parser.add_argument("-nMel", type=int, default=128) +args = parser.parse_args() + +modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel if "Mel" in args.frontend else ""}_{args.encoder}_decod2_BN_nomaxPool{"_noprcptl" if args.prcptl==0 else ""}.stdc' +print(modelname) +gpu = torch.device(f'cuda') + +meta = models.meta[args.specie] +frontend = models.frontend[args.frontend](meta['sr'], meta['nfft'], meta['sampleDur'], args.nMel) +encoder = models.__dict__[args.encoder](*((args.bottleneck // 16, (4, 4)) if args.nMel == 128 else (args.bottleneck // 8, (2, 4)))) +decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4)) +model = torch.nn.Sequential(frontend, encoder, decoder).to(gpu) +model.load_state_dict(torch.load(f'{args.specie}/weights/{modelname}')) +model.eval() + +df = pd.read_csv(f'{args.specie}/{args.specie}.csv') + +shutil.rmtree(f'{args.specie}/reconstruct_pngs', ignore_errors=True) + +for label, grp in df.groupby('label'): + os.makedirs(f'{args.specie}/reconstruct_pngs/{label}', exist_ok=True) + loader = torch.utils.data.DataLoader(u.Dataset(grp, args.specie+'/audio/', meta['sr'], meta['sampleDur']),\ + batch_size=1, num_workers=4, pin_memory=True) + with torch.no_grad(): + for x, idx in tqdm(loader, desc=args.specie + ' ' + label, leave=False): + x = model(x.to(gpu)).squeeze().detach().cpu() + assert not torch.isnan(x).any(), "Found a NaN in spectrogram... :/" + plt.imshow(x, origin='lower', aspect='auto') #, cmap='Greys', vmin=np.quantile(x, .7)) + plt.subplots_adjust(top=1, bottom=0, left=0, right=1) + row = df.loc[idx.item()] + plt.savefig(f'{args.specie}/reconstruct_pngs/{label}/{idx.item()}') + plt.close() + + + diff --git a/run_baseline.py b/run_baseline.py index ca30c1c7d7e39304ae29ec739c81da15f3a7de48..dcea284b6418d546cb510c6c99c18046cc66a5d1 100755 --- a/run_baseline.py +++ b/run_baseline.py @@ -43,7 +43,7 @@ def process(idx): return [sound.__dict__[f] for f in feats] -res = p_tqdm.p_map(process, df.index, num_cpus=14) +res = p_tqdm.p_map(process, df.index, num_cpus=10) for i, mr in zip(df.index, res): for f, r in zip(feats, mr): diff --git a/run_hearbaseline.py b/run_hearbaseline.py index e5b9a3f2a8711941af8dec04f57259638f6fa897..52ba7e6af44ca3d962cec2b99646b814aa753dcb 100755 --- a/run_hearbaseline.py +++ b/run_hearbaseline.py @@ -1,11 +1,9 @@ from sklearn import metrics -import matplotlib.pyplot as plt -#import umap, hdbscan from tqdm import tqdm import argparse, os import models, utils as u import pandas as pd, numpy as np, torch -from hearbaseline import wav2vec2 as hear +from hearbaseline import wav2vec2, torchopenl3, torchcrepe, vggish parser = argparse.ArgumentParser() parser.add_argument("specie", type=str) @@ -17,57 +15,17 @@ df = pd.read_csv(f'{args.specie}/{args.specie}.csv') meta = models.meta[args.specie] batch_size = 32 -if True : #not os.path.isfile(f'{args.specie}/encodings/encodings_wave2vec2.npy'): - gpu = torch.device(f'cuda:{args.cuda}') - model = hear.load_model().to(gpu) + +gpu = torch.device(f'cuda:{args.cuda}') +for module, name in zip([torchopenl3, wav2vec2, torchcrepe, vggish], ['openl3', 'wav2vec2', 'crepe', 'vggish']): + model = module.load_model().to(gpu) loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', model.sample_rate, meta['sampleDur']), batch_size=batch_size, num_workers=8, collate_fn=u.collate_fn) with torch.inference_mode(): encodings, idxs = [], [] for x, idx in tqdm(loader, desc='test '+args.specie, leave=False): - encoding = hear.get_scene_embeddings(x.to(gpu), model=model) + encoding = module.get_scene_embeddings(x.to(gpu), model=model) idxs.extend(idx.numpy()) encodings.extend(encoding.view(len(x), -1).cpu().numpy()) idxs, encodings = np.array(idxs), np.stack(encodings) -# X = umap.UMAP(n_jobs=-1, n_components=8).fit_transform(encodings) - np.save(f'{args.specie}/encodings/encodings_wav2vec2.npy', {'idxs':idxs, 'encodings':encodings}) #, 'umap8':X}) - exit() -else: - dic = np.load(f'{args.specie}/encodings/encodings_wave2vec2.npy', allow_pickle=True).item() - idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap8'] - -clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X) -df.loc[idxs, 'cluster'] = clusters.astype(int) -mask = ~df.loc[idxs].label.isna() - -clusters, labels = clusters[mask], df.loc[idxs[mask]].label -print('NMI', metrics.normalized_mutual_info_score(labels, clusters)) -exit() -#print('Found clusters : \n', pd.Series(clusters).value_counts()) - -plt.figure(figsize=(20, 10)) -plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey') -plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20') -plt.tight_layout() -plt.savefig(f'{args.specie}/projections/vggish_projection_clusters.png') - -plt.figure(figsize=(20, 10)) -plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey') -for l, grp in df.groupby('label'): - plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l) -plt.legend() -plt.tight_layout() -plt.savefig(f'{args.specie}/projections/vggish_projection_labels.png') - - -clusters, labels = clusters[mask], df.loc[idxs[mask]].label -print('Silhouette', metrics.silhouette_score(encodings[mask], clusters)) -print('NMI', metrics.normalized_mutual_info_score(labels, clusters)) -print('Homogeneity', metrics.homogeneity_score(labels, clusters)) -print('Completeness', metrics.completeness_score(labels, clusters)) -print('V-Measure', metrics.v_measure_score(labels, clusters)) - -labelled = df[~df.label.isna()] -for l, grp in labelled.groupby('label'): - best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax() - print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \ -with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f} and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}') + # X = umap.UMAP(n_jobs=-1, n_components=8).fit_transform(encodings) + np.save(f'{args.specie}/encodings/encodings_{name}.npy', {'idxs':idxs, 'encodings':encodings}) #, 'umap8':X}) diff --git a/run_openl3.py b/run_openl3.py deleted file mode 100755 index aefb97664c6f9093c1c250a24439361f10937e7f..0000000000000000000000000000000000000000 --- a/run_openl3.py +++ /dev/null @@ -1,74 +0,0 @@ -from sklearn import metrics -import matplotlib.pyplot as plt -import umap, hdbscan -from tqdm import tqdm -import argparse, os -import models, utils as u -import pandas as pd, numpy as np, torch -import torchopenl3 as openl3 - -parser = argparse.ArgumentParser() -parser.add_argument("specie", type=str) -parser.add_argument("-cuda", type=int, default=0) -args = parser.parse_args() - -df = pd.read_csv(f'{args.specie}/{args.specie}.csv') - -meta = models.meta[args.specie] -batch_size = 64 - -if True : #not os.path.isfile(f'{args.specie}/encodings/encodings_openl3.npy'): - gpu = torch.device(f'cuda:{args.cuda}') - frontend = models.frontend['logMel_vggish'](meta['sr'], meta['nfft'], meta['sampleDur'], 64) - loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=batch_size, num_workers=8, collate_fn=u.collate_fn) - model = openl3.models.load_audio_embedding_model(input_repr="mel128", content_type="music", embedding_size=512).to(gpu) - with torch.no_grad(): - encodings, idxs = [], [] - for x, idx in tqdm(loader, desc='test '+args.specie, leave=False): - encoding = openl3.get_audio_embedding(x.to(gpu), meta['sr'], model=model, center=False, batch_size=batch_size, verbose=False)[0] - idxs.extend(idx.numpy()) - encodings.extend(encoding.mean(axis=1).cpu().numpy()) - - idxs, encodings = np.array(idxs), np.stack(encodings) - X = umap.UMAP(n_jobs=-1, n_components=8).fit_transform(encodings) - np.save(f'{args.specie}/encodings/encodings_openl3.npy', {'idxs':idxs, 'encodings':encodings, 'umap8':X}) -else: - dic = np.load(f'{args.specie}/encodings/encodings_openl3.npy', allow_pickle=True).item() - idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap8'] - -clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X) -df.loc[idxs, 'cluster'] = clusters.astype(int) -mask = ~df.loc[idxs].label.isna() - -clusters, labels = clusters[mask], df.loc[idxs[mask]].label -print('NMI', metrics.normalized_mutual_info_score(labels, clusters)) -exit() -#print('Found clusters : \n', pd.Series(clusters).value_counts()) - -plt.figure(figsize=(20, 10)) -plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey') -plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20') -plt.tight_layout() -plt.savefig(f'{args.specie}/projections/vggish_projection_clusters.png') - -plt.figure(figsize=(20, 10)) -plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey') -for l, grp in df.groupby('label'): - plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l) -plt.legend() -plt.tight_layout() -plt.savefig(f'{args.specie}/projections/vggish_projection_labels.png') - - -clusters, labels = clusters[mask], df.loc[idxs[mask]].label -print('Silhouette', metrics.silhouette_score(encodings[mask], clusters)) -print('NMI', metrics.normalized_mutual_info_score(labels, clusters)) -print('Homogeneity', metrics.homogeneity_score(labels, clusters)) -print('Completeness', metrics.completeness_score(labels, clusters)) -print('V-Measure', metrics.v_measure_score(labels, clusters)) - -labelled = df[~df.label.isna()] -for l, grp in labelled.groupby('label'): - best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax() - print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \ -with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f} and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}') diff --git a/run_vggish.py b/run_vggish.py deleted file mode 100755 index c7bcacc8c5dac05650c2288616b3914e174e74f8..0000000000000000000000000000000000000000 --- a/run_vggish.py +++ /dev/null @@ -1,74 +0,0 @@ -from sklearn import metrics -import matplotlib.pyplot as plt -import umap, hdbscan -from tqdm import tqdm -import argparse, os -import models, utils as u -import pandas as pd, numpy as np, torch -torch.multiprocessing.set_sharing_strategy('file_system') - -parser = argparse.ArgumentParser() -parser.add_argument("specie", type=str) -args = parser.parse_args() - -df = pd.read_csv(f'{args.specie}/{args.specie}.csv') - -meta = models.meta[args.specie] - -if not os.path.isfile(f'{args.specie}/encodings/encodings_vggish.npy'): - gpu = torch.device('cuda') - frontend = models.frontend['logMel_vggish'](meta['sr'], meta['nfft'], meta['sampleDur'], 64) - vggish = torch.hub.load('harritaylor/torchvggish', 'vggish') - # vggish.preprocess = False - vggish.postprocess = False - model = torch.nn.Sequential(frontend, vggish).to(gpu) - model.eval() - loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', 16000, 1), batch_size=1, shuffle=True, num_workers=8, collate_fn=u.collate_fn) - with torch.no_grad(): - encodings, idxs = [], [] - for x, idx in tqdm(loader, desc='test '+args.specie, leave=False): - # encoding = model(x.to(gpu)) - encoding = vggish(x.numpy().squeeze(0), fs=16000) - idxs.extend(idx.numpy()) - encodings.extend(encoding.cpu().numpy()) - - idxs, encodings = np.array(idxs), np.stack(encodings) - X = umap.UMAP(n_jobs=-1).fit_transform(encodings) - np.save(f'{args.specie}/encodings/encodings_vggish.npy', {'idxs':idxs, 'encodings':encodings, 'umap':X}) -else: - dic = np.load(f'{args.specie}/encodings/encodings_vggish.npy', allow_pickle=True).item() - idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap'] - -clusters = hdbscan.HDBSCAN(min_cluster_size=50, min_samples=5, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X) -df.loc[idxs, 'cluster'] = clusters.astype(int) -mask = ~df.loc[idxs].label.isna() - -#print('Found clusters : \n', pd.Series(clusters).value_counts()) - -plt.figure(figsize=(20, 10)) -plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey') -plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20') -plt.tight_layout() -plt.savefig(f'{args.specie}/projections/vggish_projection_clusters.png') - -plt.figure(figsize=(20, 10)) -plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey') -for l, grp in df.groupby('label'): - plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l) -plt.legend() -plt.tight_layout() -plt.savefig(f'{args.specie}/projections/vggish_projection_labels.png') - - -clusters, labels = clusters[mask], df.loc[idxs[mask]].label -print('Silhouette', metrics.silhouette_score(encodings[mask], clusters)) -print('NMI', metrics.normalized_mutual_info_score(labels, clusters)) -print('Homogeneity', metrics.homogeneity_score(labels, clusters)) -print('Completeness', metrics.completeness_score(labels, clusters)) -print('V-Measure', metrics.v_measure_score(labels, clusters)) - -labelled = df[~df.label.isna()] -for l, grp in labelled.groupby('label'): - best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax() - print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \ -with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f} and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}') diff --git a/test_AE.py b/test_AE.py index 5fcc0ae21a012e2964529e1b24c9dee80d341dde..92f8100885ee4e7c624cf6109de493a2f7446349 100755 --- a/test_AE.py +++ b/test_AE.py @@ -56,7 +56,7 @@ else: if X is None: X = umap.UMAP(n_jobs=-1, n_components=8).fit_transform(encodings) print('/!\ no UMAP') -clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=.1, cluster_selection_method='leaf', core_dist_n_jobs=-1).fit_predict(X) +clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=.1, cluster_selection_method='leaf' if not 'humpback' in args.specie else 'eom', core_dist_n_jobs=-1).fit_predict(X) df.loc[idxs, 'cluster'] = clusters.astype(int) mask = ~df.loc[idxs].label.isna() diff --git a/train_AE.py b/train_AE.py index b13dca3027529a4a9ade28cc6a007fe492cefcd2..d02907e84b44197240b7173284d8dc6ff5eb86a4 100755 --- a/train_AE.py +++ b/train_AE.py @@ -1,10 +1,7 @@ -import umap, hdbscan from torchvision.utils import make_grid from torch.utils.tensorboard import SummaryWriter import torch -from sklearn import metrics, cluster import numpy as np, pandas as pd -from scipy.stats import linregress import utils as u, models from tqdm import tqdm import os, argparse, warnings @@ -13,9 +10,10 @@ warnings.filterwarnings("error") parser = argparse.ArgumentParser() parser.add_argument("specie", type=str) -parser.add_argument("-bottleneck", type=int, default=16) +parser.add_argument("-bottleneck", type=int, default=256) parser.add_argument("-frontend", type=str, default='logMel') parser.add_argument("-encoder", type=str, default='sparrow_encoder') +parser.add_argument("-prcptl", type=int, default=1) parser.add_argument("-nMel", type=int, default=128) parser.add_argument("-lr", type=float, default=3e-3) parser.add_argument("-lr_decay", type=float, default=1e-2) @@ -26,10 +24,10 @@ args = parser.parse_args() df = pd.read_csv(f'{args.specie}/{args.specie}.csv') print(f'{len(df)} available vocs') -modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel if "Mel" in args.frontend else ""}_{args.encoder}_decod2_BN_nomaxPool.stdc' +modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel if "Mel" in args.frontend else ""}_{args.encoder}_decod2_BN_nomaxPool{"_noprcptl" if args.prcptl==0 else ""}.stdc' gpu = torch.device(f'cuda:{args.cuda}') -writer = SummaryWriter(f'runs2/{modelname}') -os.system(f'cp *.py runs2/{modelname}') +writer = SummaryWriter(f'runs/{modelname}') +os.system(f'cp *.py runs/{modelname}') vgg16 = models.vgg16 vgg16.eval().to(gpu) meta = models.meta[args.specie] @@ -56,8 +54,11 @@ for epoch in range(100_000//len(loader)): x = encoder(label) pred = decoder(x) assert not torch.isnan(pred).any(), "found a NaN :'(" - predd = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:])) - labell = vgg16(label.expand(label.shape[0], 3, *label.shape[2:])) + if args.prcptl == 1: + predd = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:])) + labell = vgg16(label.expand(label.shape[0], 3, *label.shape[2:])) + else: + predd, labell = pred, label score = MSE(predd, labell) score.backward() @@ -69,10 +70,14 @@ for epoch in range(100_000//len(loader)): print('Early stop') torch.save(model.state_dict(), f'{args.specie}/weights/{modelname}') exit() + step += 1 - continue - # TEST ROUTINE - if step % 500 == 0: + + if step % 500 == 0: # scheduler + scheduler.step() + + # Plot images + if step % 100 == 0 : # Plot reconstructions images = [(e-e.min())/(e.max()-e.min()) for e in label[:8]] grid = make_grid(images) @@ -80,81 +85,4 @@ for epoch in range(100_000//len(loader)): # writer.add_embedding(x.detach(), global_step=step, label_img=label) images = [(e-e.min())/(e.max()-e.min()) for e in pred[:8]] grid = make_grid(images) - writer.add_image('reconstruct', grid, step) - - torch.save(model.state_dict(), f'{args.specie}/{modelname}') - scheduler.step() - - # Actual test - model[1:].eval() - with torch.no_grad(): - encodings, idxs = [], [] - for x, idx in tqdm(loader, desc='test '+str(step), leave=False): - encoding = model[:2](x.to(gpu)) - idxs.extend(idx) - encodings.extend(encoding.cpu().detach()) - idxs, encodings = np.array(idxs), np.stack(encodings) - print('Computing UMAP...', end='') - try: - X = umap.UMAP(n_jobs=-1).fit_transform(encodings) - except UserWarning: - pass - print('\rRunning HDBSCAN...', end='') - clusters = hdbscan.HDBSCAN(min_cluster_size=len(df)//100, min_samples=5, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X) - # df.loc[idxs, 'cluster'] = clusters.astype(int) - mask = ~df.loc[idxs].label.isna() - clusters, labels = clusters[mask], df.loc[idxs[mask]].label - writer.add_scalar('NMI HDBSCAN', metrics.normalized_mutual_info_score(labels, clusters), step) - try: - writer.add_scalar('ARI HDBSCAN', metrics.adjusted_rand_score(labels, clusters), step) - except: - pass - writer.add_scalar('Homogeneity HDBSCAN', metrics.homogeneity_score(labels, clusters), step) - writer.add_scalar('Completeness HDBSCAN', metrics.completeness_score(labels, clusters), step) - writer.add_scalar('V-Measure HDBSCAN', metrics.v_measure_score(labels, clusters), step) - - # print('\rComputing HDBSCAN precision and recall distributions', end='') - # labelled = df[~df.label.isna()] - # precs, recs = [], [] - # for l, grp in labelled.groupby('label'): - # best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax() - # precs.append((grp.cluster==best).sum()/(labelled.cluster==best).sum()) - # recs.append((grp.cluster==best).sum()/len(grp)) - # writer.add_histogram('HDBSCAN Precisions ', np.array(precs), step) - # writer.add_histogram('HDBSCAN Recalls ', np.array(recs), step) - # df.drop('cluster', axis=1, inplace=True) - # print('\rRunning elbow method for K-Means...', end='') - # ks = (5*1.2**np.arange(20)).astype(int) - # distorsions = [cluster.KMeans(n_clusters=k).fit(encodings).inertia_ for k in ks] - # print('\rEstimating elbow...', end='') - # errors = [linregress(ks[:i], distorsions[:i]).stderr + linregress(ks[i+1:], distorsions[i+1:]).stderr for i in range(2, len(ks)-2)] - # k = ks[np.argmin(errors)] - # writer.add_scalar('Chosen K', k, step) - # clusters = cluster.KMeans(n_clusters=k).fit_predict(encodings) - # df.loc[idxs, 'cluster'] = clusters.astype(int) - - # writer.add_scalar('Silhouette', metrics.silhouette_score(encodings, clusters), step) - # clusters, labels = clusters[mask], df.loc[idxs[mask]].label - # writer.add_scalar('NMI K-Means', metrics.normalized_mutual_info_score(labels, clusters), step) - # try: - # writer.add_scalar('ARI K-Means', metrics.adjusted_rand_score(labels, clusters), step) - # except: - # pass - # writer.add_scalar('Homogeneity K-Means', metrics.homogeneity_score(labels, clusters), step) - # writer.add_scalar('Completeness K-Means', metrics.completeness_score(labels, clusters), step) - # writer.add_scalar('V-Measure K-Means', metrics.v_measure_score(labels, clusters), step) - - # print('\rComputing K-Means precision and recall distributions', end='') - # labelled = df[~df.label.isna()] - # precs, recs = [], [] - # for l, grp in labelled.groupby('label'): - # best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax() - # precs.append((grp.cluster==best).sum()/(labelled.cluster==best).sum()) - # recs.append((grp.cluster==best).sum()/len(grp)) - # writer.add_histogram('K-Means Precisions ', np.array(precs), step) - # writer.add_histogram('K-Means Recalls ', np.array(recs), step) - # df.drop('cluster', axis=1, inplace=True) - - print('\r', end='') - model[1:].train() - step += 1 + writer.add_image('reconstruct', grid, step) \ No newline at end of file