diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 704b3ebe13c0ca83f7636e19d3f867f87d584996..2ea76708c9209ed217b127e40b4bef4e0677b00a 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -66,13 +66,17 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std getNN()->loadDicts(path); getNN()->registerEmbeddings(); - getNN()->to(NeuralNetworkImpl::device); + getNN()->to(torch::kCPU); if (!train) + { torch::load(getNN(), getBestFilename()); + getNN()->to(NeuralNetworkImpl::device); + } else if (std::filesystem::exists(getLastFilename())) { torch::load(getNN(), getLastFilename()); + getNN()->to(NeuralNetworkImpl::device); resetOptimizer(); loadOptimizer(); } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index e732e7916a14af036f22e6cfaa2f4f38c148a425..40a7c77fd2d828a427ddd0e882cfbac85c90bc8e 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -138,7 +138,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(); + auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0); float bestScore = std::numeric_limits<float>::min(); std::vector<int> candidates;