Newer
Older
Franck Dary
committed
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
Franck Dary
committed
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
Franck Dary
committed
SubConfig config(goldConfig, goldConfig.getNbLines());
Franck Dary
committed
machine.trainMode(false);
Franck Dary
committed
machine.setDictsState(Dict::State::Closed);
Franck Dary
committed
Franck Dary
committed
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
trainDataset.reset(new Dataset(dir));
Franck Dary
committed
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary
committed
void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
Franck Dary
committed
SubConfig config(goldConfig, goldConfig.getNbLines());
Franck Dary
committed
machine.setDictsState(Dict::State::Closed);
Franck Dary
committed
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
devDataset.reset(new Dataset(dir));
Franck Dary
committed
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary
committed
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{
torch::AutoGradMode useGrad(false);
int maxNbExamplesPerFile = 50000;
std::map<std::string, Examples> examplesPerState;
Franck Dary
committed
std::filesystem::create_directories(dir);
config.addPredicted(machine.getPredicted());
Franck Dary
committed
machine.getStrategy().reset();
config.setState(machine.getStrategy().getInitialState());
machine.getClassifier()->setState(machine.getStrategy().getInitialState());
Franck Dary
committed
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval))
mustExtract = false;
if (!mustExtract)
return;
bool dynamicOracle = epoch != 0;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
int totalNbExamples = 0;
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
std::vector<std::vector<long>> context;
try
{
context = machine.getClassifier()->getNN()->extractContext(config);
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Franck Dary
committed
Transition * transition = nullptr;
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
Franck Dary
committed
if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "segmenter")
Franck Dary
committed
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min();
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
{
chosenTransition = i;
bestScore = score;
}
}
transition = machine.getTransitionSet().getTransition(chosenTransition);
}
else
{
Franck Dary
committed
}
Franck Dary
committed
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition);
totalNbExamples += context.size();
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
Franck Dary
committed
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(goldIndex);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
machine.getClassifier()->setState(movement.first);
for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0);
Franck Dary
committed
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f);
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
torch::AutoGradMode useGrad(train);
machine.setDictsState(Dict::State::Closed);
Franck Dary
committed
auto lossFct = torch::nn::CrossEntropyLoss();
auto pastTime = std::chrono::high_resolution_clock::now();
for (auto & batch : *loader)
Franck Dary
committed
machine.getClassifier()->getOptimizer().zero_grad();
auto data = std::get<0>(batch);
auto labels = std::get<1>(batch);
auto state = std::get<2>(batch);
machine.getClassifier()->setState(state);
auto prediction = machine.getClassifier()->getNN()(data);
if (prediction.dim() == 1)
prediction = prediction.unsqueeze(0);
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
try
{
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
} catch(std::exception & e) {util::myThrow(e.what());}
if (train)
{
loss.backward();
Franck Dary
committed
machine.getClassifier()->getOptimizer().step();
totalNbExamplesProcessed += torch::numel(labels);
nbExamplesProcessed += torch::numel(labels);
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
auto speed = (int)(nbExamplesProcessed/seconds);
auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
float Trainer::epoch(bool printAdvancement)
{
return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
}
float Trainer::evalOnDev(bool printAdvancement)
{
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold)
{
if (currentExampleIndex-lastSavedIndex < (int)threshold)
return;
if (contexts.empty())
return;
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}_{}-{}.tensor", state, lastSavedIndex, currentExampleIndex-1);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
contexts.clear();
classes.clear();
}
void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
{
for (auto & element : context)
contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
currentExampleIndex += context.size();
}
void Trainer::Examples::addClass(int goldIndex)
{
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndex;
while (classes.size() < contexts.size())
classes.emplace_back(gold);
}
void Trainer::fillDicts(BaseConfig & goldConfig, bool debug)
Franck Dary
committed
{
SubConfig config(goldConfig, goldConfig.getNbLines());
Franck Dary
committed
machine.trainMode(false);
machine.setDictsState(Dict::State::Open);
Franck Dary
committed
Franck Dary
committed
}
Franck Dary
committed
{
torch::AutoGradMode useGrad(false);
config.addPredicted(machine.getPredicted());
machine.getStrategy().reset();
config.setState(machine.getStrategy().getInitialState());
machine.getClassifier()->setState(machine.getStrategy().getInitialState());
while (true)
{
Franck Dary
committed
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
Franck Dary
committed
try
{
machine.getClassifier()->getNN()->extractContext(config);
Franck Dary
committed
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!goldTransition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
goldTransition->apply(config);
config.addToHistory(goldTransition->getName());
auto movement = machine.getStrategy().getMovement(config, goldTransition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second);
Franck Dary
committed
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
machine.getClassifier()->setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
if (config.needsUpdate())
config.update();
}
}