#include "Trainer.hpp" #include "SubConfig.hpp" Trainer::Trainer(ReadingMachine & machine) : machine(machine) { } void Trainer::createDataset(SubConfig & config, bool debug) { std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; extractExamples(config, debug, contexts, classes); nbExamples = classes.size(); dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999))); } void Trainer::createDevDataset(SubConfig & config, bool debug) { std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; extractExamples(config, debug, contexts, classes); devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes) { config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); while (true) { if (debug) config.printForDebug(stderr); auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } try { auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)); gold[0] = goldIndex; classes.emplace_back(gold); transition->apply(config); config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; config.setState(movement.first); if (!config.moveWordIndex(movement.second)) { config.printForDebug(stderr); util::myThrow(fmt::format("Cannot move word index by {}", movement.second)); } if (config.needsUpdate()) config.update(); } } float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; int currentBatchNumber = 0; torch::AutoGradMode useGrad(train); machine.getClassifier()->getNN()->train(train); auto lossFct = torch::nn::CrossEntropyLoss(); auto pastTime = std::chrono::high_resolution_clock::now(); for (auto & batch : *loader) { if (train) optimizer->zero_grad(); auto data = batch.data; auto labels = batch.target.squeeze(); auto prediction = machine.getClassifier()->getNN()(data); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); auto loss = lossFct(prediction, labels); try { totalLoss += loss.item<float>(); lossSoFar += loss.item<float>(); } catch(std::exception & e) {util::myThrow(e.what());} if (train) { loss.backward(); optimizer->step(); } if (printAdvancement) { nbExamplesProcessed += labels.size(0); ++currentBatchNumber; if (nbExamplesProcessed >= printInterval) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; pastTime = actualTime; if (train) fmt::print(stderr, "\r{:80}\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<6}ex/s", "", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds)); else fmt::print(stderr, "\r{:80}\reval on dev : loss={:<7.3f} speed={:<6}ex/s", "", lossSoFar, (int)(nbExamplesProcessed/seconds)); lossSoFar = 0; nbExamplesProcessed = 0; } } } return totalLoss; } float Trainer::epoch(bool printAdvancement) { return processDataset(dataLoader, true, printAdvancement); } float Trainer::evalOnDev(bool printAdvancement) { return processDataset(devDataLoader, false, printAdvancement); }