Newer
Older
#include "NeuralNetwork.hpp"
po::options_description MacaonTrain::getOptionsDescription()
{
po::options_description desc("Command-Line Arguments ");
po::options_description req("Required");
req.add_options()
("model", po::value<std::string>()->required(),
"Directory containing the machine file to train")
("mcd", po::value<std::string>()->required(),
"Multi Column Description file that describes the input format")
("trainTSV", po::value<std::string>()->required(),
"TSV file of the training corpus, in CONLLU format");
po::options_description opt("Optional");
opt.add_options()
("devScore", "Compute score on dev instead of loss (slower)")
("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""),
"TSV file of the development corpus, in CONLLU format")
("devTXT", po::value<std::string>()->default_value(""),
"Raw text file of the development corpus")
("nbEpochs,n", po::value<int>()->default_value(5),
"Number of training epochs")
("help,h", "Produce this help message");
desc.add(req).add(opt);
return desc;
}
po::variables_map MacaonTrain::checkOptions(po::options_description & od)
{
po::variables_map vm;
try {po::store(po::parse_command_line(argc, argv, od), vm);}
catch(std::exception & e) {util::myThrow(e.what());}
if (vm.count("help"))
{
std::stringstream ss;
ss << od;
fmt::print(stderr, "{}\n", ss.str());
exit(0);
}
try {po::notify(vm);}
catch(std::exception& e) {util::myThrow(e.what());}
return vm;
}
void MacaonTrain::fillDicts(ReadingMachine & rm, const Config & config)
Franck Dary
committed
{
static std::vector<std::string> interestingColumns{"FORM", "LEMMA"};
for (auto & col : interestingColumns)
if (config.has(col,0,0))
for (auto & it : rm.getDicts())
{
it.second.countOcc(true);
for (unsigned int j = 0; j < config.getNbLines(); j++)
for (unsigned int k = 0; k < Config::nbHypothesesMax; k++)
it.second.getIndexOrInsert(config.getConst(col,j,k));
it.second.countOcc(false);
}
}
std::filesystem::path modelPath(variables["model"].as<std::string>());
auto machinePath = modelPath / "machine.rm";
auto mcdFile = variables["mcd"].as<std::string>();
auto trainTsvFile = variables["trainTSV"].as<std::string>();
auto trainRawFile = variables["trainTXT"].as<std::string>();
auto devTsvFile = variables["devTSV"].as<std::string>();
auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true;
fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
Franck Dary
committed
fillDicts(machine, goldConfig);
if (!computeDevScore)
{
SubConfig devConfig(devGoldConfig);
trainer.createDevDataset(devConfig, debug);
}
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
machine.getStrategy().reset();
std::vector<std::pair<float,std::string>> devScores;
if (computeDevScore)
{
auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1, debug, printAdvancement);
machine.getStrategy().reset();
decoder.evaluate(devConfig, modelPath, devTsvFile);
devScores = decoder.getF1Scores(machine.getPredicted());
}
else
{
float devLoss = trainer.evalOnDev(printAdvancement);
devScores.emplace_back(std::make_pair(devLoss, "Loss"));
}
std::string devScoresStr = "";
float devScoreMean = 0;
for (auto & score : devScores)
{
if (computeDevScore)
devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
else
devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
}
if (!devScoresStr.empty())
devScoresStr.pop_back();
devScoreMean /= devScores.size();
bool saved = devScoreMean > bestDevScore;
if (!computeDevScore)
saved = devScoreMean < bestDevScore;
bestDevScore = devScoreMean;
fmt::print(stderr, "Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
fmt::print(stderr, "\r{:80}\rEpoch {:^5} loss = {:6.1f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
}
catch(std::exception & e) {util::error(e);}
MacaonTrain::MacaonTrain(int argc, char ** argv) : argc(argc), argv(argv)
{
}