Skip to content
Snippets Groups Projects
Commit a38db411 authored by Franck Dary's avatar Franck Dary
Browse files

Added Trainer

parent ce420cb5
Branches
No related tags found
No related merge requests found
...@@ -26,6 +26,7 @@ include_directories(fmt/include) ...@@ -26,6 +26,7 @@ include_directories(fmt/include)
include_directories(common/include) include_directories(common/include)
include_directories(reading_machine/include) include_directories(reading_machine/include)
include_directories(torch_modules/include) include_directories(torch_modules/include)
include_directories(trainer/include)
include_directories(utf8) include_directories(utf8)
add_subdirectory(fmt) add_subdirectory(fmt)
...@@ -33,4 +34,5 @@ add_subdirectory(common) ...@@ -33,4 +34,5 @@ add_subdirectory(common)
add_subdirectory(dev) add_subdirectory(dev)
add_subdirectory(reading_machine) add_subdirectory(reading_machine)
add_subdirectory(torch_modules) add_subdirectory(torch_modules)
add_subdirectory(trainer)
...@@ -73,7 +73,7 @@ int main(int argc, char * argv[]) ...@@ -73,7 +73,7 @@ int main(int argc, char * argv[])
int nbExamples = *dataset.size(); int nbExamples = *dataset.size();
fmt::print("Done! size={}\n", nbExamples); fmt::print("Done! size={}\n", nbExamples);
int batchSize = 100; int batchSize = 1000;
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
TestNetwork nn(machine.getTransitionSet().size(), 5); TestNetwork nn(machine.getTransitionSet().size(), 5);
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <string> #include <string>
#include "TransitionSet.hpp" #include "TransitionSet.hpp"
#include "MLP.hpp" #include "TestNetwork.hpp"
class Classifier class Classifier
{ {
...@@ -11,12 +11,13 @@ class Classifier ...@@ -11,12 +11,13 @@ class Classifier
std::string name; std::string name;
std::unique_ptr<TransitionSet> transitionSet; std::unique_ptr<TransitionSet> transitionSet;
MLP nn{nullptr}; TestNetwork nn{nullptr};
public : public :
Classifier(const std::string & name, const std::string & topology, const std::string & tsFile); Classifier(const std::string & name, const std::string & topology, const std::string & tsFile);
TransitionSet & getTransitionSet(); TransitionSet & getTransitionSet();
TestNetwork & getNN();
}; };
#endif #endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "Classifier.hpp" #include "Classifier.hpp"
#include "Strategy.hpp" #include "Strategy.hpp"
#include "FeatureFunction.hpp" #include "FeatureFunction.hpp"
#include "Dict.hpp"
class ReadingMachine class ReadingMachine
{ {
...@@ -14,12 +15,15 @@ class ReadingMachine ...@@ -14,12 +15,15 @@ class ReadingMachine
std::unique_ptr<Classifier> classifier; std::unique_ptr<Classifier> classifier;
std::unique_ptr<Strategy> strategy; std::unique_ptr<Strategy> strategy;
std::unique_ptr<FeatureFunction> featureFunction; std::unique_ptr<FeatureFunction> featureFunction;
std::map<std::string, Dict> dicts;
public : public :
ReadingMachine(const std::string & filename); ReadingMachine(const std::string & filename);
TransitionSet & getTransitionSet(); TransitionSet & getTransitionSet();
Strategy & getStrategy(); Strategy & getStrategy();
Dict & getDict(const std::string & state);
Classifier * getClassifier();
}; };
#endif #endif
...@@ -4,7 +4,7 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c ...@@ -4,7 +4,7 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c
{ {
this->name = name; this->name = name;
this->transitionSet.reset(new TransitionSet(tsFile)); this->transitionSet.reset(new TransitionSet(tsFile));
this->nn = MLP(topology); this->nn = TestNetwork(transitionSet->size(), 5);
} }
TransitionSet & Classifier::getTransitionSet() TransitionSet & Classifier::getTransitionSet()
...@@ -12,3 +12,8 @@ TransitionSet & Classifier::getTransitionSet() ...@@ -12,3 +12,8 @@ TransitionSet & Classifier::getTransitionSet()
return *transitionSet; return *transitionSet;
} }
TestNetwork & Classifier::getNN()
{
return nn;
}
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
ReadingMachine::ReadingMachine(const std::string & filename) ReadingMachine::ReadingMachine(const std::string & filename)
{ {
dicts.emplace(std::make_pair("", Dict::State::Open));
std::FILE * file = std::fopen(filename.c_str(), "r"); std::FILE * file = std::fopen(filename.c_str(), "r");
char buffer[1024]; char buffer[1024];
...@@ -57,3 +59,18 @@ Strategy & ReadingMachine::getStrategy() ...@@ -57,3 +59,18 @@ Strategy & ReadingMachine::getStrategy()
return *strategy; return *strategy;
} }
Dict & ReadingMachine::getDict(const std::string & state)
{
auto found = dicts.find(state);
if (found == dicts.end())
return dicts.at("");
return found->second;
}
Classifier * ReadingMachine::getClassifier()
{
return classifier.get();
}
FILE(GLOB SOURCES src/*.cpp)
add_library(trainer STATIC ${SOURCES})
target_link_libraries(trainer reading_machine)
#ifndef TRAINER__H
#define TRAINER__H
#include "ReadingMachine.hpp"
#include "ConfigDataset.hpp"
#include "SubConfig.hpp"
#include "TestNetwork.hpp"
class Trainer
{
private :
ReadingMachine & machine;
std::unique_ptr<ConfigDataset> dataset{nullptr};
std::unique_ptr<torch::optim::Adam> denseOptimizer;
std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
public :
Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig);
};
#endif
#include "Trainer.hpp"
#include "SubConfig.hpp"
Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{
}
void Trainer::createDataset(SubConfig & config)
{
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
while (true)
{
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong);
gold[0] = goldIndex;
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
if (config.needsUpdate())
config.update();
}
dataset.reset(new ConfigDataset(contexts, classes));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment