diff --git a/neural_network/src/MultiMLP.cpp b/neural_network/src/MultiMLP.cpp index 9c01f6be2deffeec4df8c3554e9f47bb83dd2d6b..de188aa40935f79a3ca9f2ed2ed16b4c9755cae0 100644 --- a/neural_network/src/MultiMLP.cpp +++ b/neural_network/src/MultiMLP.cpp @@ -1,4 +1,5 @@ #include "MultiMLP.hpp" +#include <cmath> MultiMLP::MultiMLP() { @@ -77,14 +78,19 @@ dynet::Trainer * MultiMLP::createTrainer() std::vector<float> MultiMLP::predict(FeatureModel::FeatureDescription & fd) { + double totalSum = 0.0; std::vector<float> prediction(mlps.size()); for (unsigned int i = 0; i < mlps.size(); i++) { int id = std::stoi(split(mlps[i].name, '_')[1]); auto value = mlps[i].predict(fd); - prediction[id] = value[1]; + prediction[id] = exp(value[1]); + totalSum += prediction[id]; } + for (unsigned int i = 0; i < prediction.size(); i++) + prediction[i] /= totalSum; + return prediction; } @@ -98,10 +104,8 @@ float MultiMLP::update(FeatureModel::FeatureDescription & fd, int gold) mlp.setBatchSize(getBatchSize()); if (gold >= 0) loss += mlp.update(fd, id == gold ? 1 : 0); - else if (id == (-1-gold)) - loss += mlp.update(fd, 0); else - continue; + loss += mlp.update(fd, id == (-1-gold) ? 0 : 1); trainer->update(); } catch (BatchNotFull &) @@ -128,10 +132,8 @@ float MultiMLP::getLoss(FeatureModel::FeatureDescription & fd, int gold) mlp.setBatchSize(getBatchSize()); if (gold >= 0) loss += mlp.update(fd, id == gold ? 1 : 0); - else if (id == (-1-gold)) - loss += mlp.update(fd, 0); else - continue; + loss += mlp.update(fd, id == (-1-gold) ? 0 : 1); } catch (BatchNotFull &) { }