Commit 761ea87c authored by Franck Dary's avatar Franck Dary
Browse files

Making sure the nn are loaded to the correct device

parent e35487c9
......@@ -66,13 +66,17 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
getNN()->loadDicts(path);
getNN()->registerEmbeddings();
getNN()->to(NeuralNetworkImpl::device);
getNN()->to(torch::kCPU);
if (!train)
{
torch::load(getNN(), getBestFilename());
getNN()->to(NeuralNetworkImpl::device);
}
else if (std::filesystem::exists(getLastFilename()))
{
torch::load(getNN(), getLastFilename());
getNN()->to(NeuralNetworkImpl::device);
resetOptimizer();
loadOptimizer();
}
......
......@@ -138,7 +138,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze();
auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0);
float bestScore = std::numeric_limits<float>::min();
std::vector<int> candidates;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment