Select Git revision
Trainer.cpp
-
Franck Dary authoredFranck Dary authored
Trainer.cpp 15.93 KiB
#include "Trainer.hpp"
#include "SubConfig.hpp"
#include <execution>
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
{
}
void Trainer::makeDataLoader(std::filesystem::path dir)
{
trainDataset.reset(new Dataset(dir));
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::makeDevDataLoader(std::filesystem::path dir)
{
devDataset.reset(new Dataset(dir));
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::createDataset(std::vector<BaseConfig> & goldConfigs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold, bool memcheck)
{
std::vector<SubConfig> configs;
for (auto & goldConfig : goldConfigs)
configs.emplace_back(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
extractExamples(configs, debug, dir, epoch, dynamicOracle, explorationThreshold, memcheck);
machine.saveDicts();
}
void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold, bool memcheck)
{
torch::AutoGradMode useGrad(false);
int maxNbExamplesPerFile = 50000;
std::unordered_map<std::string, Examples> examplesPerState;
std::mutex examplesMutex;
std::filesystem::create_directories(dir);
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
if (std::filesystem::exists(currentEpochAllExtractedFile))
return;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
std::atomic<int> totalNbExamples = 0;
if (memcheck)
fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());
NeuralNetworkImpl::setDevice(torch::kCPU);
machine.to(NeuralNetworkImpl::getDevice());
std::for_each(std::execution::seq, configs.begin(), configs.end(),
[this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, memcheck, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
{
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
while (true)
{
if (debug)
config.printForDebug(stderr);
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
torch::Tensor context;
try
{
context = machine.getClassifier(config.getState())->getNN()->extractContext(config);
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Transition * transition = nullptr;
auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
Transition * goldTransition = goldTransitions[0];
if (config.getState() == "parser")
goldTransitions[std::rand()%goldTransitions.size()];
int nbClasses = machine.getTransitionSet(config.getState()).size();
float bestScore = -std::numeric_limits<float>::max();
float entropy = 0.0;
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
auto & classifier = *machine.getClassifier(config.getState());
auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0);
entropy = NeuralNetworkImpl::entropy(prediction);
std::vector<int> candidates;
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if (score > bestScore and appliableTransitions[i])
bestScore = score;
}
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
candidates.emplace_back(i);
}
transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
for (auto & trans : goldTransitions)
if (trans == transition)
goldTransition = trans;
}
else
{
transition = goldTransition;
}
if (!transition or !goldTransition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
std::vector<long> goldIndexes;
bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config);
if (machine.getClassifier(config.getState())->isRegression())
{
entropy = 0.0;
auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName());
auto splited = util::split(transition->getName(), ' ');
if (splited.size() != 3 or splited[0] != "WRITESCORE")
util::myThrow(errMessage);
auto col = splited[2];
splited = util::split(splited[1], '.');
if (splited.size() != 2)
util::myThrow(errMessage);
auto object = Config::str2object(splited[0]);
int index = std::stoi(splited[1]);
float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0));
goldIndexes.emplace_back(util::float2long(regressionTarget));
}
else
{
for (auto & t : goldTransitions)
goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t));
}
if (!exampleIsBanned)
{
totalNbExamples += 1;
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::error(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
examplesMutex.lock();
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
examplesMutex.unlock();
}
config.setChosenActionScore(bestScore);
transition->apply(config, entropy);
config.addToHistory(transition->getName());
auto movement = config.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
if (config.needsUpdate())
config.update();
} // End while true
if (memcheck)
fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());
}); // End for on configs
for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
machine.to(NeuralNetworkImpl::getDevice());
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f);
if (memcheck)
fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
}
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
{
constexpr int printInterval = 50;
int nbExamplesProcessed = 0;
int totalNbExamplesProcessed = 0;
float totalLoss = 0.0;
float lossSoFar = 0.0;
torch::AutoGradMode useGrad(train);
machine.trainMode(train);
auto pastTime = std::chrono::high_resolution_clock::now();
for (auto & batch : *loader)
{
auto data = std::get<0>(batch);
auto labels = std::get<1>(batch);
auto state = std::get<2>(batch);
if (train)
machine.getClassifier(state)->getOptimizer().zero_grad();
auto prediction = machine.getClassifier(state)->getNN()->forward(data, state);
if (prediction.dim() == 1)
prediction = prediction.unsqueeze(0);
if (machine.getClassifier(state)->isRegression())
{
labels = labels.to(torch::kFloat);
labels /= util::float2longScale;
}
auto lossParameter = machine.getClassifier(state)->getNN()->getLossParameter(state);
auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels)*(1.0/torch::exp(lossParameter)) + lossParameter;
float lossAsFloat = 0.0;
try
{
lossAsFloat = loss.item<float>();
} catch(std::exception & e) {util::myThrow(e.what());}
totalLoss += lossAsFloat;
lossSoFar += lossAsFloat;
if (train)
{
loss.backward();
machine.getClassifier(state)->getOptimizer().step();
}
totalNbExamplesProcessed += labels.size(0);
if (printAdvancement)
{
nbExamplesProcessed += labels.size(0);
if (nbExamplesProcessed >= printInterval)
{
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
auto speed = (int)(nbExamplesProcessed/seconds);
auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
auto statusStr = fmt::format(lossSoFar/nbExamplesProcessed < 10.0 ? "{:6.2f}% loss={:<7.3f} speed={:<6}ex/s": "{:6.2f}% loss={:<7.0f} speed={:<6}ex/s", progression, lossSoFar / nbExamplesProcessed, speed);
if (train)
fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
else
fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
lossSoFar = 0;
nbExamplesProcessed = 0;
}
}
}
return totalLoss / nbExamples;
}
float Trainer::epoch(bool printAdvancement)
{
return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
}
float Trainer::evalOnDev(bool printAdvancement)
{
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
}
void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
{
if (currentExampleIndex-lastSavedIndex < (int)threshold)
return;
if (contexts.empty())
return;
int nbClasses = classes[0].size(0);
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
contexts.clear();
classes.clear();
}
void Trainer::Examples::addContext(torch::Tensor & context)
{
contexts.emplace_back(context);
currentExampleIndex += 1;
}
void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes)
{
auto gold = lossFct.getGoldFromClassesIndexes(nbClasses, goldIndexes);
while (classes.size() < contexts.size())
classes.emplace_back(gold);
}
Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
{
if (s == "ExtractGold")
return TrainAction::ExtractGold;
if (s == "ExtractDynamic")
return TrainAction::ExtractDynamic;
if (s == "DeleteExamples")
return TrainAction::DeleteExamples;
if (s == "ResetOptimizer")
return TrainAction::ResetOptimizer;
if (s == "ResetParameters")
return TrainAction::ResetParameters;
if (s == "Save")
return TrainAction::Save;
util::myThrow(fmt::format("unknown TrainAction '{}'", s));
return TrainAction::ExtractGold;
}
void Trainer::extractActionSequence(BaseConfig & config)
{
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
int curSeq = 0;
int curSeqStartIndex = -1;
int curInputIndex = 0;
int curInputSeqSize = 0;
int curOutputSeqSize = 0;
int maxInputSeqSize = 0;
int maxOutputSeqSize = 0;
bool newSent = true;
std::vector<std::string> transitionsIndexes;
while (true)
{
if (config.hasCharacter(0))
curInputIndex = config.getCharacterIndex();
else
curInputIndex = config.getWordIndex();
if (curSeqStartIndex == -1 or newSent)
{
newSent = false;
curSeqStartIndex = curInputIndex;
}
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true);
Transition * transition = goldTransitions[0];
if (machine.getClassifier(config.getState())->isRegression())
util::myThrow("Regressions are not supported in extract action sequence mode");
transitionsIndexes.push_back(fmt::format("{}", machine.getTransitionSet(config.getState()).getTransitionIndex(transition)));
maxOutputSeqSize = std::max(maxOutputSeqSize, curOutputSeqSize++);
curInputSeqSize = -curSeqStartIndex + curInputIndex;
maxInputSeqSize = std::max(maxInputSeqSize, curInputSeqSize++);
if (util::split(transition->getName(), ' ')[0] == "EOS")
if (++curSeq % 3 == 0)
{
newSent = true;
std::string curSeq = "";
for (int i = curSeqStartIndex; i <= curInputIndex; i++)
curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i)));
fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
curOutputSeqSize = 0;
curInputSeqSize = 0;
transitionsIndexes.clear();
}
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = config.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
}
if (curSeqStartIndex != curInputIndex)
{
std::string curSeq = "";
for (int i = curSeqStartIndex; i <= curInputIndex; i++)
curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i)));
fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
curOutputSeqSize = 0;
curInputSeqSize = 0;
curSeqStartIndex = curInputIndex;
}
fmt::print(stderr, "Longest output sequence : {}\n", maxOutputSeqSize);
fmt::print(stderr, "Longest input sequence : {}\n", maxInputSeqSize);
}