Skip to content
Snippets Groups Projects
Commit 0e8c9e13 authored by Paul Best's avatar Paul Best
Browse files

fix plot result

parent 8f621043
Branches
No related tags found
No related merge requests found
...@@ -5,25 +5,28 @@ import numpy as np ...@@ -5,25 +5,28 @@ import numpy as np
from sklearn import metrics from sklearn import metrics
species = np.loadtxt('good_species.txt', dtype=str) species = np.loadtxt('good_species.txt', dtype=str)
frontends = ['16_pcenMel128', '16_logMel128', '16_logSTFT', '16_Mel128', '8_pcen64', '32_pcenMel128'] frontends = ['16_pcenMel128', '16_logMel128', '16_logSTFT', '16_Mel128', '8_pcenMel64', '32_pcenMel128']
plt.figure() plt.figure()
for specie in species: for specie in species:
df = pd.read_csv(f'{specie}/{specie}.csv') df = pd.read_csv(f'{specie}/{specie}.csv')
nmis = []
for i, frontend in enumerate(frontends): for i, frontend in enumerate(frontends):
print(specie, frontend) print(specie, frontend)
dic = np.load(f'{specie}/encodings_{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() dic = np.load(f'{specie}/encodings_{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap'] idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap']
clusters = hdbscan.HDBSCAN(min_cluster_size=5, min_samples=None, cluster_selection_epsilon=0.0, core_dist_n_jobs=-1, cluster_selection_method='best').fit_predict(X) clusters = hdbscan.HDBSCAN(min_cluster_size=100, min_samples=20, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
df.loc[idxs, 'cluster'] = clusters.astype(int) df.loc[idxs, 'cluster'] = clusters.astype(int)
mask = ~df.loc[idxs].label.isna() mask = ~df.loc[idxs].label.isna()
clusters, labels = clusters[mask], df.loc[idxs[mask]].label clusters, labels = clusters[mask], df.loc[idxs[mask]].label
plt.scatter(metrics.normalized_mutual_info_score(labels, clusters), i, label=specie) nmis.append(metrics.normalized_mutual_info_score(labels, clusters))
df.drop('cluster', inplace=True) df.drop('cluster', axis=1, inplace=True)
plt.scatter(nmis, np.arange(len(frontends)), label=specie)
plt.ytick_labels(range(len(frontends)), frontends) plt.yticks(range(len(frontends)), frontends)
plt.ylabel('archi') plt.ylabel('archi')
plt.xlabel('NMI with expert labels') plt.xlabel('NMI with expert labels')
plt.grid() plt.grid()
plt.tight_layout()
plt.legend() plt.legend()
plt.savefig('NMIs_hdbscan.pdf') plt.savefig('NMIs_hdbscan.pdf')
...@@ -6,10 +6,11 @@ from sklearn import metrics, cluster ...@@ -6,10 +6,11 @@ from sklearn import metrics, cluster
from scipy.stats import linregress from scipy.stats import linregress
species = np.loadtxt('good_species.txt', dtype=str) species = np.loadtxt('good_species.txt', dtype=str)
frontends = ['16_pcenMel128', '16_logMel128', '16_logSTFT', '16_Mel128', '8_pcen64', '32_pcenMel128'] frontends = ['16_pcenMel128', '16_logMel128', '16_logSTFT', '16_Mel128', '8_pcenMel64', '32_pcenMel128']
plt.figure() plt.figure()
for specie in species: for specie in species:
df = pd.read_csv(f'{specie}/{specie}.csv') df = pd.read_csv(f'{specie}/{specie}.csv')
nmis = []
for i, frontend in enumerate(frontends): for i, frontend in enumerate(frontends):
print(specie, frontend) print(specie, frontend)
dic = np.load(f'{specie}/encodings_{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() dic = np.load(f'{specie}/encodings_{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
...@@ -24,12 +25,14 @@ for specie in species: ...@@ -24,12 +25,14 @@ for specie in species:
mask = ~df.loc[idxs].label.isna() mask = ~df.loc[idxs].label.isna()
clusters, labels = clusters[mask], df.loc[idxs[mask]].label clusters, labels = clusters[mask], df.loc[idxs[mask]].label
plt.scatter(metrics.normalized_mutual_info_score(labels, clusters), i, label=specie) nmis.append(metrics.normalized_mutual_info_score(labels, clusters))
df.drop('cluster', inplace=True) df.drop('cluster', axis=1, inplace=True)
plt.scatter(nmis, np.arange(len(frontends)), label=specie)
plt.ytick_labels(range(len(frontends)), frontends) plt.yticks(range(len(frontends)), frontends)
plt.ylabel('archi') plt.ylabel('archi')
plt.xlabel('NMI with expert labels') plt.xlabel('NMI with expert labels')
plt.grid() plt.grid()
plt.legend() plt.legend()
plt.tight_layout()
plt.savefig('NMIs_kmeans.pdf') plt.savefig('NMIs_kmeans.pdf')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment