#include <boost/program_options.hpp> #include <filesystem> #include "util.hpp" #include "Trainer.hpp" #include "Decoder.hpp" namespace po = boost::program_options; po::options_description getOptionsDescription() { po::options_description desc("Command-Line Arguments "); po::options_description req("Required"); req.add_options() ("model", po::value<std::string>()->required(), "Directory containing the machine file to train") ("mcd", po::value<std::string>()->required(), "Multi Column Description file that describes the input format") ("trainTSV", po::value<std::string>()->required(), "TSV file of the training corpus, in CONLLU format"); po::options_description opt("Optional"); opt.add_options() ("debug,d", "Print debuging infos on stderr") ("trainTXT", po::value<std::string>()->default_value(""), "Raw text file of the training corpus") ("devTSV", po::value<std::string>()->default_value(""), "TSV file of the development corpus, in CONLLU format") ("devTXT", po::value<std::string>()->default_value(""), "Raw text file of the development corpus") ("nbEpochs,n", po::value<int>()->default_value(5), "Number of training epochs") ("help,h", "Produce this help message"); desc.add(req).add(opt); return desc; } po::variables_map checkOptions(po::options_description & od, int argc, char ** argv) { po::variables_map vm; try {po::store(po::parse_command_line(argc, argv, od), vm);} catch(std::exception & e) {util::myThrow(e.what());} if (vm.count("help")) { std::stringstream ss; ss << od; fmt::print(stderr, "{}\n", ss.str()); exit(0); } try {po::notify(vm);} catch(std::exception& e) {util::myThrow(e.what());} return vm; } int main(int argc, char * argv[]) { auto od = getOptionsDescription(); auto variables = checkOptions(od, argc, argv); std::filesystem::path modelPath(variables["model"].as<std::string>()); auto machinePath = modelPath / "machine.rm"; auto mcdFile = variables["mcd"].as<std::string>(); auto trainTsvFile = variables["trainTSV"].as<std::string>(); auto trainRawFile = variables["trainTXT"].as<std::string>(); auto devTsvFile = variables["devTSV"].as<std::string>(); auto devRawFile = variables["devTXT"].as<std::string>(); auto nbEpoch = variables["nbEpochs"].as<int>(); bool debug = variables.count("debug") == 0 ? false : true; try { ReadingMachine machine(machinePath.string()); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); SubConfig config(goldConfig); Trainer trainer(machine); trainer.createDataset(config, debug); Decoder decoder(machine); BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); float bestDevScore = 0; for (int i = 0; i < nbEpoch; i++) { float loss = trainer.epoch(!debug); machine.getStrategy().reset(); auto devConfig = devGoldConfig; if (debug) fmt::print(stderr, "Decoding dev :\n"); else fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); decoder.decode(devConfig, 1, debug); machine.getStrategy().reset(); decoder.evaluate(devConfig, modelPath, devTsvFile); std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted()); std::string devScoresStr = ""; float devScoreMean = 0; for (auto & score : devScores) { devScoresStr += fmt::format("{}({:5.2f}%),", score.second, score.first); devScoreMean += score.first; } if (!devScoresStr.empty()) devScoresStr.pop_back(); devScoreMean /= devScores.size(); bool saved = devScoreMean > bestDevScore; if (saved) { bestDevScore = devScoreMean; machine.save(); } if (debug) fmt::print(stderr, "Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); else fmt::print(stderr, "\r{:80}\rEpoch {:^5} loss = {:6.1f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); } } catch(std::exception & e) {util::error(e);} return 0; }