From b0bb4445adde77ead17053ad53950aa3bd4dc0ff Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 25 Feb 2020 20:36:21 +0100
Subject: [PATCH] Set amsgrad as default optimizer

---
 trainer/include/Trainer.hpp | 1 -
 trainer/src/Trainer.cpp     | 2 +-
 2 files changed, 1 insertion(+), 2 deletions(-)

diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 177f5f2..6e88917 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -26,7 +26,6 @@ class Trainer
   Trainer(ReadingMachine & machine);
   void createDataset(SubConfig & goldConfig, bool debug);
   float epoch(bool printAdvancement);
-
 };
 
 #endif
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index afbf965..590504e 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -58,7 +58,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
 
   dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 
-  optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).beta1(0.9).beta2(0.999)));
+  optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
 }
 
 float Trainer::epoch(bool printAdvancement)
-- 
GitLab