Skip to content
Snippets Groups Projects
Commit 9bc1ea8d authored by Franck Dary's avatar Franck Dary
Browse files

Added movement in strategy

parent 5370f971
No related branches found
No related tags found
No related merge requests found
...@@ -18,7 +18,7 @@ class Strategy ...@@ -18,7 +18,7 @@ class Strategy
}; };
Type type; Type type;
std::map<std::pair<std::string, std::string>, std::string> edges; std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges;
std::map<std::string, bool> isDone; std::map<std::string, bool> isDone;
std::vector<std::string> defaultCycle; std::vector<std::string> defaultCycle;
std::string initialState{"UNDEFINED"}; std::string initialState{"UNDEFINED"};
......
...@@ -11,26 +11,29 @@ Strategy::Strategy(const std::vector<std::string_view> & lines) ...@@ -11,26 +11,29 @@ Strategy::Strategy(const std::vector<std::string_view> & lines)
auto splited = util::split(lines[i], ' '); auto splited = util::split(lines[i], ' ');
std::pair<std::string, std::string> key; std::pair<std::string, std::string> key;
std::string value; std::string value;
int movement;
if (splited.size() == 2) if (splited.size() == 3)
{ {
key = std::pair<std::string,std::string>(splited[0], ""); key = std::pair<std::string,std::string>(splited[0], "");
value = splited[1]; value = splited[1];
movement = std::stoi(std::string(splited[2]));
if (defaultCycle.empty()) if (defaultCycle.empty())
initialState = splited[0]; initialState = splited[0];
defaultCycle.emplace_back(value); defaultCycle.emplace_back(value);
} }
else if (splited.size() == 3) else if (splited.size() == 4)
{ {
key = std::pair<std::string,std::string>(splited[0], splited[1]); key = std::pair<std::string,std::string>(splited[0], splited[2]);
value = splited[1]; value = splited[1];
movement = std::stoi(std::string(splited[3]));
} }
else else
util::myThrow(fmt::format("Invalid strategy line '{}'", lines[i])); util::myThrow(fmt::format("Invalid strategy line '{}'", lines[i]));
if (edges.count(key)) if (edges.count(key))
util::myThrow(fmt::format("Edge {} {} defined twice", key.first, key.second)); util::myThrow(fmt::format("Edge {} {} defined twice", key.first, key.second));
edges[key] = value; edges[key] = std::make_pair(value, movement);
isDone[key.first] = false; isDone[key.first] = false;
} }
...@@ -42,6 +45,8 @@ Strategy::Strategy(const std::vector<std::string_view> & lines) ...@@ -42,6 +45,8 @@ Strategy::Strategy(const std::vector<std::string_view> & lines)
std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition) std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition)
{ {
std::string transitionPrefix(util::split(transition, ' ')[0]);
if (c.stateIsDone()) if (c.stateIsDone())
isDone[c.getState()] = true; isDone[c.getState()] = true;
...@@ -49,9 +54,9 @@ std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::s ...@@ -49,9 +54,9 @@ std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::s
defaultCycle.pop_back(); defaultCycle.pop_back();
if (type == Type::Sequential) if (type == Type::Sequential)
return getMovementSequential(c, transition); return getMovementSequential(c, transitionPrefix);
return getMovementIncremental(c, transition); return getMovementIncremental(c, transitionPrefix);
} }
std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, const std::string & transition) std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, const std::string & transition)
...@@ -60,17 +65,24 @@ std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, co ...@@ -60,17 +65,24 @@ std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, co
auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
std::string target; std::string target;
int movement;
if (foundSpecific != edges.end()) if (foundSpecific != edges.end())
target = foundSpecific->second; {
target = foundSpecific->second.first;
movement = foundSpecific->second.second;
}
else if (foundGeneric != edges.end()) else if (foundGeneric != edges.end())
target = foundGeneric->second; {
target = foundGeneric->second.first;
movement = foundGeneric->second.second;
}
if (target.empty()) if (target.empty())
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!c.stateIsDone()) if (!c.stateIsDone())
return {c.getState(), (c.getState() == target) && edges.size() > 1 ? 0 : 1}; return {c.getState(), (c.getState() == target) && edges.size() > 1 ? movement : 0};
if (!isDone[target]) if (!isDone[target])
return {target, -c.getWordIndex()}; return {target, -c.getWordIndex()};
...@@ -84,22 +96,29 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c ...@@ -84,22 +96,29 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c
auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
std::string target; std::string target;
int movement;
if (foundSpecific != edges.end()) if (foundSpecific != edges.end())
target = foundSpecific->second; {
target = foundSpecific->second.first;
movement = foundSpecific->second.second;
}
else if (foundGeneric != edges.end()) else if (foundGeneric != edges.end())
target = foundGeneric->second; {
target = foundGeneric->second.first;
movement = foundGeneric->second.second;
}
if (target.empty()) if (target.empty())
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!isDone[target]) if (!isDone[target])
return {target, target == defaultCycle.back() ? 1 : 0}; return {target, target == defaultCycle.back() ? movement : 0};
if (defaultCycle.empty()) if (defaultCycle.empty())
return endMovement; return endMovement;
return {defaultCycle.back(), 1}; return {defaultCycle.back(), movement};
} }
const std::string Strategy::getInitialState() const const std::string Strategy::getInitialState() const
......
...@@ -73,6 +73,9 @@ int main(int argc, char * argv[]) ...@@ -73,6 +73,9 @@ int main(int argc, char * argv[])
auto nbEpoch = variables["nbEpochs"].as<int>(); auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true; bool debug = variables.count("debug") == 0 ? false : true;
try
{
ReadingMachine machine(machinePath.string()); ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
...@@ -119,6 +122,9 @@ int main(int argc, char * argv[]) ...@@ -119,6 +122,9 @@ int main(int argc, char * argv[])
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
} }
}
catch(std::exception & e) {util::error(e);}
return 0; return 0;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment