From 3308c40622e7877026c085bdc05a8a5804643103 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 21 Aug 2019 11:12:55 +0200 Subject: [PATCH] Improved multiMLP --- neural_network/src/MultiMLP.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/neural_network/src/MultiMLP.cpp b/neural_network/src/MultiMLP.cpp index 9c01f6b..de188aa 100644 --- a/neural_network/src/MultiMLP.cpp +++ b/neural_network/src/MultiMLP.cpp @@ -1,4 +1,5 @@ #include "MultiMLP.hpp" +#include <cmath> MultiMLP::MultiMLP() { @@ -77,14 +78,19 @@ dynet::Trainer * MultiMLP::createTrainer() std::vector<float> MultiMLP::predict(FeatureModel::FeatureDescription & fd) { + double totalSum = 0.0; std::vector<float> prediction(mlps.size()); for (unsigned int i = 0; i < mlps.size(); i++) { int id = std::stoi(split(mlps[i].name, '_')[1]); auto value = mlps[i].predict(fd); - prediction[id] = value[1]; + prediction[id] = exp(value[1]); + totalSum += prediction[id]; } + for (unsigned int i = 0; i < prediction.size(); i++) + prediction[i] /= totalSum; + return prediction; } @@ -98,10 +104,8 @@ float MultiMLP::update(FeatureModel::FeatureDescription & fd, int gold) mlp.setBatchSize(getBatchSize()); if (gold >= 0) loss += mlp.update(fd, id == gold ? 1 : 0); - else if (id == (-1-gold)) - loss += mlp.update(fd, 0); else - continue; + loss += mlp.update(fd, id == (-1-gold) ? 0 : 1); trainer->update(); } catch (BatchNotFull &) @@ -128,10 +132,8 @@ float MultiMLP::getLoss(FeatureModel::FeatureDescription & fd, int gold) mlp.setBatchSize(getBatchSize()); if (gold >= 0) loss += mlp.update(fd, id == gold ? 1 : 0); - else if (id == (-1-gold)) - loss += mlp.update(fd, 0); else - continue; + loss += mlp.update(fd, id == (-1-gold) ? 0 : 1); } catch (BatchNotFull &) { } -- GitLab