#include "Trainer.hpp" #include "SubConfig.hpp" Trainer::Trainer(ReadingMachine & machine) : machine(machine) { } void Trainer::createDataset(SubConfig & config) { config.setState(machine.getStrategy().getInitialState()); std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; while (true) { auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } //TODO : check if clone is mandatory auto context = config.extractContext(5,5,machine.getDict(config.getState())); contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, at::kLong); gold[0] = goldIndex; classes.emplace_back(gold); transition->apply(config); config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (movement == Strategy::endMovement) break; config.setState(movement.first); if (!config.moveWordIndex(movement.second)) { config.printForDebug(stderr); util::myThrow(fmt::format("Cannot move word index by {}", movement.second)); } if (config.needsUpdate()) config.update(); } nbExamples = classes.size(); dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); } float Trainer::epoch() { constexpr int printInterval = 2000; float totalLoss = 0.0; float lossSoFar = 0.0; int nbExamplesUntilPrint = printInterval; int currentBatchNumber = 0; for (auto & batch : *dataLoader) { denseOptimizer->zero_grad(); sparseOptimizer->zero_grad(); auto data = batch.data; auto labels = batch.target.squeeze(); auto prediction = machine.getClassifier()->getNN()(data); auto loss = torch::nll_loss(torch::log(prediction), labels); totalLoss += loss.item<float>(); lossSoFar += loss.item<float>(); loss.backward(); denseOptimizer->step(); sparseOptimizer->step(); nbExamplesUntilPrint -= labels.size(0); ++currentBatchNumber; if (nbExamplesUntilPrint <= 0) { nbExamplesUntilPrint = printInterval; fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar); lossSoFar = 0; } } return totalLoss; }