Skip to content
Snippets Groups Projects
Commit 65f9e9c6 authored by Paul Best's avatar Paul Best
Browse files

Update file compute_embeddings.py

parent 362b4aa4
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ import numpy as np, pandas as pd, torch
import umap
from tqdm import tqdm
import argparse
torch.multiprocessing.set_sharing_strategy('file_system')
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py")
parser.add_argument('modelname', type=str, help='Filename of the AE weights (.stdc)')
......@@ -26,7 +27,7 @@ df = pd.read_csv(args.detections)
print('Computing AE projections...')
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur, channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8)
with torch.no_grad():
with torch.inference_mode():
encodings, idxs = [], []
for x, idx in tqdm(loader):
encoding = model[:2](x.to(device))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment