From 65f9e9c6b772ec55585f7b85312336a0d0d46123 Mon Sep 17 00:00:00 2001 From: Paul Best <paul.best@lis-lab.fr> Date: Thu, 25 May 2023 17:34:29 +0200 Subject: [PATCH] Update file compute_embeddings.py --- new_specie/compute_embeddings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py index 262f8f7..512554b 100755 --- a/new_specie/compute_embeddings.py +++ b/new_specie/compute_embeddings.py @@ -4,6 +4,7 @@ import numpy as np, pandas as pd, torch import umap from tqdm import tqdm import argparse +torch.multiprocessing.set_sharing_strategy('file_system') parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py") parser.add_argument('modelname', type=str, help='Filename of the AE weights (.stdc)') @@ -26,7 +27,7 @@ df = pd.read_csv(args.detections) print('Computing AE projections...') loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur, channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8) -with torch.no_grad(): +with torch.inference_mode(): encodings, idxs = [], [] for x, idx in tqdm(loader): encoding = model[:2](x.to(device)) -- GitLab