From a38db411f9c9a4ce4c9a2d19c7cb762e844b7ba5 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 30 Jan 2020 15:33:25 +0100 Subject: [PATCH] Added Trainer --- CMakeLists.txt | 2 + dev/src/dev.cpp | 2 +- reading_machine/include/Classifier.hpp | 5 ++- reading_machine/include/ReadingMachine.hpp | 4 ++ reading_machine/src/Classifier.cpp | 7 ++- reading_machine/src/ReadingMachine.cpp | 17 ++++++++ trainer/CMakeLists.txt | 5 +++ trainer/include/Trainer.hpp | 24 +++++++++++ trainer/src/Trainer.cpp | 50 ++++++++++++++++++++++ 9 files changed, 112 insertions(+), 4 deletions(-) create mode 100644 trainer/CMakeLists.txt create mode 100644 trainer/include/Trainer.hpp create mode 100644 trainer/src/Trainer.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 54f5cb6..8ad0a5a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,7 @@ include_directories(fmt/include) include_directories(common/include) include_directories(reading_machine/include) include_directories(torch_modules/include) +include_directories(trainer/include) include_directories(utf8) add_subdirectory(fmt) @@ -33,4 +34,5 @@ add_subdirectory(common) add_subdirectory(dev) add_subdirectory(reading_machine) add_subdirectory(torch_modules) +add_subdirectory(trainer) diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index a31fe21..3336afd 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -73,7 +73,7 @@ int main(int argc, char * argv[]) int nbExamples = *dataset.size(); fmt::print("Done! size={}\n", nbExamples); - int batchSize = 100; + int batchSize = 1000; auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); TestNetwork nn(machine.getTransitionSet().size(), 5); diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 5d38ae8..0e8b120 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -3,7 +3,7 @@ #include <string> #include "TransitionSet.hpp" -#include "MLP.hpp" +#include "TestNetwork.hpp" class Classifier { @@ -11,12 +11,13 @@ class Classifier std::string name; std::unique_ptr<TransitionSet> transitionSet; - MLP nn{nullptr}; + TestNetwork nn{nullptr}; public : Classifier(const std::string & name, const std::string & topology, const std::string & tsFile); TransitionSet & getTransitionSet(); + TestNetwork & getNN(); }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index ede9f7a..41cb826 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -5,6 +5,7 @@ #include "Classifier.hpp" #include "Strategy.hpp" #include "FeatureFunction.hpp" +#include "Dict.hpp" class ReadingMachine { @@ -14,12 +15,15 @@ class ReadingMachine std::unique_ptr<Classifier> classifier; std::unique_ptr<Strategy> strategy; std::unique_ptr<FeatureFunction> featureFunction; + std::map<std::string, Dict> dicts; public : ReadingMachine(const std::string & filename); TransitionSet & getTransitionSet(); Strategy & getStrategy(); + Dict & getDict(const std::string & state); + Classifier * getClassifier(); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index d446be2..34c5f60 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -4,7 +4,7 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c { this->name = name; this->transitionSet.reset(new TransitionSet(tsFile)); - this->nn = MLP(topology); + this->nn = TestNetwork(transitionSet->size(), 5); } TransitionSet & Classifier::getTransitionSet() @@ -12,3 +12,8 @@ TransitionSet & Classifier::getTransitionSet() return *transitionSet; } +TestNetwork & Classifier::getNN() +{ + return nn; +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index ec04962..334f8bf 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -3,6 +3,8 @@ ReadingMachine::ReadingMachine(const std::string & filename) { + dicts.emplace(std::make_pair("", Dict::State::Open)); + std::FILE * file = std::fopen(filename.c_str(), "r"); char buffer[1024]; @@ -57,3 +59,18 @@ Strategy & ReadingMachine::getStrategy() return *strategy; } +Dict & ReadingMachine::getDict(const std::string & state) +{ + auto found = dicts.find(state); + + if (found == dicts.end()) + return dicts.at(""); + + return found->second; +} + +Classifier * ReadingMachine::getClassifier() +{ + return classifier.get(); +} + diff --git a/trainer/CMakeLists.txt b/trainer/CMakeLists.txt new file mode 100644 index 0000000..b673afa --- /dev/null +++ b/trainer/CMakeLists.txt @@ -0,0 +1,5 @@ +FILE(GLOB SOURCES src/*.cpp) + +add_library(trainer STATIC ${SOURCES}) +target_link_libraries(trainer reading_machine) + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp new file mode 100644 index 0000000..e8bdcba --- /dev/null +++ b/trainer/include/Trainer.hpp @@ -0,0 +1,24 @@ +#ifndef TRAINER__H +#define TRAINER__H + +#include "ReadingMachine.hpp" +#include "ConfigDataset.hpp" +#include "SubConfig.hpp" +#include "TestNetwork.hpp" + +class Trainer +{ + private : + + ReadingMachine & machine; + std::unique_ptr<ConfigDataset> dataset{nullptr}; + std::unique_ptr<torch::optim::Adam> denseOptimizer; + std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer; + + public : + + Trainer(ReadingMachine & machine); + void createDataset(SubConfig & goldConfig); +}; + +#endif diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp new file mode 100644 index 0000000..19a5320 --- /dev/null +++ b/trainer/src/Trainer.cpp @@ -0,0 +1,50 @@ +#include "Trainer.hpp" +#include "SubConfig.hpp" + +Trainer::Trainer(ReadingMachine & machine) : machine(machine) +{ +} + +void Trainer::createDataset(SubConfig & config) +{ + config.setState(machine.getStrategy().getInitialState()); + + std::vector<torch::Tensor> contexts; + std::vector<torch::Tensor> classes; + + while (true) + { + auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); + if (!transition) + util::myThrow("No transition appliable !"); + + auto context = config.extractContext(5,5,machine.getDict(config.getState())); + contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); + + int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); + auto gold = torch::zeros(1, at::kLong); + gold[0] = goldIndex; + + classes.emplace_back(gold); + + transition->apply(config); + config.addToHistory(transition->getName()); + + auto movement = machine.getStrategy().getMovement(config, transition->getName()); + if (movement == Strategy::endMovement) + break; + + config.setState(movement.first); + if (!config.moveWordIndex(movement.second)) + util::myThrow("Cannot move word index !"); + + if (config.needsUpdate()) + config.update(); + } + + dataset.reset(new ConfigDataset(contexts, classes)); + + denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); + sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); +} + -- GitLab