Newer
Older
Franck Dary
committed
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
void Trainer::makeDataLoader(std::filesystem::path dir)
Franck Dary
committed
trainDataset.reset(new Dataset(dir));
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
void Trainer::makeDevDataLoader(std::filesystem::path dir)
{
devDataset.reset(new Dataset(dir));
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
Franck Dary
committed
SubConfig config(goldConfig, goldConfig.getNbLines());
Franck Dary
committed
extractExamples(config, debug, dir, epoch, dynamicOracle);
Franck Dary
committed
machine.saveDicts();
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
Franck Dary
committed
{
torch::AutoGradMode useGrad(false);
int maxNbExamplesPerFile = 50000;
std::map<std::string, Examples> examplesPerState;
Franck Dary
committed
std::filesystem::create_directories(dir);
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
machine.getClassifier()->setState(config.getState());
Franck Dary
committed
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
Franck Dary
committed
if (std::filesystem::exists(currentEpochAllExtractedFile))
Franck Dary
committed
return;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
int totalNbExamples = 0;
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
std::vector<std::vector<long>> context;
try
{
context = machine.getClassifier()->getNN()->extractContext(config);
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Franck Dary
committed
Transition * transition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, dynamicOracle);
Franck Dary
committed
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
Franck Dary
committed
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min();
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
{
chosenTransition = i;
bestScore = score;
}
}
transition = machine.getTransitionSet().getTransition(chosenTransition);
}
else
{
Franck Dary
committed
}
Franck Dary
committed
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition);
totalNbExamples += context.size();
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
Franck Dary
committed
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(goldIndex);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = config.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
machine.getClassifier()->setState(movement.first);
for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
Franck Dary
committed
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f);
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
torch::AutoGradMode useGrad(train);
machine.setDictsState(Dict::State::Closed);
Franck Dary
committed
auto lossFct = torch::nn::CrossEntropyLoss();
auto pastTime = std::chrono::high_resolution_clock::now();
for (auto & batch : *loader)
Franck Dary
committed
machine.getClassifier()->getOptimizer().zero_grad();
auto data = std::get<0>(batch);
auto labels = std::get<1>(batch);
auto state = std::get<2>(batch);
machine.getClassifier()->setState(state);
auto prediction = machine.getClassifier()->getNN()(data);
if (prediction.dim() == 1)
prediction = prediction.unsqueeze(0);
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
try
{
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
} catch(std::exception & e) {util::myThrow(e.what());}
if (train)
{
loss.backward();
Franck Dary
committed
machine.getClassifier()->getOptimizer().step();
totalNbExamplesProcessed += torch::numel(labels);
nbExamplesProcessed += torch::numel(labels);
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
auto speed = (int)(nbExamplesProcessed/seconds);
auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
float Trainer::epoch(bool printAdvancement)
{
return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
}
float Trainer::evalOnDev(bool printAdvancement)
{
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
{
if (currentExampleIndex-lastSavedIndex < (int)threshold)
return;
if (contexts.empty())
return;
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
contexts.clear();
classes.clear();
}
void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
{
for (auto & element : context)
contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
currentExampleIndex += context.size();
}
void Trainer::Examples::addClass(int goldIndex)
{
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndex;
while (classes.size() < contexts.size())
classes.emplace_back(gold);
}
Franck Dary
committed
{
torch::AutoGradMode useGrad(false);
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
machine.getClassifier()->setState(config.getState());
Franck Dary
committed
while (true)
{
Franck Dary
committed
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
Franck Dary
committed
try
{
machine.getClassifier()->getNN()->extractContext(config);
Franck Dary
committed
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions);
Franck Dary
committed
if (!goldTransition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
goldTransition->apply(config);
config.addToHistory(goldTransition->getName());
auto movement = config.getStrategy().getMovement(config, goldTransition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second);
Franck Dary
committed
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
machine.getClassifier()->setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
if (config.needsUpdate())
config.update();
}
}
Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
{
if (s == "ExtractGold")
return TrainAction::ExtractGold;
if (s == "ExtractDynamic")
return TrainAction::ExtractDynamic;
if (s == "DeleteExamples")
return TrainAction::DeleteExamples;
if (s == "ResetOptimizer")
return TrainAction::ResetOptimizer;
if (s == "ResetParameters")
return TrainAction::ResetParameters;
if (s == "Save")
return TrainAction::Save;
util::myThrow(fmt::format("unknown TrainAction '{}'", s));
return TrainAction::ExtractGold;
}