From 431fab993d759be05c12b855d184fa473271aa20 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 29 Apr 2020 22:47:54 +0200 Subject: [PATCH] Corrected a bug where embeddings module was not sent to cuda --- reading_machine/src/Classifier.cpp | 2 -- reading_machine/src/ReadingMachine.cpp | 1 + trainer/src/MacaonTrain.cpp | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 508d604..ee9e1e7 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -90,8 +90,6 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) else util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType)); - this->nn->to(NeuralNetworkImpl::device); - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm) { std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'"; diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 0ff5650..bbb05d4 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -28,6 +28,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size()); } classifier->getNN()->registerEmbeddings(maxDictSize); + classifier->getNN()->to(NeuralNetworkImpl::device); torch::load(classifier->getNN(), models[0]); } diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index e2cfb32..601b176 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -130,6 +130,7 @@ int MacaonTrain::main() maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size()); } machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize); + machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); machine.saveDicts(); float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); -- GitLab