Skip to content
Snippets Groups Projects
Commit 3308c406 authored by Franck Dary's avatar Franck Dary
Browse files

Improved multiMLP

parent 6eb7d405
Branches
No related tags found
No related merge requests found
#include "MultiMLP.hpp"
#include <cmath>
MultiMLP::MultiMLP()
{
......@@ -77,14 +78,19 @@ dynet::Trainer * MultiMLP::createTrainer()
std::vector<float> MultiMLP::predict(FeatureModel::FeatureDescription & fd)
{
double totalSum = 0.0;
std::vector<float> prediction(mlps.size());
for (unsigned int i = 0; i < mlps.size(); i++)
{
int id = std::stoi(split(mlps[i].name, '_')[1]);
auto value = mlps[i].predict(fd);
prediction[id] = value[1];
prediction[id] = exp(value[1]);
totalSum += prediction[id];
}
for (unsigned int i = 0; i < prediction.size(); i++)
prediction[i] /= totalSum;
return prediction;
}
......@@ -98,10 +104,8 @@ float MultiMLP::update(FeatureModel::FeatureDescription & fd, int gold)
mlp.setBatchSize(getBatchSize());
if (gold >= 0)
loss += mlp.update(fd, id == gold ? 1 : 0);
else if (id == (-1-gold))
loss += mlp.update(fd, 0);
else
continue;
loss += mlp.update(fd, id == (-1-gold) ? 0 : 1);
trainer->update();
} catch (BatchNotFull &)
......@@ -128,10 +132,8 @@ float MultiMLP::getLoss(FeatureModel::FeatureDescription & fd, int gold)
mlp.setBatchSize(getBatchSize());
if (gold >= 0)
loss += mlp.update(fd, id == gold ? 1 : 0);
else if (id == (-1-gold))
loss += mlp.update(fd, 0);
else
continue;
loss += mlp.update(fd, id == (-1-gold) ? 0 : 1);
} catch (BatchNotFull &)
{
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment