Skip to content
Snippets Groups Projects
Commit 431fab99 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected a bug where embeddings module was not sent to cuda

parent d217cf58
No related branches found
No related tags found
No related merge requests found
......@@ -90,8 +90,6 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
else
util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
this->nn->to(NeuralNetworkImpl::device);
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
{
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
......
......@@ -28,6 +28,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size());
}
classifier->getNN()->registerEmbeddings(maxDictSize);
classifier->getNN()->to(NeuralNetworkImpl::device);
torch::load(classifier->getNN(), models[0]);
}
......
......@@ -130,6 +130,7 @@ int MacaonTrain::main()
maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
}
machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
machine.saveDicts();
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment