Skip to content
Snippets Groups Projects
Commit 5370f971 authored by Franck Dary's avatar Franck Dary
Browse files

During training, dev scores correspond to what column machine is predictiong

parent af973750
No related branches found
No related tags found
No related merge requests found
......@@ -12,16 +12,25 @@ class Decoder
ReadingMachine & machine;
std::map<std::string, std::array<float,4>> evaluation;
private :
std::string getMetricOfColName(const std::string & colName) const;
std::vector<float> getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const;
float getMetricScore(const std::string & metric, std::size_t scoreIndex) const;
float getPrecision(const std::string & metric) const;
float getF1Score(const std::string & metric) const;
float getRecall(const std::string & metric) const;
float getAlignedAcc(const std::string & metric) const;
public :
Decoder(ReadingMachine & machine);
void decode(BaseConfig & config, std::size_t beamSize, bool debug);
void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV);
float getMetricScore(const std::string & metric, std::size_t scoreIndex);
float getPrecision(const std::string & metric);
float getF1Score(const std::string & metric);
float getRecall(const std::string & metric);
float getAlignedAcc(const std::string & metric);
std::vector<float> getF1Scores(const std::set<std::string> & colNames) const;
std::vector<float> getAlignedAccs(const std::set<std::string> & colNames) const;
std::vector<float> getRecalls(const std::set<std::string> & colNames) const;
std::vector<float> getPrecisions(const std::set<std::string> & colNames) const;
};
#endif
......@@ -50,7 +50,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
} catch(std::exception & e) {util::myThrow(e.what());}
}
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex)
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
{
auto found = evaluation.find(metric);
......@@ -60,26 +60,68 @@ float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex
return found->second[scoreIndex];
}
float Decoder::getPrecision(const std::string & metric)
float Decoder::getPrecision(const std::string & metric) const
{
return getMetricScore(metric, 0);
}
float Decoder::getRecall(const std::string & metric)
float Decoder::getRecall(const std::string & metric) const
{
return getMetricScore(metric, 1);
}
float Decoder::getF1Score(const std::string & metric)
float Decoder::getF1Score(const std::string & metric) const
{
return getMetricScore(metric, 2);
}
float Decoder::getAlignedAcc(const std::string & metric)
float Decoder::getAlignedAcc(const std::string & metric) const
{
return getMetricScore(metric, 3);
}
std::vector<float> Decoder::getF1Scores(const std::set<std::string> & colNames) const
{
return getScores(colNames, &Decoder::getF1Score);
}
std::vector<float> Decoder::getAlignedAccs(const std::set<std::string> & colNames) const
{
return getScores(colNames, &Decoder::getAlignedAcc);
}
std::vector<float> Decoder::getRecalls(const std::set<std::string> & colNames) const
{
return getScores(colNames, &Decoder::getRecall);
}
std::vector<float> Decoder::getPrecisions(const std::set<std::string> & colNames) const
{
return getScores(colNames, &Decoder::getPrecision);
}
std::vector<float> Decoder::getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const
{
std::vector<float> scores;
for (auto & colName : colNames)
scores.push_back((this->*metric2score)(getMetricOfColName(colName)));
return scores;
}
std::string Decoder::getMetricOfColName(const std::string & colName) const
{
if (colName == "HEAD")
return "UAS";
if (colName == "DEPREL")
return "LAS";
if (colName == "EOS")
return "Sentences";
return colName;
}
void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV)
{
evaluation.clear();
......
......@@ -96,17 +96,27 @@ int main(int argc, char * argv[])
fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
decoder.decode(devConfig, 1, debug);
decoder.evaluate(devConfig, modelPath, devTsvFile);
float devScore = decoder.getF1Score("UPOS");
bool saved = devScore > bestDevScore;
std::vector<float> devScores = decoder.getF1Scores(machine.getPredicted());
std::string devScoresStr = "";
float devScoreMean = 0;
for (auto & score : devScores)
{
devScoresStr += fmt::format("{:5.2f}%,", score);
devScoreMean += score;
}
if (!devScoresStr.empty())
devScoresStr.pop_back();
devScoreMean /= devScores.size();
bool saved = devScoreMean > bestDevScore;
if (saved)
{
bestDevScore = devScore;
bestDevScore = devScoreMean;
machine.save();
}
if (debug)
fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
else
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
}
return 0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment