From 8a10a8477f0fc54d9446ee1be9c3a162e63ec816 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 31 Jan 2020 13:54:55 +0100 Subject: [PATCH] Working training --- dev/CMakeLists.txt | 1 + dev/src/dev.cpp | 93 +++---------------------------------- trainer/include/Trainer.hpp | 12 ++++- trainer/src/Trainer.cpp | 45 +++++++++++++++++- 4 files changed, 62 insertions(+), 89 deletions(-) diff --git a/dev/CMakeLists.txt b/dev/CMakeLists.txt index a473806..35eee29 100644 --- a/dev/CMakeLists.txt +++ b/dev/CMakeLists.txt @@ -4,3 +4,4 @@ add_executable(dev src/dev.cpp) target_link_libraries(dev common) target_link_libraries(dev reading_machine) target_link_libraries(dev torch_modules) +target_link_libraries(dev trainer) diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 3336afd..0d9738d 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -5,8 +5,7 @@ #include "SubConfig.hpp" #include "TransitionSet.hpp" #include "ReadingMachine.hpp" -#include "TestNetwork.hpp" -#include "ConfigDataset.hpp" +#include "Trainer.hpp" int main(int argc, char * argv[]) { @@ -16,8 +15,6 @@ int main(int argc, char * argv[]) exit(1); } - at::init_num_threads(); - std::string machineFile = argv[1]; std::string mcdFile = argv[2]; std::string tsvFile = argv[3]; @@ -29,91 +26,13 @@ int main(int argc, char * argv[]) BaseConfig goldConfig(mcdFile, tsvFile, rawFile); SubConfig config(goldConfig); - config.setState(machine.getStrategy().getInitialState()); - - std::vector<torch::Tensor> contexts; - std::vector<torch::Tensor> classes; - - fmt::print("Generating dataset...\n"); - - Dict dict(Dict::State::Open); - - while (true) - { - auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); - if (!transition) - util::myThrow("No transition appliable !"); - - auto context = config.extractContext(5,5,dict); - contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); - - int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); - auto gold = torch::zeros(1, at::kLong); - gold[0] = goldIndex; - - classes.emplace_back(gold); - - transition->apply(config); - config.addToHistory(transition->getName()); - - auto movement = machine.getStrategy().getMovement(config, transition->getName()); - if (movement == Strategy::endMovement) - break; - - config.setState(movement.first); - if (!config.moveWordIndex(movement.second)) - util::myThrow("Cannot move word index !"); + Trainer trainer(machine); + trainer.createDataset(config); - if (config.needsUpdate()) - config.update(); - } - - auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); - - int nbExamples = *dataset.size(); - fmt::print("Done! size={}\n", nbExamples); - - int batchSize = 1000; - auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - - TestNetwork nn(machine.getTransitionSet().size(), 5); - torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)); - torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)); - - for (int epoch = 1; epoch <= 30; ++epoch) + for (int i = 0; i < 5; i++) { - float totalLoss = 0.0; - float lossSoFar = 0.0; - torch::Tensor example; - int currentBatchNumber = 0; - - for (auto & batch : *dataLoader) - { - denseOptimizer.zero_grad(); - sparseOptimizer.zero_grad(); - - auto data = batch.data; - auto labels = batch.target.squeeze(); - - auto prediction = nn(data); - example = prediction[0]; - - auto loss = torch::nll_loss(torch::log(prediction), labels); - totalLoss += loss.item<float>(); - lossSoFar += loss.item<float>(); - loss.backward(); - denseOptimizer.step(); - sparseOptimizer.step(); - - if (++currentBatchNumber*batchSize % 1000 == 0) - { - fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar); - std::fflush(stdout); - lossSoFar = 0; - } - } - - fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss); + float loss = trainer.epoch(); + fmt::print("\nEpoch {} loss = {}\n", i+1, loss); } return 0; diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index e8bdcba..45fccbe 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -10,15 +10,25 @@ class Trainer { private : + using Dataset = ConfigDataset; + using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >; + + private : + ReadingMachine & machine; - std::unique_ptr<ConfigDataset> dataset{nullptr}; + DataLoader dataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> denseOptimizer; std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer; + std::size_t epochNumber{0}; + int batchSize{100}; + int nbExamples{0}; public : Trainer(ReadingMachine & machine); void createDataset(SubConfig & goldConfig); + float epoch(); + }; #endif diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 19a5320..6279f8e 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -42,9 +42,52 @@ void Trainer::createDataset(SubConfig & config) config.update(); } - dataset.reset(new ConfigDataset(contexts, classes)); + nbExamples = classes.size(); + + dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); } +float Trainer::epoch() +{ + constexpr int printInterval = 2000; + float totalLoss = 0.0; + float lossSoFar = 0.0; + int nbExamplesUntilPrint = printInterval; + int currentBatchNumber = 0; + + for (auto & batch : *dataLoader) + { + denseOptimizer->zero_grad(); + sparseOptimizer->zero_grad(); + + auto data = batch.data; + auto labels = batch.target.squeeze(); + + auto prediction = machine.getClassifier()->getNN()(data); + + auto loss = torch::nll_loss(torch::log(prediction), labels); + totalLoss += loss.item<float>(); + lossSoFar += loss.item<float>(); + + loss.backward(); + denseOptimizer->step(); + sparseOptimizer->step(); + + nbExamplesUntilPrint -= labels.size(0); + + ++currentBatchNumber; + if (nbExamplesUntilPrint <= 0) + { + nbExamplesUntilPrint = printInterval; + fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar); + std::fflush(stdout); + lossSoFar = 0; + } + } + + return totalLoss; +} + -- GitLab