diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 136a545c20be789d90566352be52aefc5b231201..3191cbf054860846afd27ed39e8ee503fbfe0586 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -9,6 +9,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize) { config.setState(machine.getStrategy().getInitialState()); + fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); + while (true) { auto dictState = machine.getDict(config.getState()).getState(); diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 51b1381aca77ef307afc5353bd78281abfe5b5af..d6b94334c8adc7f0a1d78acb42c718911205d571 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -10,40 +10,41 @@ int main(int argc, char * argv[]) { - if (argc != 5) + if (argc != 8) { - fmt::print(stderr, "needs 4 arguments.\n"); + fmt::print(stderr, "needs 7 arguments.\n"); exit(1); } std::string model = argv[1]; std::string mcdFile = argv[2]; - std::string tsvFile = argv[3]; - //std::string rawFile = argv[4]; - std::string rawFile = ""; + std::string trainTsvFile = argv[3]; + std::string trainRawFile = ""; + std::string devTsvFile = argv[5]; + std::string devRawFile = ""; + int nbEpoch = std::stoi(argv[7]); std::filesystem::path modelPath(model); auto machinePath = modelPath / "machine.rm"; ReadingMachine machine(machinePath.string()); - BaseConfig goldConfig(mcdFile, tsvFile, rawFile); + BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); SubConfig config(goldConfig); Trainer trainer(machine); trainer.createDataset(config); Decoder decoder(machine); - - int nbEpoch = 10; + BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); for (int i = 0; i < nbEpoch; i++) { float loss = trainer.epoch(); - auto devConfig = goldConfig; + auto devConfig = devGoldConfig; decoder.decode(devConfig, 1); - decoder.evaluate(devConfig, modelPath, tsvFile); - fmt::print(stderr, "\r{:80}\rEpoch {}/{} loss = {} dev = {}\n", " ", i+1, nbEpoch, loss, decoder.getF1Score("UPOS")); + decoder.evaluate(devConfig, modelPath, devTsvFile); + fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS")); } return 0; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 65fb9b98e061d22107b46352375977183dea70c5..059da2baacf46fb17d61facbcf15286655e48031 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -16,7 +16,10 @@ void Trainer::createDataset(SubConfig & config) { auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) + { + config.printForDebug(stderr); util::myThrow("No transition appliable !"); + } //TODO : check if clone is mandatory auto context = config.extractContext(5,5,machine.getDict(config.getState())); @@ -37,7 +40,10 @@ void Trainer::createDataset(SubConfig & config) config.setState(movement.first); if (!config.moveWordIndex(movement.second)) - util::myThrow("Cannot move word index !"); + { + config.printForDebug(stderr); + util::myThrow(fmt::format("Cannot move word index by {}", movement.second)); + } if (config.needsUpdate()) config.update();