Skip to content
Snippets Groups Projects
Commit e350f836 authored by Franck Dary's avatar Franck Dary
Browse files

Optimizer is now a member of Classifier, and its parameters can be given in rm file

parent 979104ff
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,7 @@ class Classifier ...@@ -12,6 +12,7 @@ class Classifier
std::string name; std::string name;
std::unique_ptr<TransitionSet> transitionSet; std::unique_ptr<TransitionSet> transitionSet;
std::shared_ptr<NeuralNetworkImpl> nn; std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Adam> optimizer;
private : private :
...@@ -25,6 +26,9 @@ class Classifier ...@@ -25,6 +26,9 @@ class Classifier
NeuralNetwork & getNN(); NeuralNetwork & getNN();
const std::string & getName() const; const std::string & getName() const;
int getNbParameters() const; int getNbParameters() const;
void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path);
torch::optim::Adam & getOptimizer();
}; };
#endif #endif
...@@ -65,6 +65,22 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) ...@@ -65,6 +65,22 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType)); util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));
this->nn->to(NeuralNetworkImpl::device); this->nn->to(NeuralNetworkImpl::device);
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
{
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
if (sm.str(1) == "Adam")
{
auto splited = util::split(sm.str(2), ' ');
if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
util::myThrow(expected);
optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").beta1(std::stof(splited[1])).beta2(std::stof(splited[2])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
}
else
util::myThrow(expected);
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
} }
void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex) void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex)
...@@ -272,3 +288,19 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size ...@@ -272,3 +288,19 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout)); this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout));
} }
void Classifier::loadOptimizer(std::filesystem::path path)
{
torch::load(*optimizer, path);
}
void Classifier::saveOptimizer(std::filesystem::path path)
{
torch::save(*optimizer, path);
}
torch::optim::Adam & Classifier::getOptimizer()
{
return *optimizer;
}
...@@ -19,7 +19,6 @@ class Trainer ...@@ -19,7 +19,6 @@ class Trainer
std::unique_ptr<Dataset> devDataset{nullptr}; std::unique_ptr<Dataset> devDataset{nullptr};
DataLoader dataLoader{nullptr}; DataLoader dataLoader{nullptr};
DataLoader devDataLoader{nullptr}; DataLoader devDataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0}; std::size_t epochNumber{0};
int batchSize; int batchSize;
...@@ -36,8 +35,6 @@ class Trainer ...@@ -36,8 +35,6 @@ class Trainer
void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float epoch(bool printAdvancement); float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement); float evalOnDev(bool printAdvancement);
void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path);
}; };
#endif #endif
...@@ -157,7 +157,7 @@ int MacaonTrain::main() ...@@ -157,7 +157,7 @@ int MacaonTrain::main()
auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
if (std::filesystem::exists(trainInfos)) if (std::filesystem::exists(trainInfos))
trainer.loadOptimizer(optimizerCheckpoint); machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
for (; currentEpoch < nbEpoch; currentEpoch++) for (; currentEpoch < nbEpoch; currentEpoch++)
{ {
...@@ -204,7 +204,7 @@ int MacaonTrain::main() ...@@ -204,7 +204,7 @@ int MacaonTrain::main()
machine.saveBest(); machine.saveBest();
} }
machine.saveLast(); machine.saveLast();
trainer.saveOptimizer(optimizerCheckpoint); machine.getClassifier()->saveOptimizer(optimizerCheckpoint);
if (printAdvancement) if (printAdvancement)
fmt::print(stderr, "\r{:80}\r", ""); fmt::print(stderr, "\r{:80}\r", "");
std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
......
...@@ -14,9 +14,6 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem ...@@ -14,9 +14,6 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
trainDataset.reset(new Dataset(dir)); trainDataset.reset(new Dataset(dir));
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
if (optimizer.get() == nullptr)
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
} }
void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
...@@ -184,7 +181,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance ...@@ -184,7 +181,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
for (auto & batch : *loader) for (auto & batch : *loader)
{ {
if (train) if (train)
optimizer->zero_grad(); machine.getClassifier()->getOptimizer().zero_grad();
auto data = batch.first; auto data = batch.first;
auto labels = batch.second; auto labels = batch.second;
...@@ -205,7 +202,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance ...@@ -205,7 +202,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
if (train) if (train)
{ {
loss.backward(); loss.backward();
optimizer->step(); machine.getClassifier()->getOptimizer().step();
} }
totalNbExamplesProcessed += torch::numel(labels); totalNbExamplesProcessed += torch::numel(labels);
...@@ -245,13 +242,3 @@ float Trainer::evalOnDev(bool printAdvancement) ...@@ -245,13 +242,3 @@ float Trainer::evalOnDev(bool printAdvancement)
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
} }
void Trainer::loadOptimizer(std::filesystem::path path)
{
torch::load(*optimizer, path);
}
void Trainer::saveOptimizer(std::filesystem::path path)
{
torch::save(*optimizer, path);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment