diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp
index 49619151bac810cdd90de910312caf6aede85166..4c1e3661d36451c82fe7cbb6f77d3ad397bafdc3 100644
--- a/torch_modules/src/ConcatWordsNetwork.cpp
+++ b/torch_modules/src/ConcatWordsNetwork.cpp
@@ -35,8 +35,6 @@ torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
   // reshaped dim = {batch, sequence of embeddings}
   auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
 
-  auto res = torch::softmax(linear2(torch::relu(linear1(reshaped))), reshaped.dim() == 2 ? 1 : 0);
-
-  return res;
+  return linear2(torch::relu(linear1(reshaped)));
 }
 
diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp
index 1d9b3869304f48196a610ac085af0c603d69c764..c054e6dd8c5d9a5810a57a448ad318030164efa2 100644
--- a/torch_modules/src/OneWordNetwork.cpp
+++ b/torch_modules/src/OneWordNetwork.cpp
@@ -45,7 +45,7 @@ torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
   if (reshaped.dim() == 3)
     reshaped = wordsAsEmb.permute({1,0,2});
 
-  auto res = torch::softmax(linear(reshaped[focusedIndex]), reshaped.dim() == 3 ? 1 : 0);
+  auto res = linear(reshaped[focusedIndex]);
 
   return res;
 }
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 94d23d2aa9f49f9eba40694e133bf41ba48bd185..a439c885351936359718d5b96ed0e67c9f6fb201 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -58,8 +58,8 @@ void Trainer::createDataset(SubConfig & config, bool debug)
 
   dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 
-  denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)));
-  sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5))); 
+  denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
+  sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
 }
 
 float Trainer::epoch(bool printAdvancement)
@@ -70,6 +70,8 @@ float Trainer::epoch(bool printAdvancement)
   int nbExamplesUntilPrint = printInterval;
   int currentBatchNumber = 0;
 
+  auto lossFct = torch::nn::CrossEntropyLoss();
+
   for (auto & batch : *dataLoader)
   {
     denseOptimizer->zero_grad();
@@ -80,7 +82,7 @@ float Trainer::epoch(bool printAdvancement)
 
     auto prediction = machine.getClassifier()->getNN()(data);
 
-    auto loss = torch::nll_loss(torch::log(prediction), labels);
+    auto loss = lossFct(prediction, labels);
     try
     {
       totalLoss += loss.item<float>();