Skip to content
Snippets Groups Projects
run_openl3.py 3.63 KiB
Newer Older
from sklearn import metrics
import matplotlib.pyplot as plt
import umap, hdbscan
from tqdm import tqdm
import argparse, os
import models, utils as u
import pandas as pd, numpy as np, torch
import torchopenl3 as openl3

parser = argparse.ArgumentParser()
parser.add_argument("specie", type=str)
parser.add_argument("-cuda", type=int, default=0)
args = parser.parse_args()

df = pd.read_csv(f'{args.specie}/{args.specie}.csv')

meta = models.meta[args.specie]
batch_size = 64

if True : #not os.path.isfile(f'{args.specie}/encodings/encodings_openl3.npy'):
    gpu = torch.device(f'cuda:{args.cuda}')
    frontend = models.frontend['logMel_vggish'](meta['sr'], meta['nfft'], meta['sampleDur'], 64)
    loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=batch_size, num_workers=8, collate_fn=u.collate_fn)
    model = openl3.models.load_audio_embedding_model(input_repr="mel128", content_type="music", embedding_size=512).to(gpu)
    with torch.no_grad():
        encodings, idxs = [], []
        for x, idx in tqdm(loader, desc='test '+args.specie, leave=False):
            encoding = openl3.get_audio_embedding(x.to(gpu), meta['sr'], model=model, center=False, batch_size=batch_size, verbose=False)[0]
            idxs.extend(idx.numpy())
            encodings.extend(encoding.mean(axis=1).cpu().numpy())

    idxs, encodings = np.array(idxs), np.stack(encodings)
    X = umap.UMAP(n_jobs=-1, n_components=8).fit_transform(encodings)
    np.save(f'{args.specie}/encodings/encodings_openl3.npy', {'idxs':idxs, 'encodings':encodings, 'umap8':X})
else:
    dic = np.load(f'{args.specie}/encodings/encodings_openl3.npy', allow_pickle=True).item()
    idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap8']

clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
df.loc[idxs, 'cluster'] = clusters.astype(int)
mask = ~df.loc[idxs].label.isna()

clusters, labels = clusters[mask], df.loc[idxs[mask]].label
print('NMI', metrics.normalized_mutual_info_score(labels, clusters))
exit()
#print('Found clusters : \n', pd.Series(clusters).value_counts())

plt.figure(figsize=(20, 10))
plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey')
plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20')
plt.tight_layout()
plt.savefig(f'{args.specie}/projections/vggish_projection_clusters.png')

plt.figure(figsize=(20, 10))
plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey')
for l, grp in df.groupby('label'):
    plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l)
plt.legend()
plt.tight_layout()
plt.savefig(f'{args.specie}/projections/vggish_projection_labels.png')


clusters, labels = clusters[mask], df.loc[idxs[mask]].label
print('Silhouette', metrics.silhouette_score(encodings[mask], clusters))
print('NMI', metrics.normalized_mutual_info_score(labels, clusters))
print('Homogeneity', metrics.homogeneity_score(labels, clusters))
print('Completeness', metrics.completeness_score(labels, clusters))
print('V-Measure', metrics.v_measure_score(labels, clusters))

labelled = df[~df.label.isna()]
for l, grp in labelled.groupby('label'):
    best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax()
    print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \
with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f} and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}')