From d145be52a13ddda6a522338a2edc08ab777ea579 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 25 Feb 2020 23:09:43 +0100 Subject: [PATCH] Reseting strategy between different corpuses --- reading_machine/include/Strategy.hpp | 2 ++ reading_machine/src/Strategy.cpp | 12 ++++++++++-- trainer/src/macaon_train.cpp | 6 ++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp index 46946ab..9c4edd5 100644 --- a/reading_machine/include/Strategy.hpp +++ b/reading_machine/include/Strategy.hpp @@ -21,6 +21,7 @@ class Strategy std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges; std::map<std::string, bool> isDone; std::vector<std::string> defaultCycle; + std::vector<std::string> originalDefaultCycle; std::string initialState{"UNDEFINED"}; private : @@ -33,6 +34,7 @@ class Strategy Strategy(const std::vector<std::string_view> & lines); std::pair<std::string, int> getMovement(const Config & c, const std::string & transition); const std::string getInitialState() const; + void reset(); }; #endif diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index ee1efa6..a13ac5e 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -41,6 +41,7 @@ Strategy::Strategy(const std::vector<std::string_view> & lines) util::myThrow("Strategy is empty"); defaultCycle.pop_back(); std::reverse(defaultCycle.begin(), defaultCycle.end()); + originalDefaultCycle = defaultCycle; } std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition) @@ -96,7 +97,7 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); std::string target; - int movement; + int movement = -1; if (foundSpecific != edges.end()) { @@ -113,7 +114,7 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); if (!isDone[target]) - return {target, target == defaultCycle.back() ? movement : 0}; + return {target, c.canMoveWordIndex(movement) ? movement : 0}; if (defaultCycle.empty()) return endMovement; @@ -126,3 +127,10 @@ const std::string Strategy::getInitialState() const return initialState; } +void Strategy::reset() +{ + for (auto & it : isDone) + it.second = false; + defaultCycle = originalDefaultCycle; +} + diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 936d16c..34a009a 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -92,12 +92,14 @@ int main(int argc, char * argv[]) for (int i = 0; i < nbEpoch; i++) { float loss = trainer.epoch(!debug); + machine.getStrategy().reset(); auto devConfig = devGoldConfig; if (debug) fmt::print(stderr, "Decoding dev :\n"); else fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); decoder.decode(devConfig, 1, debug); + machine.getStrategy().reset(); decoder.evaluate(devConfig, modelPath, devTsvFile); std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted()); std::string devScoresStr = ""; @@ -117,9 +119,9 @@ int main(int argc, char * argv[]) machine.save(); } if (debug) - fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); + fmt::print(stderr, "Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); else - fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); + fmt::print(stderr, "\r{:80}\rEpoch {:^5} loss = {:6.1f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); } } -- GitLab