From 6eb7d405c423778851b0939f65a30080e0e86931 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 19 Aug 2019 13:14:39 +0200
Subject: [PATCH] Improved multiMLP

---
 neural_network/src/MultiMLP.cpp | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/neural_network/src/MultiMLP.cpp b/neural_network/src/MultiMLP.cpp
index 6458022..9c01f6b 100644
--- a/neural_network/src/MultiMLP.cpp
+++ b/neural_network/src/MultiMLP.cpp
@@ -96,7 +96,12 @@ float MultiMLP::update(FeatureModel::FeatureDescription & fd, int gold)
     {
       int id = std::stoi(split(mlp.name, '_')[1]);
       mlp.setBatchSize(getBatchSize());
-      loss += mlp.update(fd, id == gold ? 1 : 0);
+      if (gold >= 0)
+        loss += mlp.update(fd, id == gold ? 1 : 0);
+      else if (id == (-1-gold))
+        loss += mlp.update(fd, 0);
+      else
+        continue;
 
       trainer->update();
     } catch (BatchNotFull &)
@@ -121,7 +126,12 @@ float MultiMLP::getLoss(FeatureModel::FeatureDescription & fd, int gold)
     {
       int id = std::stoi(split(mlp.name, '_')[1]);
       mlp.setBatchSize(getBatchSize());
-      loss += mlp.getLoss(fd, id == gold ? 1 : 0);
+      if (gold >= 0)
+        loss += mlp.update(fd, id == gold ? 1 : 0);
+      else if (id == (-1-gold))
+        loss += mlp.update(fd, 0);
+      else
+        continue;
     } catch (BatchNotFull &)
     {
     }
-- 
GitLab