diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 177f5f28e4feb0ecdeafe3a64bfffa0b11a63c5b..6e889171c5f1fa461f333605446fb8545d287270 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -26,7 +26,6 @@ class Trainer Trainer(ReadingMachine & machine); void createDataset(SubConfig & goldConfig, bool debug); float epoch(bool printAdvancement); - }; #endif diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index afbf9652f1b0152d73741b4fa341cf725a8e7fc1..590504ef6fc8c75d31a6e597188c9648e099ee96 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -58,7 +58,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).beta1(0.9).beta2(0.999))); + optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999))); } float Trainer::epoch(bool printAdvancement)