From 4cd705146235fd7bb2c0c21fa5cede5b6462c7e1 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 16 Aug 2019 16:36:22 +0200
Subject: [PATCH] Corrected multimlp

---
 neural_network/src/MultiMLP.cpp | 31 +++++++++++--------------------
 1 file changed, 11 insertions(+), 20 deletions(-)

diff --git a/neural_network/src/MultiMLP.cpp b/neural_network/src/MultiMLP.cpp
index 03d4f73..6458022 100644
--- a/neural_network/src/MultiMLP.cpp
+++ b/neural_network/src/MultiMLP.cpp
@@ -90,24 +90,20 @@ std::vector<float> MultiMLP::predict(FeatureModel::FeatureDescription & fd)
 
 float MultiMLP::update(FeatureModel::FeatureDescription & fd, int gold)
 {
-  try
-  {
-    for (auto & mlp : mlps)
+  float loss = 0.0;
+  for (auto & mlp : mlps)
+    try
     {
       int id = std::stoi(split(mlp.name, '_')[1]);
-      float loss = 0.0;
       mlp.setBatchSize(getBatchSize());
-      loss = mlp.update(fd, id == gold ? 1 : 0);
+      loss += mlp.update(fd, id == gold ? 1 : 0);
 
       trainer->update();
-      return loss;
+    } catch (BatchNotFull &)
+    {
     }
-  } catch (BatchNotFull &)
-  {
-    return 0.0;
-  }
 
-  return 0.0;
+  return loss;
 }
 
 float MultiMLP::update(FeatureModel::FeatureDescription &, const std::vector<float> &)
@@ -120,20 +116,15 @@ float MultiMLP::update(FeatureModel::FeatureDescription &, const std::vector<flo
 float MultiMLP::getLoss(FeatureModel::FeatureDescription & fd, int gold)
 {
   float loss = 0.0;
-  try
-  {
-    for (auto & mlp : mlps)
+  for (auto & mlp : mlps)
+    try
     {
       int id = std::stoi(split(mlp.name, '_')[1]);
       mlp.setBatchSize(getBatchSize());
       loss += mlp.getLoss(fd, id == gold ? 1 : 0);
-
-      trainer->update();
+    } catch (BatchNotFull &)
+    {
     }
-  } catch (BatchNotFull &)
-  {
-    return 0.0;
-  }
 
   return loss;
 }
-- 
GitLab