diff --git a/trainer/include/TrainInfos.hpp b/trainer/include/TrainInfos.hpp index fb0de9727180bbf47ff2e76159479ef2f4b1a71e..3fb815a8901cee1543c047eb63b8853d83f9d42a 100644 --- a/trainer/include/TrainInfos.hpp +++ b/trainer/include/TrainInfos.hpp @@ -18,6 +18,7 @@ class TrainInfos std::string filename; int lastEpoch; int lastSaved; + int lastIndexTreated; std::map< std::string, std::vector<float> > trainLossesPerClassifierPerEpoch; std::map< std::string, std::vector<float> > devLossesPerClassifierPerEpoch; std::map< std::string, std::vector<float> > trainScoresPerClassifierPerEpoch; @@ -55,6 +56,7 @@ class TrainInfos void nextEpoch(); bool mustSave(const std::string & classifier); void printScores(FILE * output); + void setLastIndexTreated(int index); }; #endif diff --git a/trainer/src/TrainInfos.cpp b/trainer/src/TrainInfos.cpp index daaa46bd583543cf9c4eebaa80a1a78ed8fda250..b8ea23177648fdcf8f2e3ae92c33c86c2cbcb721 100644 --- a/trainer/src/TrainInfos.cpp +++ b/trainer/src/TrainInfos.cpp @@ -148,7 +148,7 @@ float TrainInfos::computeScoreOnTapes(Config & c, std::vector<std::string> tapes float res = 0.0; for (auto & tape : tapes) - res += c.getTape(tape).getScore(); + res += c.getTape(tape).getScore(0, lastIndexTreated); return res / tapes.size(); } @@ -292,3 +292,8 @@ bool TrainInfos::mustSave(const std::string & classifier) return mustSavePerClassifierPerEpoch.count(classifier) && mustSavePerClassifierPerEpoch[classifier].back(); } +void TrainInfos::setLastIndexTreated(int index) +{ + lastIndexTreated = index; +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 56a2a4997cae7cd785fadbdee2ee3fa4d280b7bb..7d847ddbb81aaffae55105b27ee858ab68557817 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -192,6 +192,7 @@ void Trainer::resetAndShuffle() { tm.reset(); trainConfig.reset(); + TI.setLastIndexTreated(0); if(ProgramParameters::shuffleExamples) trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter); @@ -540,6 +541,8 @@ void Trainer::train() if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize) try {prepareNextEpoch();} catch (EndOfTraining &) {break;} + + TI.setLastIndexTreated(trainConfig.getHead()); } if (ProgramParameters::debug) diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index 7d1e71903b0ed6db6f0ba75fc7e8e8245ff50fed..fc7eefc2649d267dec58abdbb03c6f46bdf6db5a 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -132,8 +132,11 @@ class Config void maskIndex(int index); /// @brief Compare hyp and ref to give a matching score. /// + /// @param from first index to evaluate + /// @param to last index to evaluate + /// /// @return The score as a percentage. - float getScore(); + float getScore(int from, int to); }; private : diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index ff3c2bcf830f28483c30c7ff8e9a2a041a9c21da..10fe33b8e40d55d7225e680dc72884ef2d25404c 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -647,14 +647,14 @@ std::vector< std::pair<std::string, int> > & Config::getActionsHistory(std::stri return actionsHistory[state+"_"+std::to_string(head)]; } -float Config::Tape::getScore() +float Config::Tape::getScore(int from, int to) { float res = 0.0; - for (int i = 0; i < refSize()-1; i++) + for (int i = from; i <= to; i++) if (getRef(i-head) == getHyp(i-head)) res += 1; - return 100.0*res / (refSize()-1); + return 100.0*res / (1+to-from); }