#include "Trainer.hpp" #include "SubConfig.hpp" Trainer::Trainer(ReadingMachine & machine) : machine(machine) { } void Trainer::createDataset(SubConfig & config, bool debug) { config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; while (true) { if (debug) config.printForDebug(stderr); auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, at::kLong); gold[0] = goldIndex; classes.emplace_back(gold); transition->apply(config); config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; config.setState(movement.first); if (!config.moveWordIndex(movement.second)) { config.printForDebug(stderr); util::myThrow(fmt::format("Cannot move word index by {}", movement.second)); } if (config.needsUpdate()) config.update(); } nbExamples = classes.size(); dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999))); } float Trainer::epoch(bool printAdvancement) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; int currentBatchNumber = 0; auto lossFct = torch::nn::CrossEntropyLoss(); auto pastTime = std::chrono::high_resolution_clock::now(); for (auto & batch : *dataLoader) { optimizer->zero_grad(); auto data = batch.data; auto labels = batch.target.squeeze(); auto prediction = machine.getClassifier()->getNN()(data); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); auto loss = lossFct(prediction, labels); try { totalLoss += loss.item<float>(); lossSoFar += loss.item<float>(); } catch(std::exception & e) {util::myThrow(e.what());} loss.backward(); optimizer->step(); if (printAdvancement) { nbExamplesProcessed += labels.size(0); ++currentBatchNumber; if (nbExamplesProcessed >= printInterval) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; pastTime = actualTime; fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<5}ex/s", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds)); lossSoFar = 0; nbExamplesProcessed = 0; } } } return totalLoss; }