From a4844d76e92cd796dd2cad107b3984622aa7f3e2 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 21 Feb 2020 10:46:12 +0100 Subject: [PATCH] Changed the loss to CrossEntropy, to avoid using log (caused exploding gradient) --- torch_modules/src/ConcatWordsNetwork.cpp | 4 +--- torch_modules/src/OneWordNetwork.cpp | 2 +- trainer/src/Trainer.cpp | 8 +++++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index 4961915..4c1e366 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -35,8 +35,6 @@ torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) // reshaped dim = {batch, sequence of embeddings} auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); - auto res = torch::softmax(linear2(torch::relu(linear1(reshaped))), reshaped.dim() == 2 ? 1 : 0); - - return res; + return linear2(torch::relu(linear1(reshaped))); } diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp index 1d9b386..c054e6d 100644 --- a/torch_modules/src/OneWordNetwork.cpp +++ b/torch_modules/src/OneWordNetwork.cpp @@ -45,7 +45,7 @@ torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) if (reshaped.dim() == 3) reshaped = wordsAsEmb.permute({1,0,2}); - auto res = torch::softmax(linear(reshaped[focusedIndex]), reshaped.dim() == 3 ? 1 : 0); + auto res = linear(reshaped[focusedIndex]); return res; } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 94d23d2..a439c88 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -58,8 +58,8 @@ void Trainer::createDataset(SubConfig & config, bool debug) dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5))); - sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5))); + denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); + sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); } float Trainer::epoch(bool printAdvancement) @@ -70,6 +70,8 @@ float Trainer::epoch(bool printAdvancement) int nbExamplesUntilPrint = printInterval; int currentBatchNumber = 0; + auto lossFct = torch::nn::CrossEntropyLoss(); + for (auto & batch : *dataLoader) { denseOptimizer->zero_grad(); @@ -80,7 +82,7 @@ float Trainer::epoch(bool printAdvancement) auto prediction = machine.getClassifier()->getNN()(data); - auto loss = torch::nll_loss(torch::log(prediction), labels); + auto loss = lossFct(prediction, labels); try { totalLoss += loss.item<float>(); -- GitLab