From 3aafb6ac5c6d414fde5b92afdb2cdc88752421bb Mon Sep 17 00:00:00 2001 From: "paul.best" <paul.best@lis-lab.fr> Date: Thu, 13 Jul 2023 11:38:01 +0200 Subject: [PATCH] use cuda from argparse --- run_CNN.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_CNN.py b/run_CNN.py index 07daf2b..1fc6e4b 100644 --- a/run_CNN.py +++ b/run_CNN.py @@ -60,7 +60,7 @@ class Dataset(torch.utils.data.Dataset): model = models.get[args.specie]['archi'] model.load_state_dict(torch.load(f"{os.path.dirname(__file__)}/weights/{models.get[args.specie]['weights']}")) model.eval() -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu') model.to(device) # prepare data loader and output storage for predictions -- GitLab