Newer
Older
#include "Trainer.hpp"
#include "SubConfig.hpp"
Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{
}
void Trainer::createDataset(SubConfig & config, bool debug)
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
extractExamples(config, debug, contexts, classes);
nbExamples = classes.size();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
}
void Trainer::createDevDataset(SubConfig & config, bool debug)
{
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
extractExamples(config, debug, contexts, classes);
devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes)
{
config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState());
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
try
{
auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device));
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device));
gold[0] = goldIndex;
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
{
config.printForDebug(stderr);
util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
}
if (config.needsUpdate())
config.update();
}
}
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement)
float totalLoss = 0.0;
float lossSoFar = 0.0;
int currentBatchNumber = 0;
torch::AutoGradMode useGrad(train);
Franck Dary
committed
auto lossFct = torch::nn::CrossEntropyLoss();
auto pastTime = std::chrono::high_resolution_clock::now();
for (auto & batch : *loader)
if (train)
optimizer->zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = machine.getClassifier()->getNN()(data);
if (prediction.dim() == 1)
prediction = prediction.unsqueeze(0);
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
Franck Dary
committed
auto loss = lossFct(prediction, labels);
try
{
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
} catch(std::exception & e) {util::myThrow(e.what());}
if (train)
{
loss.backward();
optimizer->step();
}
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
fmt::print(stderr, "\r{:80}\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<6}ex/s", "", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds));
fmt::print(stderr, "\r{:80}\reval on dev : loss={:<7.3f} speed={:<6}ex/s", "", lossSoFar, (int)(nbExamplesProcessed/seconds));
float Trainer::epoch(bool printAdvancement)
{
return processDataset(dataLoader, true, printAdvancement);
}
float Trainer::evalOnDev(bool printAdvancement)
{
return processDataset(devDataLoader, false, printAdvancement);
}