Newer
Older
Franck Dary
committed
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
Franck Dary
committed
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
Franck Dary
committed
SubConfig config(goldConfig, goldConfig.getNbLines());
Franck Dary
committed
machine.trainMode(false);
Franck Dary
committed
machine.setDictsState(Dict::State::Closed);
Franck Dary
committed
Franck Dary
committed
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
trainDataset.reset(new Dataset(dir));
Franck Dary
committed
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary
committed
void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
Franck Dary
committed
SubConfig config(goldConfig, goldConfig.getNbLines());
Franck Dary
committed
machine.setDictsState(Dict::State::Closed);
Franck Dary
committed
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
devDataset.reset(new Dataset(dir));
Franck Dary
committed
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary
committed
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{
torch::AutoGradMode useGrad(false);
int maxNbExamplesPerFile = 50000;
std::map<std::string, Examples> examplesPerState;
Franck Dary
committed
std::filesystem::create_directories(dir);
config.addPredicted(machine.getPredicted());
Franck Dary
committed
machine.getStrategy().reset();
config.setState(machine.getStrategy().getInitialState());
machine.getClassifier()->setState(machine.getStrategy().getInitialState());
Franck Dary
committed
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval))
mustExtract = false;
if (!mustExtract)
return;
bool dynamicOracle = epoch != 0;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
std::filesystem::remove(entry.path());
int totalNbExamples = 0;
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
std::vector<std::vector<long>> context;
try
{
context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Franck Dary
committed
Transition * transition = nullptr;
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
Franck Dary
committed
if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "parser")
Franck Dary
committed
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min();
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
{
chosenTransition = i;
bestScore = score;
}
}
transition = machine.getTransitionSet().getTransition(chosenTransition);
}
else
{
Franck Dary
committed
}
Franck Dary
committed
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition);
totalNbExamples += context.size();
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
Franck Dary
committed
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(goldIndex);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
machine.getClassifier()->setState(movement.first);
for (auto & it : examplesPerState)
it.second.saveIfNeeded(it.first, dir, 0);
Franck Dary
committed
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f);
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
torch::AutoGradMode useGrad(train);
machine.setDictsState(Dict::State::Closed);
Franck Dary
committed
auto lossFct = torch::nn::CrossEntropyLoss();
auto pastTime = std::chrono::high_resolution_clock::now();
for (auto & batch : *loader)
Franck Dary
committed
machine.getClassifier()->getOptimizer().zero_grad();
auto data = std::get<0>(batch);
auto labels = std::get<1>(batch);
auto state = std::get<2>(batch);
machine.getClassifier()->setState(state);
auto prediction = machine.getClassifier()->getNN()(data);
if (prediction.dim() == 1)
prediction = prediction.unsqueeze(0);
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
Franck Dary
committed
auto loss = lossFct(prediction, labels);
try
{
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
} catch(std::exception & e) {util::myThrow(e.what());}
if (train)
{
loss.backward();
Franck Dary
committed
machine.getClassifier()->getOptimizer().step();
totalNbExamplesProcessed += torch::numel(labels);
nbExamplesProcessed += torch::numel(labels);
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
auto speed = (int)(nbExamplesProcessed/seconds);
auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
float Trainer::epoch(bool printAdvancement)
{
return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
}
float Trainer::evalOnDev(bool printAdvancement)
{
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold)
{
if (currentExampleIndex-lastSavedIndex < (int)threshold)
return;
if (contexts.empty())
return;
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}_{}-{}.tensor", state, lastSavedIndex, currentExampleIndex-1);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
contexts.clear();
classes.clear();
}
void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
{
for (auto & element : context)
contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
currentExampleIndex += context.size();
}
void Trainer::Examples::addClass(int goldIndex)
{
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndex;
while (classes.size() < contexts.size())
classes.emplace_back(gold);
}
Franck Dary
committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
void Trainer::fillDicts(BaseConfig & goldConfig)
{
SubConfig config(goldConfig, goldConfig.getNbLines());
for (auto & it : machine.getDicts())
it.second.countOcc(true);
machine.trainMode(false);
machine.setDictsState(Dict::State::Open);
fillDicts(config);
for (auto & it : machine.getDicts())
it.second.countOcc(false);
}
void Trainer::fillDicts(SubConfig & config)
{
torch::AutoGradMode useGrad(false);
config.addPredicted(machine.getPredicted());
machine.getStrategy().reset();
config.setState(machine.getStrategy().getInitialState());
machine.getClassifier()->setState(machine.getStrategy().getInitialState());
while (true)
{
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
try
{
machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!goldTransition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
goldTransition->apply(config);
config.addToHistory(goldTransition->getName());
auto movement = machine.getStrategy().getMovement(config, goldTransition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
machine.getClassifier()->setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
if (config.needsUpdate())
config.update();
}
}