Skip to content
Snippets Groups Projects
Commit 15787f04 authored by Franck Dary's avatar Franck Dary
Browse files

Eval on dev during training

parent 9d7a334b
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize)
{
config.setState(machine.getStrategy().getInitialState());
fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
while (true)
{
auto dictState = machine.getDict(config.getState()).getState();
......
......@@ -10,40 +10,41 @@
int main(int argc, char * argv[])
{
if (argc != 5)
if (argc != 8)
{
fmt::print(stderr, "needs 4 arguments.\n");
fmt::print(stderr, "needs 7 arguments.\n");
exit(1);
}
std::string model = argv[1];
std::string mcdFile = argv[2];
std::string tsvFile = argv[3];
//std::string rawFile = argv[4];
std::string rawFile = "";
std::string trainTsvFile = argv[3];
std::string trainRawFile = "";
std::string devTsvFile = argv[5];
std::string devRawFile = "";
int nbEpoch = std::stoi(argv[7]);
std::filesystem::path modelPath(model);
auto machinePath = modelPath / "machine.rm";
ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
SubConfig config(goldConfig);
Trainer trainer(machine);
trainer.createDataset(config);
Decoder decoder(machine);
int nbEpoch = 10;
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
for (int i = 0; i < nbEpoch; i++)
{
float loss = trainer.epoch();
auto devConfig = goldConfig;
auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1);
decoder.evaluate(devConfig, modelPath, tsvFile);
fmt::print(stderr, "\r{:80}\rEpoch {}/{} loss = {} dev = {}\n", " ", i+1, nbEpoch, loss, decoder.getF1Score("UPOS"));
decoder.evaluate(devConfig, modelPath, devTsvFile);
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS"));
}
return 0;
......
......@@ -16,7 +16,10 @@ void Trainer::createDataset(SubConfig & config)
{
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
//TODO : check if clone is mandatory
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
......@@ -37,7 +40,10 @@ void Trainer::createDataset(SubConfig & config)
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
{
config.printForDebug(stderr);
util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
}
if (config.needsUpdate())
config.update();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment