Commit 761ea87c authored by Franck Dary's avatar Franck Dary
Browse files

Making sure the nn are loaded to the correct device

parent e35487c9
...@@ -66,13 +66,17 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std ...@@ -66,13 +66,17 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
getNN()->loadDicts(path); getNN()->loadDicts(path);
getNN()->registerEmbeddings(); getNN()->registerEmbeddings();
getNN()->to(NeuralNetworkImpl::device); getNN()->to(torch::kCPU);
if (!train) if (!train)
{
torch::load(getNN(), getBestFilename()); torch::load(getNN(), getBestFilename());
getNN()->to(NeuralNetworkImpl::device);
}
else if (std::filesystem::exists(getLastFilename())) else if (std::filesystem::exists(getLastFilename()))
{ {
torch::load(getNN(), getLastFilename()); torch::load(getNN(), getLastFilename());
getNN()->to(NeuralNetworkImpl::device);
resetOptimizer(); resetOptimizer();
loadOptimizer(); loadOptimizer();
} }
......
...@@ -138,7 +138,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -138,7 +138,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{ {
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(); auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0);
float bestScore = std::numeric_limits<float>::min(); float bestScore = std::numeric_limits<float>::min();
std::vector<int> candidates; std::vector<int> candidates;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment