Newer
Older
#include <boost/program_options.hpp>
#include <filesystem>
#include "util.hpp"
#include "Trainer.hpp"
#include "Decoder.hpp"
#include "NeuralNetwork.hpp"
namespace po = boost::program_options;
po::options_description getOptionsDescription()
{
po::options_description desc("Command-Line Arguments ");
po::options_description req("Required");
req.add_options()
("model", po::value<std::string>()->required(),
"Directory containing the machine file to train")
("mcd", po::value<std::string>()->required(),
"Multi Column Description file that describes the input format")
("trainTSV", po::value<std::string>()->required(),
"TSV file of the training corpus, in CONLLU format");
po::options_description opt("Optional");
opt.add_options()
("devScore", "Compute score on dev instead of loss (slower)")
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""),
"TSV file of the development corpus, in CONLLU format")
("devTXT", po::value<std::string>()->default_value(""),
"Raw text file of the development corpus")
("nbEpochs,n", po::value<int>()->default_value(5),
"Number of training epochs")
("help,h", "Produce this help message");
desc.add(req).add(opt);
return desc;
}
po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
{
po::variables_map vm;
try {po::store(po::parse_command_line(argc, argv, od), vm);}
catch(std::exception & e) {util::myThrow(e.what());}
if (vm.count("help"))
{
std::stringstream ss;
ss << od;
fmt::print(stderr, "{}\n", ss.str());
exit(0);
}
try {po::notify(vm);}
catch(std::exception& e) {util::myThrow(e.what());}
return vm;
}
int main(int argc, char * argv[])
{
auto od = getOptionsDescription();
auto variables = checkOptions(od, argc, argv);
std::filesystem::path modelPath(variables["model"].as<std::string>());
auto machinePath = modelPath / "machine.rm";
auto mcdFile = variables["mcd"].as<std::string>();
auto trainTsvFile = variables["trainTSV"].as<std::string>();
auto trainRawFile = variables["trainTXT"].as<std::string>();
auto devTsvFile = variables["devTSV"].as<std::string>();
auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true;
fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
SubConfig config(goldConfig);
Trainer trainer(machine);
if (!computeDevScore)
{
SubConfig devConfig(devGoldConfig);
trainer.createDevDataset(devConfig, debug);
}
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
machine.getStrategy().reset();
std::vector<std::pair<float,std::string>> devScores;
if (computeDevScore)
{
auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1, debug, printAdvancement);
machine.getStrategy().reset();
decoder.evaluate(devConfig, modelPath, devTsvFile);
devScores = decoder.getF1Scores(machine.getPredicted());
}
else
{
float devLoss = trainer.evalOnDev(printAdvancement);
devScores.emplace_back(std::make_pair(devLoss, "Loss"));
}
std::string devScoresStr = "";
float devScoreMean = 0;
for (auto & score : devScores)
{
if (computeDevScore)
devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
else
devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
}
if (!devScoresStr.empty())
devScoresStr.pop_back();
devScoreMean /= devScores.size();
bool saved = devScoreMean > bestDevScore;
if (!computeDevScore)
saved = devScoreMean < bestDevScore;
bestDevScore = devScoreMean;
fmt::print(stderr, "Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
fmt::print(stderr, "\r{:80}\rEpoch {:^5} loss = {:6.1f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
}
catch(std::exception & e) {util::error(e);}