Skip to content
Snippets Groups Projects
Commit c014eaee authored by Franck Dary's avatar Franck Dary
Browse files

Changed score computation during training

parent c0935734
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@
#include <string>
#include <vector>
#include "ProgramParameters.hpp"
#include "Config.hpp"
class TrainInfos
{
......@@ -23,10 +24,8 @@ class TrainInfos
std::map< std::string, std::vector<float> > devScoresPerClassifierPerEpoch;
std::map< std::string, std::vector<bool> > mustSavePerClassifierPerEpoch;
std::map< std::string, std::pair<int,int> > trainCounter;
std::map<std::string, float> trainLossCounter;
std::map<std::string, float> devLossCounter;
std::map< std::string, std::pair<int,int> > devCounter;
std::map<std::string, bool> topologyPrinted;
......@@ -34,8 +33,6 @@ class TrainInfos
void readFromFilename();
void saveToFilename();
float computeTrainScore(const std::string & classifier);
float computeDevScore(const std::string & classifier);
void addTrainScore(const std::string & classifier, float score);
void addDevScore(const std::string & classifier, float score);
......@@ -44,14 +41,8 @@ class TrainInfos
TrainInfos();
void addTrainLoss(const std::string & classifier, float loss);
void addDevLoss(const std::string & classifier, float loss);
void addTrainExample(const std::string & classifier, float loss);
void addDevExample(const std::string & classifier);
void addDevExample(const std::string & classifier, float loss);
void addTrainSuccess(const std::string & classifier);
void addDevSuccess(const std::string & classifier);
void resetCounters();
void computeTrainScores();
void computeDevScores();
void computeTrainScores(Config & c);
void computeDevScores(Config & c);
void computeMustSaves();
int getEpoch();
bool isTopologyPrinted(const std::string & classifier);
......
......@@ -9,10 +9,8 @@ TrainInfos::TrainInfos()
lastSaved = 0;
if (fileExists(filename))
{
readFromFilename();
}
}
void TrainInfos::readFromFilename()
{
......@@ -127,7 +125,7 @@ void TrainInfos::saveToFilename()
void TrainInfos::addTrainLoss(const std::string & classifier, float loss)
{
trainLossesPerClassifierPerEpoch[classifier].emplace_back(loss);
trainLossesPerClassifierPerEpoch[classifier].back() += loss;
}
void TrainInfos::addDevLoss(const std::string & classifier, float loss)
......@@ -145,68 +143,30 @@ void TrainInfos::addDevScore(const std::string & classifier, float score)
devScoresPerClassifierPerEpoch[classifier].emplace_back(score);
}
float TrainInfos::computeTrainScore(const std::string & classifier)
void TrainInfos::computeTrainScores(Config & c)
{
return 100.0*trainCounter[classifier].first / trainCounter[classifier].second;
}
float TrainInfos::computeDevScore(const std::string & classifier)
for (auto & it : topologyPrinted)
{
return 100.0*devCounter[classifier].first / devCounter[classifier].second;
}
void TrainInfos::addTrainExample(const std::string & classifier, float loss)
if (it.first == "Parser")
{
trainCounter[classifier].second++;
trainLossCounter[classifier] += loss;
float govScore = c.getTape("GOV").getScore();
float labelScore = c.getTape("LABEL").getScore();
float score = (govScore + labelScore) / 2;
addTrainScore(it.first, score);
}
void TrainInfos::addDevExample(const std::string & classifier)
{
devCounter[classifier].second++;
}
void TrainInfos::addDevExample(const std::string & classifier, float loss)
{
devCounter[classifier].second++;
devLossCounter[classifier] += loss;
}
void TrainInfos::addTrainSuccess(const std::string & classifier)
void TrainInfos::computeDevScores(Config & c)
{
trainCounter[classifier].first++;
}
void TrainInfos::addDevSuccess(const std::string & classifier)
for (auto & it : topologyPrinted)
{
devCounter[classifier].first++;
}
void TrainInfos::resetCounters()
{
trainCounter.clear();
devCounter.clear();
}
void TrainInfos::computeTrainScores()
{
for (auto & it : trainCounter)
{
addTrainScore(it.first, computeTrainScore(it.first));
addTrainLoss(it.first, trainLossCounter[it.first]);
trainLossCounter[it.first] = 0.0;
}
}
void TrainInfos::computeDevScores()
if (it.first == "Parser")
{
for (auto & it : devCounter)
{
addDevScore(it.first, computeDevScore(it.first));
if (devLossCounter.count(it.first))
{
addDevLoss(it.first, devLossCounter[it.first]);
devLossCounter[it.first] = 0.0;
float govScore = c.getTape("GOV").getScore();
float labelScore = c.getTape("LABEL").getScore();
float score = (govScore + labelScore) / 2;
addDevScore(it.first, score);
}
}
}
......@@ -223,14 +183,18 @@ bool TrainInfos::isTopologyPrinted(const std::string & classifier)
void TrainInfos::setTopologyPrinted(const std::string & classifier)
{
topologyPrinted[classifier] = true;
trainLossesPerClassifierPerEpoch[classifier].emplace_back(0.0);
if (ProgramParameters::devLoss)
devLossesPerClassifierPerEpoch[classifier].emplace_back(0.0);
}
void TrainInfos::nextEpoch()
{
lastEpoch++;
saveToFilename();
for (auto & it : topologyPrinted)
trainLossesPerClassifierPerEpoch[it.first].emplace_back(0.0);
}
void TrainInfos::computeMustSaves()
......
......@@ -111,20 +111,11 @@ void Trainer::computeScoreOnDev()
}
}
bool pActionIsZeroCost = tm.getCurrentClassifier()->getActionCost(*devConfig, pAction) == 0;
if (ProgramParameters::devLoss)
{
float loss = tm.getCurrentClassifier()->getLoss(*devConfig, tm.getCurrentClassifier()->getActionIndex(oAction));
TI.addDevExample(tm.getCurrentClassifier()->name, loss);
TI.addDevLoss(tm.getCurrentClassifier()->name, loss);
}
else
{
TI.addDevExample(tm.getCurrentClassifier()->name);
}
if (((!ProgramParameters::devEvalOnGold) && pActionIsZeroCost) || (pAction == oAction))
TI.addDevSuccess(tm.getCurrentClassifier()->name);
std::string actionName;
......@@ -189,7 +180,7 @@ void Trainer::computeScoreOnDev()
if (ProgramParameters::debug)
fprintf(stderr, "Dev Config is final\n");
TI.computeDevScores();
TI.computeDevScores(*devConfig);
if (ProgramParameters::debug)
fprintf(stderr, "End of %s\n", __func__);
......@@ -202,8 +193,6 @@ void Trainer::resetAndShuffle()
if(ProgramParameters::shuffleExamples)
trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter);
TI.resetCounters();
}
void Trainer::doStepNoTrain()
......@@ -302,9 +291,7 @@ void Trainer::doStepTrain()
if (!ProgramParameters::featureExtraction)
loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction));
TI.addTrainExample(tm.getCurrentClassifier()->name, loss);
if (pActionIsZeroCost)
TI.addTrainSuccess(tm.getCurrentClassifier()->name);
TI.addTrainLoss(tm.getCurrentClassifier()->name, loss);
int k = ProgramParameters::dynamicEpoch;
......@@ -400,17 +387,13 @@ void Trainer::doStepTrain()
if (newCost >= lastCost)
{
loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON"));
if (pActionIsZeroCost)
TI.addTrainSuccess(tm.getCurrentClassifier()->name);
}
else
{
loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
TI.addTrainSuccess(tm.getCurrentClassifier()->name);
}
TI.addTrainExample(tm.getCurrentClassifier()->name, loss);
TI.addTrainLoss(tm.getCurrentClassifier()->name, loss);
}
if (ProgramParameters::debug)
......@@ -509,7 +492,7 @@ void Trainer::train()
void Trainer::printScoresAndSave(FILE * output)
{
TI.computeTrainScores();
TI.computeTrainScores(trainConfig);
computeScoreOnDev();
TI.computeMustSaves();
......
......@@ -130,6 +130,10 @@ class Config
///
/// @param index the index to mask
void maskIndex(int index);
/// @brief Compare hyp and ref to give a matching score.
///
/// @return The score as a percentage.
float getScore();
};
private :
......
......@@ -625,3 +625,16 @@ std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::stri
return actionsHistory[classifier+"_"+std::to_string(head)];
}
float Config::Tape::getScore()
{
float res = 0.0;
for (int i = 0; i < refSize(); i++)
{
if (getRef(i-head) == getHyp(i-head))
res += 1;
}
return 100.0*res / refSize();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment