From 9bc1ea8d7c65733a06029b5112fc54abdda2b5b4 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 14 Feb 2020 13:25:32 +0100 Subject: [PATCH] Added movement in strategy --- reading_machine/include/Strategy.hpp | 2 +- reading_machine/src/Strategy.cpp | 45 ++++++++++++++++++++-------- trainer/src/macaon_train.cpp | 6 ++++ 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp index 10d9628..46946ab 100644 --- a/reading_machine/include/Strategy.hpp +++ b/reading_machine/include/Strategy.hpp @@ -18,7 +18,7 @@ class Strategy }; Type type; - std::map<std::pair<std::string, std::string>, std::string> edges; + std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges; std::map<std::string, bool> isDone; std::vector<std::string> defaultCycle; std::string initialState{"UNDEFINED"}; diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index ca045f1..ea537e0 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -11,26 +11,29 @@ Strategy::Strategy(const std::vector<std::string_view> & lines) auto splited = util::split(lines[i], ' '); std::pair<std::string, std::string> key; std::string value; + int movement; - if (splited.size() == 2) + if (splited.size() == 3) { key = std::pair<std::string,std::string>(splited[0], ""); value = splited[1]; + movement = std::stoi(std::string(splited[2])); if (defaultCycle.empty()) initialState = splited[0]; defaultCycle.emplace_back(value); } - else if (splited.size() == 3) + else if (splited.size() == 4) { - key = std::pair<std::string,std::string>(splited[0], splited[1]); + key = std::pair<std::string,std::string>(splited[0], splited[2]); value = splited[1]; + movement = std::stoi(std::string(splited[3])); } else util::myThrow(fmt::format("Invalid strategy line '{}'", lines[i])); if (edges.count(key)) util::myThrow(fmt::format("Edge {} {} defined twice", key.first, key.second)); - edges[key] = value; + edges[key] = std::make_pair(value, movement); isDone[key.first] = false; } @@ -42,6 +45,8 @@ Strategy::Strategy(const std::vector<std::string_view> & lines) std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition) { + std::string transitionPrefix(util::split(transition, ' ')[0]); + if (c.stateIsDone()) isDone[c.getState()] = true; @@ -49,9 +54,9 @@ std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::s defaultCycle.pop_back(); if (type == Type::Sequential) - return getMovementSequential(c, transition); + return getMovementSequential(c, transitionPrefix); - return getMovementIncremental(c, transition); + return getMovementIncremental(c, transitionPrefix); } std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, const std::string & transition) @@ -60,17 +65,24 @@ std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, co auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); std::string target; + int movement; if (foundSpecific != edges.end()) - target = foundSpecific->second; + { + target = foundSpecific->second.first; + movement = foundSpecific->second.second; + } else if (foundGeneric != edges.end()) - target = foundGeneric->second; + { + target = foundGeneric->second.first; + movement = foundGeneric->second.second; + } if (target.empty()) util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); if (!c.stateIsDone()) - return {c.getState(), (c.getState() == target) && edges.size() > 1 ? 0 : 1}; + return {c.getState(), (c.getState() == target) && edges.size() > 1 ? movement : 0}; if (!isDone[target]) return {target, -c.getWordIndex()}; @@ -84,22 +96,29 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); std::string target; + int movement; if (foundSpecific != edges.end()) - target = foundSpecific->second; + { + target = foundSpecific->second.first; + movement = foundSpecific->second.second; + } else if (foundGeneric != edges.end()) - target = foundGeneric->second; + { + target = foundGeneric->second.first; + movement = foundGeneric->second.second; + } if (target.empty()) util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); if (!isDone[target]) - return {target, target == defaultCycle.back() ? 1 : 0}; + return {target, target == defaultCycle.back() ? movement : 0}; if (defaultCycle.empty()) return endMovement; - return {defaultCycle.back(), 1}; + return {defaultCycle.back(), movement}; } const std::string Strategy::getInitialState() const diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index fff80b3..f762c3d 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -73,6 +73,9 @@ int main(int argc, char * argv[]) auto nbEpoch = variables["nbEpochs"].as<int>(); bool debug = variables.count("debug") == 0 ? false : true; + try + { + ReadingMachine machine(machinePath.string()); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); @@ -119,6 +122,9 @@ int main(int argc, char * argv[]) fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); } + } + catch(std::exception & e) {util::error(e);} + return 0; } -- GitLab