Skip to content
Snippets Groups Projects
Commit 3776e702 authored by Paul Best's avatar Paul Best
Browse files

fix print_annot

parent c953aec8
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,7 @@ model.load_state_dict(torch.load(args.modelname))
df = pd.read_csv(args.detections)
print('Computing AE projections...')
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0), channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn)
with torch.inference_mode():
encodings, idxs = [], []
for x, idx in tqdm(loader):
......
import soundfile as sf
import os
import os, argparse
from tqdm import tqdm
import matplotlib.pyplot as plt
import soundfile as sf
import pandas as pd, numpy as np
import models, utils as u
import torch
from filterbank import MelFilter, STFT, Log1p, MedFilt
fs = 44_100
sampleDur = 2.5
nMel = 96
hop = 260
width = int((sampleDur * fs - 2048)/hop) + 1
pad_width = int((1 * fs)/hop) + 1
frontend = torch.nn.Sequential(
STFT(2048, hop),
MelFilter(fs, 2048, 96, 500, 4000),
Log1p(4),
MedFilt()
)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Plot vocalisation spectrograms into annot_pngs/")
parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.set_defaults(feature=False)
args = parser.parse_args()
#df = pd.read_csv('S4A09100_20220527$045924.Table.1.selections.txt', delimiter='\t')
#df['filename'] = 'S4A09100_20220527$045924.wav'
#df['pos'] = df['Begin Time (s)'] + 1.25
#df['dur'] = df['End Time (s)'] - df['Begin Time (s)']
#df.drop(df[((df.dur < 2.5)|(df.dur > 8))].index, inplace=True)
df = pd.read_csv('all_annot.csv')
df = df['Cao'] = 'C'
df['pos'] = df['Begin Time (s)'] + 1.25
loader = torch.utils.data.DataLoader(u.Dataset(df, './', fs, 2.5 + 2), batch_size=1, num_workers=8, collate_fn=u.collate_fn)
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
df = pd.read_csv(args.detections)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=1, num_workers=8, collate_fn=u.collate_fn, shuffle=True)
for x, idx in tqdm(loader):
x = frontend(x).squeeze().detach()[:,pad_width:-pad_width]
x = frontend(x).squeeze().detach()
plt.figure()
plt.imshow(x, origin='lower', aspect='auto')
plt.imshow(x, origin='lower', aspect='auto', vmin=torch.quantile(x, .25), cmap='Greys', vmax=torch.quantile(x, .98))
plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
plt.savefig(f'gibbon_calls/cao_vit_DL/{idx.item()}')
plt.savefig(f'annot_pngs/{idx.item()}')
plt.close()
exit()
sig, fs = sf.read('S4A09100_20220527$045924.wav')
for i, r in tqdm(df.iterrows(), total=len(df)):
# x = frontend(torch.Tensor(sig[int(fs*r['Begin Time (s)']):int(fs*r['End Time (s)']), 1]).unsqueeze(0)).squeeze().detach()
x = frontend(torch.Tensor(sig[int(fs*r['Begin Time (s)']):int(fs*(r['Begin Time (s)']+2.5)), 1]).unsqueeze(0)).squeeze().detach()
plt.imshow(x, origin='lower', aspect='auto', extent=[0, r.dur, 500, 5000])
plt.tight_layout()
plt.savefig(f'gibbon_calls/cao_vit_cut/{i}')
plt.close()
......@@ -29,6 +29,7 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.')
parser.add_argument('-channel', type=int, default=0)
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.')
parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.')
args = parser.parse_args()
......@@ -36,7 +37,9 @@ args = parser.parse_args()
df = pd.read_csv(args.detections)
encodings = np.load(args.encodings, allow_pickle=True).item()
idxs, umap_, embeddings = encodings['idx'], encodings['umap'], encodings['encodings']
frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
args.sampleDur += (2 if args.medfilt else 0)
if args.umap_ndim == 2:
df.loc[idxs, 'umap_x'] = umap_[:,0]
......@@ -76,7 +79,7 @@ if args.umap_ndim == 2:
dur, fs = info.duration, info.samplerate
start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
sig = sig[:, args.channel]
sig = sig[:, row.Channel -1 if 'Channel' in row else args.channel]
if fs != args.SR:
sig = signal.resample(sig, int(len(sig)/fs*args.SR))
spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
......@@ -107,11 +110,16 @@ else :
print(df.cluster.value_counts().describe())
if df.cluster.isna().sum() == 0:
df.cluster = df.cluster.astype(int)
if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y':
exit()
os.system('rm -R cluster_pngs/*')
df.to_csv(f'cluster_pngs/{args.detections}', index=False)
for c, grp in df.groupby('cluster'):
if c == -1 or len(grp) > 10_000:
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment