Skip to content
Snippets Groups Projects
Commit 9bc1ea8d authored by Franck Dary's avatar Franck Dary
Browse files

Added movement in strategy

parent 5370f971
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ class Strategy
};
Type type;
std::map<std::pair<std::string, std::string>, std::string> edges;
std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges;
std::map<std::string, bool> isDone;
std::vector<std::string> defaultCycle;
std::string initialState{"UNDEFINED"};
......
......@@ -11,26 +11,29 @@ Strategy::Strategy(const std::vector<std::string_view> & lines)
auto splited = util::split(lines[i], ' ');
std::pair<std::string, std::string> key;
std::string value;
int movement;
if (splited.size() == 2)
if (splited.size() == 3)
{
key = std::pair<std::string,std::string>(splited[0], "");
value = splited[1];
movement = std::stoi(std::string(splited[2]));
if (defaultCycle.empty())
initialState = splited[0];
defaultCycle.emplace_back(value);
}
else if (splited.size() == 3)
else if (splited.size() == 4)
{
key = std::pair<std::string,std::string>(splited[0], splited[1]);
key = std::pair<std::string,std::string>(splited[0], splited[2]);
value = splited[1];
movement = std::stoi(std::string(splited[3]));
}
else
util::myThrow(fmt::format("Invalid strategy line '{}'", lines[i]));
if (edges.count(key))
util::myThrow(fmt::format("Edge {} {} defined twice", key.first, key.second));
edges[key] = value;
edges[key] = std::make_pair(value, movement);
isDone[key.first] = false;
}
......@@ -42,6 +45,8 @@ Strategy::Strategy(const std::vector<std::string_view> & lines)
std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition)
{
std::string transitionPrefix(util::split(transition, ' ')[0]);
if (c.stateIsDone())
isDone[c.getState()] = true;
......@@ -49,9 +54,9 @@ std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::s
defaultCycle.pop_back();
if (type == Type::Sequential)
return getMovementSequential(c, transition);
return getMovementSequential(c, transitionPrefix);
return getMovementIncremental(c, transition);
return getMovementIncremental(c, transitionPrefix);
}
std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, const std::string & transition)
......@@ -60,17 +65,24 @@ std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, co
auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
std::string target;
int movement;
if (foundSpecific != edges.end())
target = foundSpecific->second;
{
target = foundSpecific->second.first;
movement = foundSpecific->second.second;
}
else if (foundGeneric != edges.end())
target = foundGeneric->second;
{
target = foundGeneric->second.first;
movement = foundGeneric->second.second;
}
if (target.empty())
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!c.stateIsDone())
return {c.getState(), (c.getState() == target) && edges.size() > 1 ? 0 : 1};
return {c.getState(), (c.getState() == target) && edges.size() > 1 ? movement : 0};
if (!isDone[target])
return {target, -c.getWordIndex()};
......@@ -84,22 +96,29 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c
auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
std::string target;
int movement;
if (foundSpecific != edges.end())
target = foundSpecific->second;
{
target = foundSpecific->second.first;
movement = foundSpecific->second.second;
}
else if (foundGeneric != edges.end())
target = foundGeneric->second;
{
target = foundGeneric->second.first;
movement = foundGeneric->second.second;
}
if (target.empty())
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!isDone[target])
return {target, target == defaultCycle.back() ? 1 : 0};
return {target, target == defaultCycle.back() ? movement : 0};
if (defaultCycle.empty())
return endMovement;
return {defaultCycle.back(), 1};
return {defaultCycle.back(), movement};
}
const std::string Strategy::getInitialState() const
......
......@@ -73,6 +73,9 @@ int main(int argc, char * argv[])
auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
try
{
ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
......@@ -119,6 +122,9 @@ int main(int argc, char * argv[])
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
}
}
catch(std::exception & e) {util::error(e);}
return 0;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment