Skip to content
Snippets Groups Projects
Commit e350f836 authored by Franck Dary's avatar Franck Dary
Browse files

Optimizer is now a member of Classifier, and its parameters can be given in rm file

parent 979104ff
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ class Classifier
std::string name;
std::unique_ptr<TransitionSet> transitionSet;
std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Adam> optimizer;
private :
......@@ -25,6 +26,9 @@ class Classifier
NeuralNetwork & getNN();
const std::string & getName() const;
int getNbParameters() const;
void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path);
torch::optim::Adam & getOptimizer();
};
#endif
......@@ -65,6 +65,22 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));
this->nn->to(NeuralNetworkImpl::device);
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
{
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
if (sm.str(1) == "Adam")
{
auto splited = util::split(sm.str(2), ' ');
if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
util::myThrow(expected);
optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").beta1(std::stof(splited[1])).beta2(std::stof(splited[2])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
}
else
util::myThrow(expected);
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
}
void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex)
......@@ -272,3 +288,19 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout));
}
void Classifier::loadOptimizer(std::filesystem::path path)
{
torch::load(*optimizer, path);
}
void Classifier::saveOptimizer(std::filesystem::path path)
{
torch::save(*optimizer, path);
}
torch::optim::Adam & Classifier::getOptimizer()
{
return *optimizer;
}
......@@ -19,7 +19,6 @@ class Trainer
std::unique_ptr<Dataset> devDataset{nullptr};
DataLoader dataLoader{nullptr};
DataLoader devDataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0};
int batchSize;
......@@ -36,8 +35,6 @@ class Trainer
void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement);
void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path);
};
#endif
......@@ -157,7 +157,7 @@ int MacaonTrain::main()
auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
if (std::filesystem::exists(trainInfos))
trainer.loadOptimizer(optimizerCheckpoint);
machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
for (; currentEpoch < nbEpoch; currentEpoch++)
{
......@@ -204,7 +204,7 @@ int MacaonTrain::main()
machine.saveBest();
}
machine.saveLast();
trainer.saveOptimizer(optimizerCheckpoint);
machine.getClassifier()->saveOptimizer(optimizerCheckpoint);
if (printAdvancement)
fmt::print(stderr, "\r{:80}\r", "");
std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
......
......@@ -14,9 +14,6 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
trainDataset.reset(new Dataset(dir));
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
if (optimizer.get() == nullptr)
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
}
void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
......@@ -184,7 +181,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
for (auto & batch : *loader)
{
if (train)
optimizer->zero_grad();
machine.getClassifier()->getOptimizer().zero_grad();
auto data = batch.first;
auto labels = batch.second;
......@@ -205,7 +202,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
if (train)
{
loss.backward();
optimizer->step();
machine.getClassifier()->getOptimizer().step();
}
totalNbExamplesProcessed += torch::numel(labels);
......@@ -245,13 +242,3 @@ float Trainer::evalOnDev(bool printAdvancement)
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
}
void Trainer::loadOptimizer(std::filesystem::path path)
{
torch::load(*optimizer, path);
}
void Trainer::saveOptimizer(std::filesystem::path path)
{
torch::save(*optimizer, path);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment