From a4844d76e92cd796dd2cad107b3984622aa7f3e2 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 21 Feb 2020 10:46:12 +0100
Subject: [PATCH] Changed the loss to CrossEntropy, to avoid using log (caused
 exploding gradient)

---
 torch_modules/src/ConcatWordsNetwork.cpp | 4 +---
 torch_modules/src/OneWordNetwork.cpp     | 2 +-
 trainer/src/Trainer.cpp                  | 8 +++++---
 3 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp
index 4961915..4c1e366 100644
--- a/torch_modules/src/ConcatWordsNetwork.cpp
+++ b/torch_modules/src/ConcatWordsNetwork.cpp
@@ -35,8 +35,6 @@ torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
   // reshaped dim = {batch, sequence of embeddings}
   auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
 
-  auto res = torch::softmax(linear2(torch::relu(linear1(reshaped))), reshaped.dim() == 2 ? 1 : 0);
-
-  return res;
+  return linear2(torch::relu(linear1(reshaped)));
 }
 
diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp
index 1d9b386..c054e6d 100644
--- a/torch_modules/src/OneWordNetwork.cpp
+++ b/torch_modules/src/OneWordNetwork.cpp
@@ -45,7 +45,7 @@ torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
   if (reshaped.dim() == 3)
     reshaped = wordsAsEmb.permute({1,0,2});
 
-  auto res = torch::softmax(linear(reshaped[focusedIndex]), reshaped.dim() == 3 ? 1 : 0);
+  auto res = linear(reshaped[focusedIndex]);
 
   return res;
 }
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 94d23d2..a439c88 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -58,8 +58,8 @@ void Trainer::createDataset(SubConfig & config, bool debug)
 
   dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 
-  denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)));
-  sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5))); 
+  denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
+  sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
 }
 
 float Trainer::epoch(bool printAdvancement)
@@ -70,6 +70,8 @@ float Trainer::epoch(bool printAdvancement)
   int nbExamplesUntilPrint = printInterval;
   int currentBatchNumber = 0;
 
+  auto lossFct = torch::nn::CrossEntropyLoss();
+
   for (auto & batch : *dataLoader)
   {
     denseOptimizer->zero_grad();
@@ -80,7 +82,7 @@ float Trainer::epoch(bool printAdvancement)
 
     auto prediction = machine.getClassifier()->getNN()(data);
 
-    auto loss = torch::nll_loss(torch::log(prediction), labels);
+    auto loss = lossFct(prediction, labels);
     try
     {
       totalLoss += loss.item<float>();
-- 
GitLab