Skip to content
Snippets Groups Projects
Commit 3776e702 authored by Paul Best's avatar Paul Best
Browse files

fix print_annot

parent c953aec8
No related branches found
No related tags found
No related merge requests found
...@@ -30,7 +30,7 @@ model.load_state_dict(torch.load(args.modelname)) ...@@ -30,7 +30,7 @@ model.load_state_dict(torch.load(args.modelname))
df = pd.read_csv(args.detections) df = pd.read_csv(args.detections)
print('Computing AE projections...') print('Computing AE projections...')
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0), channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn) loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn)
with torch.inference_mode(): with torch.inference_mode():
encodings, idxs = [], [] encodings, idxs = [], []
for x, idx in tqdm(loader): for x, idx in tqdm(loader):
......
import soundfile as sf import os, argparse
import os
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import soundfile as sf
import pandas as pd, numpy as np import pandas as pd, numpy as np
import models, utils as u import models, utils as u
import torch import torch
from filterbank import MelFilter, STFT, Log1p, MedFilt
fs = 44_100
sampleDur = 2.5
nMel = 96
hop = 260
width = int((sampleDur * fs - 2048)/hop) + 1
pad_width = int((1 * fs)/hop) + 1
frontend = torch.nn.Sequential( parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Plot vocalisation spectrograms into annot_pngs/")
STFT(2048, hop), parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
MelFilter(fs, 2048, 96, 500, 4000), parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
Log1p(4), parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
MedFilt() parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
) parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.set_defaults(feature=False)
args = parser.parse_args()
#df = pd.read_csv('S4A09100_20220527$045924.Table.1.selections.txt', delimiter='\t')
#df['filename'] = 'S4A09100_20220527$045924.wav'
#df['pos'] = df['Begin Time (s)'] + 1.25
#df['dur'] = df['End Time (s)'] - df['Begin Time (s)']
#df.drop(df[((df.dur < 2.5)|(df.dur > 8))].index, inplace=True)
df = pd.read_csv('all_annot.csv')
df = df['Cao'] = 'C'
df['pos'] = df['Begin Time (s)'] + 1.25
loader = torch.utils.data.DataLoader(u.Dataset(df, './', fs, 2.5 + 2), batch_size=1, num_workers=8, collate_fn=u.collate_fn) frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
df = pd.read_csv(args.detections)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=1, num_workers=8, collate_fn=u.collate_fn, shuffle=True)
for x, idx in tqdm(loader): for x, idx in tqdm(loader):
x = frontend(x).squeeze().detach()[:,pad_width:-pad_width] x = frontend(x).squeeze().detach()
plt.figure() plt.figure()
plt.imshow(x, origin='lower', aspect='auto') plt.imshow(x, origin='lower', aspect='auto', vmin=torch.quantile(x, .25), cmap='Greys', vmax=torch.quantile(x, .98))
plt.subplots_adjust(top=1, bottom=0, left=0, right=1) plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
plt.savefig(f'gibbon_calls/cao_vit_DL/{idx.item()}') plt.savefig(f'annot_pngs/{idx.item()}')
plt.close() plt.close()
exit()
sig, fs = sf.read('S4A09100_20220527$045924.wav')
for i, r in tqdm(df.iterrows(), total=len(df)):
# x = frontend(torch.Tensor(sig[int(fs*r['Begin Time (s)']):int(fs*r['End Time (s)']), 1]).unsqueeze(0)).squeeze().detach()
x = frontend(torch.Tensor(sig[int(fs*r['Begin Time (s)']):int(fs*(r['Begin Time (s)']+2.5)), 1]).unsqueeze(0)).squeeze().detach()
plt.imshow(x, origin='lower', aspect='auto', extent=[0, r.dur, 500, 5000])
plt.tight_layout()
plt.savefig(f'gibbon_calls/cao_vit_cut/{i}')
plt.close()
...@@ -29,6 +29,7 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec ...@@ -29,6 +29,7 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.') parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.')
parser.add_argument('-channel', type=int, default=0) parser.add_argument('-channel', type=int, default=0)
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.') parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.')
parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.') parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.')
args = parser.parse_args() args = parser.parse_args()
...@@ -36,7 +37,9 @@ args = parser.parse_args() ...@@ -36,7 +37,9 @@ args = parser.parse_args()
df = pd.read_csv(args.detections) df = pd.read_csv(args.detections)
encodings = np.load(args.encodings, allow_pickle=True).item() encodings = np.load(args.encodings, allow_pickle=True).item()
idxs, umap_, embeddings = encodings['idx'], encodings['umap'], encodings['encodings'] idxs, umap_, embeddings = encodings['idx'], encodings['umap'], encodings['encodings']
frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
args.sampleDur += (2 if args.medfilt else 0)
if args.umap_ndim == 2: if args.umap_ndim == 2:
df.loc[idxs, 'umap_x'] = umap_[:,0] df.loc[idxs, 'umap_x'] = umap_[:,0]
...@@ -76,7 +79,7 @@ if args.umap_ndim == 2: ...@@ -76,7 +79,7 @@ if args.umap_ndim == 2:
dur, fs = info.duration, info.samplerate dur, fs = info.duration, info.samplerate
start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs) start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True) sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
sig = sig[:, args.channel] sig = sig[:, row.Channel -1 if 'Channel' in row else args.channel]
if fs != args.SR: if fs != args.SR:
sig = signal.resample(sig, int(len(sig)/fs*args.SR)) sig = signal.resample(sig, int(len(sig)/fs*args.SR))
spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze() spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
...@@ -107,11 +110,16 @@ else : ...@@ -107,11 +110,16 @@ else :
print(df.cluster.value_counts().describe()) print(df.cluster.value_counts().describe())
if df.cluster.isna().sum() == 0:
df.cluster = df.cluster.astype(int)
if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y': if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y':
exit() exit()
os.system('rm -R cluster_pngs/*') os.system('rm -R cluster_pngs/*')
df.to_csv(f'cluster_pngs/{args.detections}', index=False)
for c, grp in df.groupby('cluster'): for c, grp in df.groupby('cluster'):
if c == -1 or len(grp) > 10_000: if c == -1 or len(grp) > 10_000:
continue continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment