diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 702303007b572d68e577c02e737d5a41a84dfce7..c685fa200e6185810b524e7698f1baa6113868f4 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -66,6 +66,8 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std getNN()->loadDicts(path); getNN()->registerEmbeddings(); + getNN()->to(NeuralNetworkImpl::device); + if (!train) torch::load(getNN(), getBestFilename()); else if (std::filesystem::exists(getLastFilename()))