Skip to content
Snippets Groups Projects
Commit d145be52 authored by Franck Dary's avatar Franck Dary
Browse files

Reseting strategy between different corpuses

parent 5cc6a43b
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@ class Strategy
std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges;
std::map<std::string, bool> isDone;
std::vector<std::string> defaultCycle;
std::vector<std::string> originalDefaultCycle;
std::string initialState{"UNDEFINED"};
private :
......@@ -33,6 +34,7 @@ class Strategy
Strategy(const std::vector<std::string_view> & lines);
std::pair<std::string, int> getMovement(const Config & c, const std::string & transition);
const std::string getInitialState() const;
void reset();
};
#endif
......@@ -41,6 +41,7 @@ Strategy::Strategy(const std::vector<std::string_view> & lines)
util::myThrow("Strategy is empty");
defaultCycle.pop_back();
std::reverse(defaultCycle.begin(), defaultCycle.end());
originalDefaultCycle = defaultCycle;
}
std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition)
......@@ -96,7 +97,7 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c
auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
std::string target;
int movement;
int movement = -1;
if (foundSpecific != edges.end())
{
......@@ -113,7 +114,7 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!isDone[target])
return {target, target == defaultCycle.back() ? movement : 0};
return {target, c.canMoveWordIndex(movement) ? movement : 0};
if (defaultCycle.empty())
return endMovement;
......@@ -126,3 +127,10 @@ const std::string Strategy::getInitialState() const
return initialState;
}
void Strategy::reset()
{
for (auto & it : isDone)
it.second = false;
defaultCycle = originalDefaultCycle;
}
......@@ -92,12 +92,14 @@ int main(int argc, char * argv[])
for (int i = 0; i < nbEpoch; i++)
{
float loss = trainer.epoch(!debug);
machine.getStrategy().reset();
auto devConfig = devGoldConfig;
if (debug)
fmt::print(stderr, "Decoding dev :\n");
else
fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
decoder.decode(devConfig, 1, debug);
machine.getStrategy().reset();
decoder.evaluate(devConfig, modelPath, devTsvFile);
std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted());
std::string devScoresStr = "";
......@@ -117,9 +119,9 @@ int main(int argc, char * argv[])
machine.save();
}
if (debug)
fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
fmt::print(stderr, "Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
else
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
fmt::print(stderr, "\r{:80}\rEpoch {:^5} loss = {:6.1f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment