From 6eb7d405c423778851b0939f65a30080e0e86931 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 19 Aug 2019 13:14:39 +0200 Subject: [PATCH] Improved multiMLP --- neural_network/src/MultiMLP.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/neural_network/src/MultiMLP.cpp b/neural_network/src/MultiMLP.cpp index 6458022..9c01f6b 100644 --- a/neural_network/src/MultiMLP.cpp +++ b/neural_network/src/MultiMLP.cpp @@ -96,7 +96,12 @@ float MultiMLP::update(FeatureModel::FeatureDescription & fd, int gold) { int id = std::stoi(split(mlp.name, '_')[1]); mlp.setBatchSize(getBatchSize()); - loss += mlp.update(fd, id == gold ? 1 : 0); + if (gold >= 0) + loss += mlp.update(fd, id == gold ? 1 : 0); + else if (id == (-1-gold)) + loss += mlp.update(fd, 0); + else + continue; trainer->update(); } catch (BatchNotFull &) @@ -121,7 +126,12 @@ float MultiMLP::getLoss(FeatureModel::FeatureDescription & fd, int gold) { int id = std::stoi(split(mlp.name, '_')[1]); mlp.setBatchSize(getBatchSize()); - loss += mlp.getLoss(fd, id == gold ? 1 : 0); + if (gold >= 0) + loss += mlp.update(fd, id == gold ? 1 : 0); + else if (id == (-1-gold)) + loss += mlp.update(fd, 0); + else + continue; } catch (BatchNotFull &) { } -- GitLab