Skip to content
Snippets Groups Projects
Commit 197e07b2 authored by Paul Best's avatar Paul Best
Browse files

medfilt in compute embeddings

parent 36a96885
No related branches found
No related tags found
No related merge requests found
......@@ -15,10 +15,12 @@ parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands fo
parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.set_defaults(feature=False)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
......@@ -26,7 +28,7 @@ model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
df = pd.read_csv(args.detections)
print('Computing AE projections...')
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur, channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0), channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn)
with torch.inference_mode():
encodings, idxs = [], []
for x, idx in tqdm(loader):
......@@ -38,4 +40,4 @@ encodings = np.stack(encodings)
print('Computing UMAP projections...')
X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split('.')[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X})
np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split(".")[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment