diff --git a/run_CNN.py b/run_CNN.py
index 07daf2bd919eeaf1531497b1b00bab4b9cd8e9f4..1fc6e4b9ec903ed1e8a92ebaea393bdcacbb8b20 100644
--- a/run_CNN.py
+++ b/run_CNN.py
@@ -60,7 +60,7 @@ class Dataset(torch.utils.data.Dataset):
 model = models.get[args.specie]['archi']
 model.load_state_dict(torch.load(f"{os.path.dirname(__file__)}/weights/{models.get[args.specie]['weights']}"))
 model.eval()
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
 model.to(device)
 
 # prepare data loader and output storage for predictions