Skip to content
Snippets Groups Projects
run_vggish.py 3.44 KiB
Newer Older
Paul Best's avatar
Paul Best committed
from sklearn import metrics
import matplotlib.pyplot as plt
import umap, hdbscan
from tqdm import tqdm
import argparse, os
import models, utils as u
import pandas as pd, numpy as np, torch
torch.multiprocessing.set_sharing_strategy('file_system')

parser = argparse.ArgumentParser()
parser.add_argument("specie", type=str)
args = parser.parse_args()

df = pd.read_csv(f'{args.specie}/{args.specie}.csv')

meta = models.meta[args.specie]

Paul Best's avatar
Paul Best committed
if not os.path.isfile(f'{args.specie}/encodings/encodings_vggish.npy'):
Paul Best's avatar
Paul Best committed
    gpu = torch.device('cuda')
    frontend = models.frontend['logMel_vggish'](meta['sr'], meta['nfft'], meta['sampleDur'], 64)
    vggish = torch.hub.load('harritaylor/torchvggish', 'vggish')
    # vggish.preprocess = False
    vggish.postprocess = False
    model = torch.nn.Sequential(frontend, vggish).to(gpu)
    model.eval()
    loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', 16000, 1), batch_size=1, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
    with torch.no_grad():
        encodings, idxs = [], []
        for x, idx in tqdm(loader, desc='test '+args.specie, leave=False):
            # encoding = model(x.to(gpu))
            encoding = vggish(x.numpy().squeeze(0), fs=16000)
Paul Best's avatar
Paul Best committed
            idxs.extend(idx.numpy())
            encodings.extend(encoding.cpu().numpy())
Paul Best's avatar
Paul Best committed

    idxs, encodings = np.array(idxs), np.stack(encodings)
    X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
Paul Best's avatar
Paul Best committed
    np.save(f'{args.specie}/encodings/encodings_vggish.npy', {'idxs':idxs, 'encodings':encodings, 'umap':X})
Paul Best's avatar
Paul Best committed
else:
Paul Best's avatar
Paul Best committed
    dic = np.load(f'{args.specie}/encodings/encodings_vggish.npy', allow_pickle=True).item()
Paul Best's avatar
Paul Best committed
    idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap']

clusters = hdbscan.HDBSCAN(min_cluster_size=50, min_samples=5, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
df.loc[idxs, 'cluster'] = clusters.astype(int)
mask = ~df.loc[idxs].label.isna()

#print('Found clusters : \n', pd.Series(clusters).value_counts())

plt.figure(figsize=(20, 10))
plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey')
plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20')
plt.tight_layout()
Paul Best's avatar
Paul Best committed
plt.savefig(f'{args.specie}/projections/vggish_projection_clusters.png')
Paul Best's avatar
Paul Best committed

plt.figure(figsize=(20, 10))
plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey')
for l, grp in df.groupby('label'):
    plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l)
plt.legend()
plt.tight_layout()
Paul Best's avatar
Paul Best committed
plt.savefig(f'{args.specie}/projections/vggish_projection_labels.png')
Paul Best's avatar
Paul Best committed


clusters, labels = clusters[mask], df.loc[idxs[mask]].label
print('Silhouette', metrics.silhouette_score(encodings[mask], clusters))
print('NMI', metrics.normalized_mutual_info_score(labels, clusters))
print('Homogeneity', metrics.homogeneity_score(labels, clusters))
print('Completeness', metrics.completeness_score(labels, clusters))
print('V-Measure', metrics.v_measure_score(labels, clusters))

labelled = df[~df.label.isna()]
for l, grp in labelled.groupby('label'):
    best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax()
    print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \
with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f} and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}')