From 431fab993d759be05c12b855d184fa473271aa20 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 29 Apr 2020 22:47:54 +0200
Subject: [PATCH] Corrected a bug where embeddings module was not sent to cuda

---
 reading_machine/src/Classifier.cpp     | 2 --
 reading_machine/src/ReadingMachine.cpp | 1 +
 trainer/src/MacaonTrain.cpp            | 1 +
 3 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 508d604..ee9e1e7 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -90,8 +90,6 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
   else
     util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
 
-  this->nn->to(NeuralNetworkImpl::device);
-
   if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
         {
           std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 0ff5650..bbb05d4 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -28,6 +28,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
     maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size());
   }
   classifier->getNN()->registerEmbeddings(maxDictSize);
+  classifier->getNN()->to(NeuralNetworkImpl::device);
 
   torch::load(classifier->getNN(), models[0]);
 }
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index e2cfb32..601b176 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -130,6 +130,7 @@ int MacaonTrain::main()
     maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
   }
   machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
+  machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
   machine.saveDicts();
 
   float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
-- 
GitLab