Skip to content
Snippets Groups Projects
Commit f97310f3 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed accuracy computing when using iterationSize argument

parent f6f90445
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ class TrainInfos ...@@ -18,6 +18,7 @@ class TrainInfos
std::string filename; std::string filename;
int lastEpoch; int lastEpoch;
int lastSaved; int lastSaved;
int lastIndexTreated;
std::map< std::string, std::vector<float> > trainLossesPerClassifierPerEpoch; std::map< std::string, std::vector<float> > trainLossesPerClassifierPerEpoch;
std::map< std::string, std::vector<float> > devLossesPerClassifierPerEpoch; std::map< std::string, std::vector<float> > devLossesPerClassifierPerEpoch;
std::map< std::string, std::vector<float> > trainScoresPerClassifierPerEpoch; std::map< std::string, std::vector<float> > trainScoresPerClassifierPerEpoch;
...@@ -55,6 +56,7 @@ class TrainInfos ...@@ -55,6 +56,7 @@ class TrainInfos
void nextEpoch(); void nextEpoch();
bool mustSave(const std::string & classifier); bool mustSave(const std::string & classifier);
void printScores(FILE * output); void printScores(FILE * output);
void setLastIndexTreated(int index);
}; };
#endif #endif
...@@ -148,7 +148,7 @@ float TrainInfos::computeScoreOnTapes(Config & c, std::vector<std::string> tapes ...@@ -148,7 +148,7 @@ float TrainInfos::computeScoreOnTapes(Config & c, std::vector<std::string> tapes
float res = 0.0; float res = 0.0;
for (auto & tape : tapes) for (auto & tape : tapes)
res += c.getTape(tape).getScore(); res += c.getTape(tape).getScore(0, lastIndexTreated);
return res / tapes.size(); return res / tapes.size();
} }
...@@ -292,3 +292,8 @@ bool TrainInfos::mustSave(const std::string & classifier) ...@@ -292,3 +292,8 @@ bool TrainInfos::mustSave(const std::string & classifier)
return mustSavePerClassifierPerEpoch.count(classifier) && mustSavePerClassifierPerEpoch[classifier].back(); return mustSavePerClassifierPerEpoch.count(classifier) && mustSavePerClassifierPerEpoch[classifier].back();
} }
void TrainInfos::setLastIndexTreated(int index)
{
lastIndexTreated = index;
}
...@@ -192,6 +192,7 @@ void Trainer::resetAndShuffle() ...@@ -192,6 +192,7 @@ void Trainer::resetAndShuffle()
{ {
tm.reset(); tm.reset();
trainConfig.reset(); trainConfig.reset();
TI.setLastIndexTreated(0);
if(ProgramParameters::shuffleExamples) if(ProgramParameters::shuffleExamples)
trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter); trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter);
...@@ -540,6 +541,8 @@ void Trainer::train() ...@@ -540,6 +541,8 @@ void Trainer::train()
if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize) if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize)
try {prepareNextEpoch();} try {prepareNextEpoch();}
catch (EndOfTraining &) {break;} catch (EndOfTraining &) {break;}
TI.setLastIndexTreated(trainConfig.getHead());
} }
if (ProgramParameters::debug) if (ProgramParameters::debug)
......
...@@ -132,8 +132,11 @@ class Config ...@@ -132,8 +132,11 @@ class Config
void maskIndex(int index); void maskIndex(int index);
/// @brief Compare hyp and ref to give a matching score. /// @brief Compare hyp and ref to give a matching score.
/// ///
/// @param from first index to evaluate
/// @param to last index to evaluate
///
/// @return The score as a percentage. /// @return The score as a percentage.
float getScore(); float getScore(int from, int to);
}; };
private : private :
......
...@@ -647,14 +647,14 @@ std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::stri ...@@ -647,14 +647,14 @@ std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::stri
return actionsHistory[state+"_"+std::to_string(head)]; return actionsHistory[state+"_"+std::to_string(head)];
} }
float Config::Tape::getScore() float Config::Tape::getScore(int from, int to)
{ {
float res = 0.0; float res = 0.0;
for (int i = 0; i < refSize()-1; i++) for (int i = from; i <= to; i++)
if (getRef(i-head) == getHyp(i-head)) if (getRef(i-head) == getHyp(i-head))
res += 1; res += 1;
return 100.0*res / (refSize()-1); return 100.0*res / (1+to-from);
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment