Skip to content
Snippets Groups Projects
Commit b4c4cd4c authored by Franck Dary's avatar Franck Dary
Browse files

Making sure optimizer have access to embeddings parameters

parent 93b2c58c
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,7 @@ class Classifier ...@@ -13,6 +13,7 @@ class Classifier
std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets; std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets;
std::shared_ptr<NeuralNetworkImpl> nn; std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Adam> optimizer; std::unique_ptr<torch::optim::Adam> optimizer;
std::string optimizerType, optimizerParameters;
std::string state; std::string state;
private : private :
...@@ -27,6 +28,7 @@ class Classifier ...@@ -27,6 +28,7 @@ class Classifier
NeuralNetwork & getNN(); NeuralNetwork & getNN();
const std::string & getName() const; const std::string & getName() const;
int getNbParameters() const; int getNbParameters() const;
void resetOptimizer();
void loadOptimizer(std::filesystem::path path); void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path); void saveOptimizer(std::filesystem::path path);
torch::optim::Adam & getOptimizer(); torch::optim::Adam & getOptimizer();
......
...@@ -92,17 +92,8 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) ...@@ -92,17 +92,8 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm) if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
{ {
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'"; optimizerType = sm.str(1);
if (sm.str(1) == "Adam") optimizerParameters = sm.str(2);
{
auto splited = util::split(sm.str(2), ' ');
if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
util::myThrow(expected);
optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
}
else
util::myThrow(expected);
})) }))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}")); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
} }
...@@ -147,3 +138,18 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s ...@@ -147,3 +138,18 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions)); this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions));
} }
void Classifier::resetOptimizer()
{
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
if (optimizerType == "Adam")
{
auto splited = util::split(optimizerParameters, ' ');
if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
util::myThrow(expected);
optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
}
else
util::myThrow(expected);
}
...@@ -167,6 +167,7 @@ int MacaonTrain::main() ...@@ -167,6 +167,7 @@ int MacaonTrain::main()
if (!computeDevScore) if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval); trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
machine.getClassifier()->resetOptimizer();
auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
if (std::filesystem::exists(trainInfos)) if (std::filesystem::exists(trainInfos))
machine.getClassifier()->loadOptimizer(optimizerCheckpoint); machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment