Skip to content
Snippets Groups Projects
Commit 3308c406 authored by Franck Dary's avatar Franck Dary
Browse files

Improved multiMLP

parent 6eb7d405
No related branches found
No related tags found
No related merge requests found
#include "MultiMLP.hpp" #include "MultiMLP.hpp"
#include <cmath>
MultiMLP::MultiMLP() MultiMLP::MultiMLP()
{ {
...@@ -77,14 +78,19 @@ dynet::Trainer * MultiMLP::createTrainer() ...@@ -77,14 +78,19 @@ dynet::Trainer * MultiMLP::createTrainer()
std::vector<float> MultiMLP::predict(FeatureModel::FeatureDescription & fd) std::vector<float> MultiMLP::predict(FeatureModel::FeatureDescription & fd)
{ {
double totalSum = 0.0;
std::vector<float> prediction(mlps.size()); std::vector<float> prediction(mlps.size());
for (unsigned int i = 0; i < mlps.size(); i++) for (unsigned int i = 0; i < mlps.size(); i++)
{ {
int id = std::stoi(split(mlps[i].name, '_')[1]); int id = std::stoi(split(mlps[i].name, '_')[1]);
auto value = mlps[i].predict(fd); auto value = mlps[i].predict(fd);
prediction[id] = value[1]; prediction[id] = exp(value[1]);
totalSum += prediction[id];
} }
for (unsigned int i = 0; i < prediction.size(); i++)
prediction[i] /= totalSum;
return prediction; return prediction;
} }
...@@ -98,10 +104,8 @@ float MultiMLP::update(FeatureModel::FeatureDescription & fd, int gold) ...@@ -98,10 +104,8 @@ float MultiMLP::update(FeatureModel::FeatureDescription & fd, int gold)
mlp.setBatchSize(getBatchSize()); mlp.setBatchSize(getBatchSize());
if (gold >= 0) if (gold >= 0)
loss += mlp.update(fd, id == gold ? 1 : 0); loss += mlp.update(fd, id == gold ? 1 : 0);
else if (id == (-1-gold))
loss += mlp.update(fd, 0);
else else
continue; loss += mlp.update(fd, id == (-1-gold) ? 0 : 1);
trainer->update(); trainer->update();
} catch (BatchNotFull &) } catch (BatchNotFull &)
...@@ -128,10 +132,8 @@ float MultiMLP::getLoss(FeatureModel::FeatureDescription & fd, int gold) ...@@ -128,10 +132,8 @@ float MultiMLP::getLoss(FeatureModel::FeatureDescription & fd, int gold)
mlp.setBatchSize(getBatchSize()); mlp.setBatchSize(getBatchSize());
if (gold >= 0) if (gold >= 0)
loss += mlp.update(fd, id == gold ? 1 : 0); loss += mlp.update(fd, id == gold ? 1 : 0);
else if (id == (-1-gold))
loss += mlp.update(fd, 0);
else else
continue; loss += mlp.update(fd, id == (-1-gold) ? 0 : 1);
} catch (BatchNotFull &) } catch (BatchNotFull &)
{ {
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment