#include "MacaonTrain.hpp" #include <filesystem> #include "util.hpp" #include "NeuralNetwork.hpp" namespace po = boost::program_options; po::options_description MacaonTrain::getOptionsDescription() { po::options_description desc("Command-Line Arguments "); po::options_description req("Required"); req.add_options() ("model", po::value<std::string>()->required(), "Directory containing the machine file to train") ("trainTSV", po::value<std::string>()->required(), "TSV file of the training corpus, in CONLLU format"); po::options_description opt("Optional"); opt.add_options() ("debug,d", "Print debuging infos on stderr") ("silent", "Don't print speed and progress") ("devScore", "Compute score on dev instead of loss (slower)") ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"), "Comma separated column names that describes the input/output format") ("trainTXT", po::value<std::string>()->default_value(""), "Raw text file of the training corpus") ("devTSV", po::value<std::string>()->default_value(""), "TSV file of the development corpus, in CONLLU format") ("devTXT", po::value<std::string>()->default_value(""), "Raw text file of the development corpus") ("nbEpochs,n", po::value<int>()->default_value(5), "Number of training epochs") ("batchSize", po::value<int>()->default_value(64), "Number of examples per batch") ("explorationThreshold", po::value<float>()->default_value(0.1), "Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset") ("machine", po::value<std::string>()->default_value(""), "Reading machine file content") ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"), "Description of what should happen during training") ("loss", po::value<std::string>()->default_value("CrossEntropy"), "Loss function to use during training : CrossEntropy | bce | mse | hinge") ("help,h", "Produce this help message"); desc.add(req).add(opt); return desc; } po::variables_map MacaonTrain::checkOptions(po::options_description & od) { po::variables_map vm; try {po::store(po::parse_command_line(argc, argv, od), vm);} catch(std::exception & e) {util::myThrow(e.what());} if (vm.count("help")) { std::stringstream ss; ss << od; fmt::print(stderr, "{}\n", ss.str()); exit(0); } try {po::notify(vm);} catch(std::exception& e) {util::myThrow(e.what());} return vm; } Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s) { Trainer::TrainStrategy ts; try { auto splited = util::split(s, ':'); for (auto & ss : splited) { auto elements = util::split(ss, ','); int epoch = std::stoi(elements[0]); for (unsigned int i = 1; i < elements.size(); i++) ts[epoch].insert(Trainer::str2TrainAction(elements[i])); } } catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));} return ts; } template < typename Optimizer = torch::optim::Adam, typename OptimizerOptions = torch::optim::AdamOptions > inline auto decay(Optimizer &optimizer, double rate) -> void { for (auto &group : optimizer.param_groups()) { for (auto ¶m : group.params()) { if (!param.grad().defined()) continue; auto &options = static_cast<OptimizerOptions &>(group.options()); options.lr(options.lr() * (1.0 - rate)); } } } int MacaonTrain::main() { auto od = getOptionsDescription(); auto variables = checkOptions(od); std::filesystem::path modelPath(variables["model"].as<std::string>()); auto machinePath = modelPath / "machine.rm"; auto mcd = variables["mcd"].as<std::string>(); auto trainTsvFile = variables["trainTSV"].as<std::string>(); auto trainRawFile = variables["trainTXT"].as<std::string>(); auto devTsvFile = variables["devTSV"].as<std::string>(); auto devRawFile = variables["devTXT"].as<std::string>(); auto nbEpoch = variables["nbEpochs"].as<int>(); auto batchSize = variables["batchSize"].as<int>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool computeDevScore = variables.count("devScore") == 0 ? false : true; auto machineContent = variables["machine"].as<std::string>(); auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); auto lossFunction = variables["loss"].as<std::string>(); auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto trainStrategy = parseTrainStrategy(trainStrategyStr); torch::globalContext().setBenchmarkCuDNN(true); if (!machineContent.empty()) { std::FILE * file = fopen(machinePath.c_str(), "w"); if (!file) util::error(fmt::format("can't open file '{}'\n", machinePath.c_str())); fmt::print(file, "{}", machineContent); std::fclose(file); } fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::device.str()); try { ReadingMachine machine(machinePath.string()); BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile); BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); Trainer trainer(machine, batchSize, lossFunction); Decoder decoder(machine); if (!util::findFilesByExtension(machinePath.parent_path(), ".dict").empty()) { machine.loadDicts(); machine.getClassifier()->getNN()->registerEmbeddings(); machine.loadLastSaved(); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); } float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); auto trainInfos = machinePath.parent_path() / "train.info"; int currentEpoch = 0; if (std::filesystem::exists(trainInfos)) { std::FILE * f = std::fopen(trainInfos.c_str(), "r"); char buffer[1024]; while (!std::feof(f)) { if (buffer != std::fgets(buffer, 1024, f)) break; bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED"; float devScoreMean = std::stof(util::split(buffer, '\t').back()); if (saved) bestDevScore = devScoreMean; currentEpoch++; } std::fclose(f); } auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer"; if (std::filesystem::exists(trainInfos)) { machine.getClassifier()->resetOptimizer(); machine.getClassifier()->loadOptimizer(optimizerCheckpoint); } for (; currentEpoch < nbEpoch; currentEpoch++) { bool saved = false; if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples)) { for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train")) if (entry.is_regular_file()) std::filesystem::remove(entry.path()); if (!computeDevScore) for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev")) if (entry.is_regular_file()) std::filesystem::remove(entry.path()); } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) { machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open); trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold); if (!computeDevScore) { machine.setDictsState(Dict::State::Closed); trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold); } } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer)) { if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters)) { machine.resetClassifier(); machine.trainMode(currentEpoch == 0); machine.getClassifier()->getNN()->registerEmbeddings(); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters())); } machine.getClassifier()->resetOptimizer(); } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save)) { saved = true; } trainer.makeDataLoader(modelPath/"examples/train"); if (!computeDevScore) trainer.makeDevDataLoader(modelPath/"examples/dev"); float loss = trainer.epoch(printAdvancement); if (debug) fmt::print(stderr, "Decoding dev :\n"); std::vector<std::pair<float,std::string>> devScores; if (computeDevScore) { BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); decoder.evaluate(devConfig, modelPath, devTsvFile); devScores = decoder.getF1Scores(machine.getPredicted()); } else { float devLoss = trainer.evalOnDev(printAdvancement); devScores.emplace_back(std::make_pair(devLoss, "Loss")); } std::string devScoresStr = ""; float devScoreMean = 0; for (auto & score : devScores) { if (computeDevScore) devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : ""); else devScoresStr += fmt::format("{}({:6.4f}{}),", score.second, 100.0*score.first, computeDevScore ? "%" : ""); devScoreMean += score.first; } if (!devScoresStr.empty()) devScoresStr.pop_back(); devScoreMean /= devScores.size(); if (computeDevScore) saved = saved or devScoreMean >= bestDevScore; else saved = saved or devScoreMean <= bestDevScore; if (saved) { bestDevScore = devScoreMean; machine.saveBest(); } machine.saveLast(); machine.getClassifier()->saveOptimizer(optimizerCheckpoint); if (printAdvancement) fmt::print(stderr, "\r{:80}\r", ""); std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), 100.0*loss, devScoresStr, saved ? "SAVED" : ""); fmt::print(stderr, "{}\n", iterStr); std::FILE * f = std::fopen(trainInfos.c_str(), "a"); fmt::print(f, "{}\t{}\n", iterStr, devScoreMean); std::fclose(f); } } catch(std::exception & e) {util::error(e);} return 0; } MacaonTrain::MacaonTrain(int argc, char ** argv) : argc(argc), argv(argv) { }