From a3413da01a30cbe4b935c59db6f92fd6a091f241 Mon Sep 17 00:00:00 2001
From: lamipaul <paulobest25@gmail.com>
Date: Mon, 20 Nov 2023 19:33:56 +0100
Subject: [PATCH] semi sup umap

---
 new_specie/sort_cluster.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/new_specie/sort_cluster.py b/new_specie/sort_cluster.py
index b97b4ce..c867a58 100755
--- a/new_specie/sort_cluster.py
+++ b/new_specie/sort_cluster.py
@@ -32,6 +32,7 @@ parser.add_argument('-channel', type=int, default=0)
 parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
 parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.')
 parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.')
+parser.add_argument('-semi_sup_umap', type=str, default=None, help='Name of a column in the detections file with labels to guide the umap compression (see target_metric in umap documentation)')
 args = parser.parse_args()
 
 df = pd.read_csv(args.detections)
@@ -42,6 +43,8 @@ frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel
 args.sampleDur += (2 if args.medfilt else 0)
 
 if args.umap_ndim == 2:
+    if args.semi_sup_umap:
+        umap_ = umap.UMAP(n_jobs=-1).fit_transform(encodings['encodings'], y=df.loc[idxs, args.semi_sup_umap])
     df.loc[idxs, 'umap_x'] = umap_[:,0]
     df.loc[idxs, 'umap_y'] = umap_[:,1]
 
@@ -101,7 +104,7 @@ if args.umap_ndim == 2:
 
     plt.show()
 else :
-    umap_ = umap.UMAP(n_jobs=-1, n_components=args.umap_ndim).fit_transform(embeddings)
+    umap_ = umap.UMAP(n_jobs=-1, n_components=args.umap_ndim).fit_transform(embeddings, y=df[args.semi_sup_umap] if args.semi_sup_umap else None)
     df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
                                              min_samples=args.min_sample,
                                              core_dist_n_jobs=-1,
-- 
GitLab