diff --git a/CMakeLists.txt b/CMakeLists.txt index 54f5cb6d417682a962a23ac5fd04d9c89bca3522..8ad0a5a87c20ce6e0acdb1b72ba35f112b2ead69 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,7 @@ include_directories(fmt/include) include_directories(common/include) include_directories(reading_machine/include) include_directories(torch_modules/include) +include_directories(trainer/include) include_directories(utf8) add_subdirectory(fmt) @@ -33,4 +34,5 @@ add_subdirectory(common) add_subdirectory(dev) add_subdirectory(reading_machine) add_subdirectory(torch_modules) +add_subdirectory(trainer) diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index a31fe21a3164818c792b142b563d5ac3b7ae2967..3336afde63f31ae2e40333f3a51046941b410edb 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -73,7 +73,7 @@ int main(int argc, char * argv[]) int nbExamples = *dataset.size(); fmt::print("Done! size={}\n", nbExamples); - int batchSize = 100; + int batchSize = 1000; auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); TestNetwork nn(machine.getTransitionSet().size(), 5); diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 5d38ae88ec4e893b523e6fcd17391e0957cc4934..0e8b120f357673fda318ab138ca8187cff4ba2af 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -3,7 +3,7 @@ #include <string> #include "TransitionSet.hpp" -#include "MLP.hpp" +#include "TestNetwork.hpp" class Classifier { @@ -11,12 +11,13 @@ class Classifier std::string name; std::unique_ptr<TransitionSet> transitionSet; - MLP nn{nullptr}; + TestNetwork nn{nullptr}; public : Classifier(const std::string & name, const std::string & topology, const std::string & tsFile); TransitionSet & getTransitionSet(); + TestNetwork & getNN(); }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index ede9f7a5212bbff43575a1f0b0a9f9164ec23088..41cb82617311e0dfc7448d3c8bb09686257ee045 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -5,6 +5,7 @@ #include "Classifier.hpp" #include "Strategy.hpp" #include "FeatureFunction.hpp" +#include "Dict.hpp" class ReadingMachine { @@ -14,12 +15,15 @@ class ReadingMachine std::unique_ptr<Classifier> classifier; std::unique_ptr<Strategy> strategy; std::unique_ptr<FeatureFunction> featureFunction; + std::map<std::string, Dict> dicts; public : ReadingMachine(const std::string & filename); TransitionSet & getTransitionSet(); Strategy & getStrategy(); + Dict & getDict(const std::string & state); + Classifier * getClassifier(); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index d446be2ce9bfe40b8502ff658a4507dc1d29adc0..34c5f6061123cfb219eb54a1ef46676f338e4866 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -4,7 +4,7 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c { this->name = name; this->transitionSet.reset(new TransitionSet(tsFile)); - this->nn = MLP(topology); + this->nn = TestNetwork(transitionSet->size(), 5); } TransitionSet & Classifier::getTransitionSet() @@ -12,3 +12,8 @@ TransitionSet & Classifier::getTransitionSet() return *transitionSet; } +TestNetwork & Classifier::getNN() +{ + return nn; +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index ec04962fcdd7cfb7f799d802397e510665241bb0..334f8bf327acec84adb9091c58f55fc2c0d05d4b 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -3,6 +3,8 @@ ReadingMachine::ReadingMachine(const std::string & filename) { + dicts.emplace(std::make_pair("", Dict::State::Open)); + std::FILE * file = std::fopen(filename.c_str(), "r"); char buffer[1024]; @@ -57,3 +59,18 @@ Strategy & ReadingMachine::getStrategy() return *strategy; } +Dict & ReadingMachine::getDict(const std::string & state) +{ + auto found = dicts.find(state); + + if (found == dicts.end()) + return dicts.at(""); + + return found->second; +} + +Classifier * ReadingMachine::getClassifier() +{ + return classifier.get(); +} + diff --git a/trainer/CMakeLists.txt b/trainer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b673afa03d496fecb188c151d058a314c7b815ed --- /dev/null +++ b/trainer/CMakeLists.txt @@ -0,0 +1,5 @@ +FILE(GLOB SOURCES src/*.cpp) + +add_library(trainer STATIC ${SOURCES}) +target_link_libraries(trainer reading_machine) + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e8bdcba301141271f2b84ca186e8c8ba041c77d5 --- /dev/null +++ b/trainer/include/Trainer.hpp @@ -0,0 +1,24 @@ +#ifndef TRAINER__H +#define TRAINER__H + +#include "ReadingMachine.hpp" +#include "ConfigDataset.hpp" +#include "SubConfig.hpp" +#include "TestNetwork.hpp" + +class Trainer +{ + private : + + ReadingMachine & machine; + std::unique_ptr<ConfigDataset> dataset{nullptr}; + std::unique_ptr<torch::optim::Adam> denseOptimizer; + std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer; + + public : + + Trainer(ReadingMachine & machine); + void createDataset(SubConfig & goldConfig); +}; + +#endif diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..19a5320a21643ed1f99dbe7f39fb5b9e84b32537 --- /dev/null +++ b/trainer/src/Trainer.cpp @@ -0,0 +1,50 @@ +#include "Trainer.hpp" +#include "SubConfig.hpp" + +Trainer::Trainer(ReadingMachine & machine) : machine(machine) +{ +} + +void Trainer::createDataset(SubConfig & config) +{ + config.setState(machine.getStrategy().getInitialState()); + + std::vector<torch::Tensor> contexts; + std::vector<torch::Tensor> classes; + + while (true) + { + auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); + if (!transition) + util::myThrow("No transition appliable !"); + + auto context = config.extractContext(5,5,machine.getDict(config.getState())); + contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); + + int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); + auto gold = torch::zeros(1, at::kLong); + gold[0] = goldIndex; + + classes.emplace_back(gold); + + transition->apply(config); + config.addToHistory(transition->getName()); + + auto movement = machine.getStrategy().getMovement(config, transition->getName()); + if (movement == Strategy::endMovement) + break; + + config.setState(movement.first); + if (!config.moveWordIndex(movement.second)) + util::myThrow("Cannot move word index !"); + + if (config.needsUpdate()) + config.update(); + } + + dataset.reset(new ConfigDataset(contexts, classes)); + + denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); + sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); +} +