From a6f7fa15474f2cf759f79e2d1d4df67777ed74a2 Mon Sep 17 00:00:00 2001
From: lamipaul <paulobest25@gmail.com>
Date: Tue, 18 Jul 2023 15:25:59 +0200
Subject: [PATCH] small fix

---
 new_specie/print_detections.py | 3 ++-
 new_specie/sort_cluster.py     | 2 +-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/new_specie/print_detections.py b/new_specie/print_detections.py
index 0e1d86e..bd0b8b1 100755
--- a/new_specie/print_detections.py
+++ b/new_specie/print_detections.py
@@ -21,12 +21,13 @@ args = parser.parse_args()
 frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
 df = pd.read_csv(args.detections)
 
+os.system('rm detections_pngs/*')
 loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=1, num_workers=8, collate_fn=u.collate_fn, shuffle=True)
 
 for x, idx in tqdm(loader):
     x = frontend(x).squeeze().detach()
     plt.imshow(x, origin='lower', aspect='auto', vmin=torch.quantile(x, .25), cmap='Greys', vmax=torch.quantile(x, .98))
     plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
-    plt.savefig(f'annot_pngs/{idx.item()}')
+    plt.savefig(f'detections_pngs/{idx.item()}')
     plt.close()
 
diff --git a/new_specie/sort_cluster.py b/new_specie/sort_cluster.py
index a5d7d33..b97b4ce 100755
--- a/new_specie/sort_cluster.py
+++ b/new_specie/sort_cluster.py
@@ -84,7 +84,7 @@ if args.umap_ndim == 2:
                 sig = signal.resample(sig, int(len(sig)/fs*args.SR))
             spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
             axSpec.clear()
-            axSpec.imshow(spec, origin='lower', aspect='auto')
+            axSpec.imshow(spec, origin='lower', aspect='auto', vmin=torch.quantile(spec, .25), cmap='Greys', vmax=torch.quantile(spec, .98))
             # Display and metadata management
             axSpec.set_title(f'{closest}, cluster {row.cluster:.0f} ({(df.cluster==row.cluster).sum()} points)')
             axScat.scatter(row.umap_x, row.umap_y, c='r')
-- 
GitLab