Select Git revision
HOG Extraction-checkpoint.ipynb
MacaonTrain.cpp 14.54 KiB
#include "MacaonTrain.hpp"
#include <filesystem>
#include <execution>
#include "util.hpp"
#include "NeuralNetwork.hpp"
#include "WordEmbeddings.hpp"
namespace po = boost::program_options;
po::options_description MacaonTrain::getOptionsDescription()
{
po::options_description desc("Command-Line Arguments ");
po::options_description req("Required");
req.add_options()
("model", po::value<std::string>()->required(),
"Directory containing the machine file to train")
("trainTSV", po::value<std::string>()->required(),
"TSV file of the training corpus, in CONLLU format");
po::options_description opt("Optional");
opt.add_options()
("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress")
("memcheck", "Regularly print memory usage on stderr")
("devScore", "Compute score on dev instead of loss (slower)")
("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
"Comma separated column names that describes the input/output format")
("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""),
"TSV file of the development corpus, in CONLLU format")
("devTXT", po::value<std::string>()->default_value(""),
"Raw text file of the development corpus")
("nbEpochs,n", po::value<int>()->default_value(5),
"Number of training epochs")
("batchSize", po::value<int>()->default_value(64),
"Number of examples per batch")
("explorationThreshold", po::value<float>()->default_value(0.1),
"Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset")
("machine", po::value<std::string>()->default_value(""),
"Reading machine file content")
("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
"Description of what should happen during training")
("seed", po::value<int>()->default_value(100),
"Number of examples per batch")
("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
"Max norm for the embeddings")
("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
("lineByLine", "Process the TXT input as being one different text per line.")
("help,h", "Produce this help message")
("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions.");
desc.add(req).add(opt);
return desc;
}
po::variables_map MacaonTrain::checkOptions(po::options_description & od)
{
po::variables_map vm;
try {po::store(po::parse_command_line(argc, argv, od), vm);}
catch(std::exception & e) {util::myThrow(e.what());}
if (vm.count("help"))
{
std::stringstream ss;
ss << od;
fmt::print(stderr, "{}\n", ss.str());
exit(0);
}
try {po::notify(vm);}
catch(std::exception& e) {util::myThrow(e.what());}
return vm;
}
Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
{
Trainer::TrainStrategy ts;
try
{
auto splited = util::split(s, ':');
for (auto & ss : splited)
{
auto elements = util::split(ss, ',');
int epoch = std::stoi(elements[0]);
for (unsigned int i = 1; i < elements.size(); i++)
ts[epoch].insert(Trainer::str2TrainAction(elements[i]));
}
} catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));}
return ts;
}
template
<
typename Optimizer = torch::optim::Adam,
typename OptimizerOptions = torch::optim::AdamOptions
>
inline auto decay(Optimizer &optimizer, double rate) -> void
{
for (auto &group : optimizer.param_groups())
{
for (auto ¶m : group.params())
{
if (!param.grad().defined())
continue;
auto &options = static_cast<OptimizerOptions &>(group.options());
options.lr(options.lr() * (1.0 - rate));
}
}
}
int MacaonTrain::main()
{
auto od = getOptionsDescription();
auto variables = checkOptions(od);
std::filesystem::path modelPath(variables["model"].as<std::string>());
auto machinePath = modelPath / "machine.rm";
auto mcd = variables["mcd"].as<std::string>();
auto trainTsvFile = variables["trainTSV"].as<std::string>();
auto trainRawFile = variables["trainTXT"].as<std::string>();
auto devTsvFile = variables["devTSV"].as<std::string>();
auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>();
auto batchSize = variables["batchSize"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
bool memcheck = variables.count("memcheck") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true;
auto machineContent = variables["machine"].as<std::string>();
auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
auto explorationThreshold = variables["explorationThreshold"].as<float>();
auto seed = variables["seed"].as<int>();
auto oracleMode = variables.count("oracleMode") == 0 ? false : true;
auto lineByLine = variables.count("lineByLine") == 0 ? false : true;
WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
std::srand(seed);
torch::manual_seed(seed);
auto trainStrategy = parseTrainStrategy(trainStrategyStr);
torch::globalContext().setBenchmarkCuDNN(true);
if (!machineContent.empty())
{
std::FILE * file = fopen(machinePath.c_str(), "w");
if (!file)
util::error(fmt::format("can't open file '{}'\n", machinePath.c_str()));
fmt::print(file, "{}", machineContent);
std::fclose(file);
}
fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::getDevice().str());
try
{
std::vector<std::vector<std::string>> trainTsv, devTsv, noTsv;
if (!trainTsvFile.empty())
trainTsv = util::readTSV(trainTsvFile);
if (!devTsvFile.empty())
devTsv = util::readTSV(devTsvFile);
ReadingMachine machine(machinePath.string(), true);
std::vector<util::utf8string> trainRawInputs;
if (!trainRawFile.empty())
trainRawInputs = util::readFileAsUtf8(trainRawFile, lineByLine);
std::vector<BaseConfig> goldConfigs;
if (lineByLine)
{
if (trainRawInputs.size())
for (unsigned int i = 0; i < trainRawInputs.size(); i++)
goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[i], std::vector<int>{(int)i});
else
for (unsigned int i = 0; i < trainTsv.size(); i++)
goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
if (trainRawInputs.size())
goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[0], std::vector<int>());
else
goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>());
}
std::vector<util::utf8string> devRawInputs;
if (!devRawFile.empty())
devRawInputs = util::readFileAsUtf8(devRawFile, lineByLine);
std::vector<BaseConfig> devGoldConfigs;
if (lineByLine)
{
if (devRawInputs.size())
for (unsigned int i = 0; i < devRawInputs.size(); i++)
devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[i], std::vector<int>{(int)i});
else
for (unsigned int i = 0; i < devTsv.size(); i++)
devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
if (devRawInputs.size())
devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[0], std::vector<int>());
else
devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>());
}
Trainer trainer(machine, batchSize);
Decoder decoder(machine);
if (oracleMode)
{
//TODO : handle more than one
trainer.extractActionSequence(goldConfigs[0]);
exit(0);
}
float bestDevScore = computeDevScore ? -std::numeric_limits<float>::max() : std::numeric_limits<float>::max();
auto trainInfos = machinePath.parent_path() / "train.info";
int currentEpoch = 0;
if (std::filesystem::exists(trainInfos))
{
std::FILE * f = std::fopen(trainInfos.c_str(), "r");
char buffer[1024];
while (!std::feof(f))
{
if (buffer != std::fgets(buffer, 1024, f))
break;
bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED";
float devScoreMean = std::stof(util::split(buffer, '\t').back());
if (saved)
bestDevScore = devScoreMean;
currentEpoch++;
}
std::fclose(f);
}
for (; currentEpoch < nbEpoch; currentEpoch++)
{
bool saved = false;
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples))
{
for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train"))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
if (!computeDevScore)
for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev"))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
{
machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open);
trainer.createDataset(goldConfigs, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold, memcheck);
if (!computeDevScore)
{
machine.setDictsState(Dict::State::Closed);
trainer.createDataset(devGoldConfigs, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold, memcheck);
}
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
{
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
{
machine.resetClassifiers();
machine.trainMode(currentEpoch == 0);
fmt::print(stderr, "[{}] Model has {} trainable parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
}
machine.resetOptimizers();
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
{
saved = true;
}
trainer.makeDataLoader(modelPath/"examples/train");
if (!computeDevScore)
trainer.makeDevDataLoader(modelPath/"examples/dev");
float loss = trainer.epoch(printAdvancement);
if (debug)
fmt::print(stderr, "Decoding dev :\n");
std::vector<std::pair<float,std::string>> devScores;
if (computeDevScore)
{
machine.setDictsState(Dict::State::Closed);
std::vector<BaseConfig> devConfigs;
if (lineByLine)
{
if (devRawInputs.size())
for (unsigned int i = 0; i < devRawInputs.size(); i++)
devConfigs.emplace_back(mcd, noTsv, devRawInputs[i], std::vector<int>{(int)i});
else
for (unsigned int i = 0; i < devTsv.size(); i++)
devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
if (devRawInputs.size())
devConfigs.emplace_back(mcd, noTsv, devRawInputs[0], std::vector<int>());
else
devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>());
}
torch::AutoGradMode useGrad(false);
machine.trainMode(false);
machine.setDictsState(Dict::State::Closed);
if (devConfigs.size() > 1)
{
NeuralNetworkImpl::setDevice(torch::kCPU);
machine.to(NeuralNetworkImpl::getDevice());
std::for_each(std::execution::seq, devConfigs.begin(), devConfigs.end(),
[&decoder, debug, printAdvancement](BaseConfig & devConfig)
{
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
});
NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
machine.to(NeuralNetworkImpl::getDevice());
}
else
{
decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement);
}
std::vector<const Config *> devConfigsPtrs;
for (auto & devConfig : devConfigs)
devConfigsPtrs.emplace_back(&devConfig);
decoder.evaluate(devConfigsPtrs, modelPath, devTsvFile, machine.getPredicted());
devScores = decoder.getF1Scores(machine.getPredicted());
}
else
{
float devLoss = trainer.evalOnDev(printAdvancement);
devScores.emplace_back(std::make_pair(devLoss, "Loss"));
}
std::string devScoresStr = "";
float devScoreMean = 0;
int totalLen = 0;
std::string toAdd;
for (auto & score : devScores)
{
if (computeDevScore)
toAdd = fmt::format("{}({}{}),", score.second, util::shrink(fmt::format("{:.2f}", std::abs(score.first)),7), score.first >= 0 ? "%" : "");
else
toAdd = fmt::format("{}({}),", score.second, util::shrink(fmt::format("{}", score.first),7));
devScoreMean += score.first;
devScoresStr += toAdd;
totalLen += util::printedLength(score.second) + 3;
}
if (!devScoresStr.empty())
devScoresStr.pop_back();
devScoresStr = fmt::format("{:{}}", devScoresStr, totalLen+7*devScores.size());
devScoreMean /= devScores.size();
if (computeDevScore)
saved = saved or devScoreMean >= bestDevScore;
else
saved = saved or devScoreMean <= bestDevScore;
if (saved)
{
bestDevScore = devScoreMean;
machine.saveBest();
}
machine.saveLast();
if (printAdvancement)
fmt::print(stderr, "\r{:80}\r", "");
std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:7} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), util::shrink(fmt::format("{}",loss), 7), devScoresStr, saved ? "SAVED" : "");
fmt::print(stderr, "{}\n", iterStr);
std::FILE * f = std::fopen(trainInfos.c_str(), "a");
fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
std::fclose(f);
if (memcheck)
fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());
}
}
catch(std::exception & e) {util::error(e);}
return 0;
}
MacaonTrain::MacaonTrain(int argc, char ** argv) : argc(argc), argv(argv)
{
}