diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 508d6041bcaae17442c0fd74f6261988b313a89c..ee9e1e77801f3db44e72f4e49005d08a9cb7fa18 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -90,8 +90,6 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
   else
     util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
 
-  this->nn->to(NeuralNetworkImpl::device);
-
   if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
         {
           std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 0ff56500fc6f890654eb24d2c73a8bb89e2ee6ce..bbb05d47a0e71458fea0a2c739614b93fde22715 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -28,6 +28,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
     maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size());
   }
   classifier->getNN()->registerEmbeddings(maxDictSize);
+  classifier->getNN()->to(NeuralNetworkImpl::device);
 
   torch::load(classifier->getNN(), models[0]);
 }
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index e2cfb32bb80683d6bde3a2e5adb65b67c82ef8be..601b17670e366281e01f6e12cf493643b70f68b0 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -130,6 +130,7 @@ int MacaonTrain::main()
     maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
   }
   machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
+  machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
   machine.saveDicts();
 
   float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();