Skip to content
Snippets Groups Projects
Commit b4c4cd4c authored by Franck Dary's avatar Franck Dary
Browse files

Making sure optimizer have access to embeddings parameters

parent 93b2c58c
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@ class Classifier
std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets;
std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Adam> optimizer;
std::string optimizerType, optimizerParameters;
std::string state;
private :
......@@ -27,6 +28,7 @@ class Classifier
NeuralNetwork & getNN();
const std::string & getName() const;
int getNbParameters() const;
void resetOptimizer();
void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path);
torch::optim::Adam & getOptimizer();
......
......@@ -92,17 +92,8 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
{
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
if (sm.str(1) == "Adam")
{
auto splited = util::split(sm.str(2), ' ');
if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
util::myThrow(expected);
optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
}
else
util::myThrow(expected);
optimizerType = sm.str(1);
optimizerParameters = sm.str(2);
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
}
......@@ -147,3 +138,18 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions));
}
void Classifier::resetOptimizer()
{
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
if (optimizerType == "Adam")
{
auto splited = util::split(optimizerParameters, ' ');
if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
util::myThrow(expected);
optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
}
else
util::myThrow(expected);
}
......@@ -167,6 +167,7 @@ int MacaonTrain::main()
if (!computeDevScore)
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
machine.getClassifier()->resetOptimizer();
auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
if (std::filesystem::exists(trainInfos))
machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment