Skip to content
Snippets Groups Projects
Commit 15787f04 authored by Franck Dary's avatar Franck Dary
Browse files

Eval on dev during training

parent 9d7a334b
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize) ...@@ -9,6 +9,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize)
{ {
config.setState(machine.getStrategy().getInitialState()); config.setState(machine.getStrategy().getInitialState());
fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
while (true) while (true)
{ {
auto dictState = machine.getDict(config.getState()).getState(); auto dictState = machine.getDict(config.getState()).getState();
......
...@@ -10,40 +10,41 @@ ...@@ -10,40 +10,41 @@
int main(int argc, char * argv[]) int main(int argc, char * argv[])
{ {
if (argc != 5) if (argc != 8)
{ {
fmt::print(stderr, "needs 4 arguments.\n"); fmt::print(stderr, "needs 7 arguments.\n");
exit(1); exit(1);
} }
std::string model = argv[1]; std::string model = argv[1];
std::string mcdFile = argv[2]; std::string mcdFile = argv[2];
std::string tsvFile = argv[3]; std::string trainTsvFile = argv[3];
//std::string rawFile = argv[4]; std::string trainRawFile = "";
std::string rawFile = ""; std::string devTsvFile = argv[5];
std::string devRawFile = "";
int nbEpoch = std::stoi(argv[7]);
std::filesystem::path modelPath(model); std::filesystem::path modelPath(model);
auto machinePath = modelPath / "machine.rm"; auto machinePath = modelPath / "machine.rm";
ReadingMachine machine(machinePath.string()); ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, tsvFile, rawFile); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
SubConfig config(goldConfig); SubConfig config(goldConfig);
Trainer trainer(machine); Trainer trainer(machine);
trainer.createDataset(config); trainer.createDataset(config);
Decoder decoder(machine); Decoder decoder(machine);
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
int nbEpoch = 10;
for (int i = 0; i < nbEpoch; i++) for (int i = 0; i < nbEpoch; i++)
{ {
float loss = trainer.epoch(); float loss = trainer.epoch();
auto devConfig = goldConfig; auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1); decoder.decode(devConfig, 1);
decoder.evaluate(devConfig, modelPath, tsvFile); decoder.evaluate(devConfig, modelPath, devTsvFile);
fmt::print(stderr, "\r{:80}\rEpoch {}/{} loss = {} dev = {}\n", " ", i+1, nbEpoch, loss, decoder.getF1Score("UPOS")); fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS"));
} }
return 0; return 0;
......
...@@ -16,7 +16,10 @@ void Trainer::createDataset(SubConfig & config) ...@@ -16,7 +16,10 @@ void Trainer::createDataset(SubConfig & config)
{ {
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition) if (!transition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !"); util::myThrow("No transition appliable !");
}
//TODO : check if clone is mandatory //TODO : check if clone is mandatory
auto context = config.extractContext(5,5,machine.getDict(config.getState())); auto context = config.extractContext(5,5,machine.getDict(config.getState()));
...@@ -37,7 +40,10 @@ void Trainer::createDataset(SubConfig & config) ...@@ -37,7 +40,10 @@ void Trainer::createDataset(SubConfig & config)
config.setState(movement.first); config.setState(movement.first);
if (!config.moveWordIndex(movement.second)) if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !"); {
config.printForDebug(stderr);
util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
}
if (config.needsUpdate()) if (config.needsUpdate())
config.update(); config.update();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment