Skip to content
Snippets Groups Projects
Commit 3a0f7d4f authored by Franck Dary's avatar Franck Dary
Browse files

Fixed accuracy computing

parent fa69635b
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,6 @@ class TrainInfos
std::string filename;
int lastEpoch;
int lastSaved;
int lastIndexTreated;
std::map< std::string, std::vector<float> > trainLossesPerClassifierPerEpoch;
std::map< std::string, std::vector<float> > devLossesPerClassifierPerEpoch;
std::map< std::string, std::vector<float> > trainScoresPerClassifierPerEpoch;
......@@ -36,7 +35,7 @@ class TrainInfos
void saveToFilename();
void addTrainScore(const std::string & classifier, float score);
void addDevScore(const std::string & classifier, float score);
float computeScoreOnTapes(Config & c, std::vector<std::string> tapes);
float computeScoreOnTapes(Config & c, std::vector<std::string> tapes, int from, int to);
public :
......
......@@ -143,12 +143,12 @@ void TrainInfos::addDevScore(const std::string & classifier, float score)
devScoresPerClassifierPerEpoch[classifier].emplace_back(score);
}
float TrainInfos::computeScoreOnTapes(Config & c, std::vector<std::string> tapes)
float TrainInfos::computeScoreOnTapes(Config & c, std::vector<std::string> tapes, int from, int to)
{
float res = 0.0;
for (auto & tape : tapes)
res += c.getTape(tape).getScore(0, lastIndexTreated);
res += c.getTape(tape).getScore(from, to);
return res / tapes.size();
}
......@@ -158,13 +158,13 @@ void TrainInfos::computeTrainScores(Config & c)
for (auto & it : topologyPrinted)
{
if (it.first == "Parser")
addTrainScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}));
addTrainScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead()));
else if (it.first == "Tagger")
addTrainScore(it.first, computeScoreOnTapes(c, {"POS"}));
addTrainScore(it.first, computeScoreOnTapes(c, {"POS"}, 0, c.getHead()));
else if (it.first == "Morpho")
addTrainScore(it.first, computeScoreOnTapes(c, {"MORPHO"}));
addTrainScore(it.first, computeScoreOnTapes(c, {"MORPHO"}, 0, c.getHead()));
else if (it.first == "Lemmatizer_Rules")
addTrainScore(it.first, computeScoreOnTapes(c, {"LEMMA"}));
addTrainScore(it.first, computeScoreOnTapes(c, {"LEMMA"}, 0, c.getHead()));
else if (split(it.first, '_')[0] == "Error")
addTrainScore(it.first, 100.0);
else
......@@ -180,15 +180,15 @@ void TrainInfos::computeDevScores(Config & c)
for (auto & it : topologyPrinted)
{
if (it.first == "Parser")
addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}));
addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead()));
else if (it.first == "Parser")
addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}));
addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead()));
else if (it.first == "Tagger")
addDevScore(it.first, computeScoreOnTapes(c, {"POS"}));
addDevScore(it.first, computeScoreOnTapes(c, {"POS"}, 0, c.getHead()));
else if (it.first == "Morpho")
addDevScore(it.first, computeScoreOnTapes(c, {"MORPHO"}));
addDevScore(it.first, computeScoreOnTapes(c, {"MORPHO"}, 0, c.getHead()));
else if (it.first == "Lemmatizer_Rules")
addDevScore(it.first, computeScoreOnTapes(c, {"LEMMA"}));
addDevScore(it.first, computeScoreOnTapes(c, {"LEMMA"}, 0, c.getHead()));
else if (split(it.first, '_')[0] == "Error")
addDevScore(it.first, 100.0);
else
......@@ -292,8 +292,3 @@ bool TrainInfos::mustSave(const std::string & classifier)
return mustSavePerClassifierPerEpoch.count(classifier) && mustSavePerClassifierPerEpoch[classifier].back();
}
void TrainInfos::setLastIndexTreated(int index)
{
lastIndexTreated = index;
}
......@@ -192,7 +192,6 @@ void Trainer::resetAndShuffle()
{
tm.reset();
trainConfig.reset();
TI.setLastIndexTreated(0);
if(ProgramParameters::shuffleExamples)
trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter);
......@@ -541,8 +540,6 @@ void Trainer::train()
if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize)
try {prepareNextEpoch();}
catch (EndOfTraining &) {break;}
TI.setLastIndexTreated(trainConfig.getHead());
}
if (ProgramParameters::debug)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment