From 761ea87c1cc5052e4e97b3e15ec7b98aa6291020 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 9 Jul 2020 08:57:56 +0200
Subject: [PATCH] Making sure the nn are loaded to the correct device

---
 reading_machine/src/Classifier.cpp | 6 +++++-
 trainer/src/Trainer.cpp            | 2 +-
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 704b3eb..2ea7670 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -66,13 +66,17 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
   getNN()->loadDicts(path);
   getNN()->registerEmbeddings();
 
-  getNN()->to(NeuralNetworkImpl::device);
+  getNN()->to(torch::kCPU);
 
   if (!train)
+  {
     torch::load(getNN(), getBestFilename());
+    getNN()->to(NeuralNetworkImpl::device);
+  }
   else if (std::filesystem::exists(getLastFilename()))
   {
     torch::load(getNN(), getLastFilename());
+    getNN()->to(NeuralNetworkImpl::device);
     resetOptimizer();
     loadOptimizer();
   }
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index e732e79..40a7c77 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -138,7 +138,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
     {
       auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
-      auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze();
+      auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0);
   
       float bestScore = std::numeric_limits<float>::min();
       std::vector<int> candidates;
-- 
GitLab