#include <boost/program_options.hpp> #include <filesystem> #include "util.hpp" #include "Trainer.hpp" #include "Decoder.hpp" namespace po = boost::program_options; po::options_description getOptionsDescription() { po::options_description desc("Command-Line Arguments "); po::options_description req("Required"); req.add_options() ("expName", po::value<std::string>()->required(), "Name of this experiment") ("model", po::value<std::string>()->required(), "Directory containing the machine file to train") ("trainTSV", po::value<std::string>()->required(), "TSV file of the training corpus, in CONLLU format"); po::options_description opt("Optional"); opt.add_options() ("trainTXT", po::value<std::string>()->default_value(""), "Raw text file of the training corpus") ("devTSV", po::value<std::string>()->default_value(""), "TSV file of the development corpus, in CONLLU format") ("devTXT", po::value<std::string>()->default_value(""), "Raw text file of the development corpus") ("nbEpochs,n", po::value<int>()->default_value(5), "Number of training epochs") ("help,h", "Produce this help message"); desc.add(req).add(opt); return desc; } po::variables_map checkOptions(po::options_description & od, int argc, char ** argv) { po::variables_map vm; try {po::store(po::parse_command_line(argc, argv, od), vm);} catch(std::exception & e) {util::myThrow(e.what());} if (vm.count("help")) { std::stringstream ss; ss << od; fmt::print(stderr, "{}\n", ss.str()); exit(0); } try {po::notify(vm);} catch(std::exception& e) {util::myThrow(e.what());} return vm; } int main(int argc, char * argv[]) { auto od = getOptionsDescription(); auto variables = checkOptions(od, argc, argv); auto expName = variables["expName"].as<std::string>(); std::filesystem::path modelPath(variables["model"].as<std::string>()); auto machinePath = modelPath / "machine.rm"; auto mcdFile = variables["mcd"].as<std::string>(); auto trainTsvFile = variables["trainTSV"].as<std::string>(); auto trainRawFile = variables["trainTXT"].as<std::string>(); auto devTsvFile = variables["devTSV"].as<std::string>(); auto devRawFile = variables["devTXT"].as<std::string>(); auto nbEpoch = variables["nbEpochs"].as<int>(); ReadingMachine machine(machinePath.string()); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); SubConfig config(goldConfig); Trainer trainer(machine); trainer.createDataset(config); Decoder decoder(machine); BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); for (int i = 0; i < nbEpoch; i++) { float loss = trainer.epoch(); auto devConfig = devGoldConfig; decoder.decode(devConfig, 1); decoder.evaluate(devConfig, modelPath, devTsvFile); fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS")); } return 0; }