Skip to content
Snippets Groups Projects
Commit a38db411 authored by Franck Dary's avatar Franck Dary
Browse files

Added Trainer

parent ce420cb5
No related branches found
No related tags found
No related merge requests found
......@@ -26,6 +26,7 @@ include_directories(fmt/include)
include_directories(common/include)
include_directories(reading_machine/include)
include_directories(torch_modules/include)
include_directories(trainer/include)
include_directories(utf8)
add_subdirectory(fmt)
......@@ -33,4 +34,5 @@ add_subdirectory(common)
add_subdirectory(dev)
add_subdirectory(reading_machine)
add_subdirectory(torch_modules)
add_subdirectory(trainer)
......@@ -73,7 +73,7 @@ int main(int argc, char * argv[])
int nbExamples = *dataset.size();
fmt::print("Done! size={}\n", nbExamples);
int batchSize = 100;
int batchSize = 1000;
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
TestNetwork nn(machine.getTransitionSet().size(), 5);
......
......@@ -3,7 +3,7 @@
#include <string>
#include "TransitionSet.hpp"
#include "MLP.hpp"
#include "TestNetwork.hpp"
class Classifier
{
......@@ -11,12 +11,13 @@ class Classifier
std::string name;
std::unique_ptr<TransitionSet> transitionSet;
MLP nn{nullptr};
TestNetwork nn{nullptr};
public :
Classifier(const std::string & name, const std::string & topology, const std::string & tsFile);
TransitionSet & getTransitionSet();
TestNetwork & getNN();
};
#endif
......@@ -5,6 +5,7 @@
#include "Classifier.hpp"
#include "Strategy.hpp"
#include "FeatureFunction.hpp"
#include "Dict.hpp"
class ReadingMachine
{
......@@ -14,12 +15,15 @@ class ReadingMachine
std::unique_ptr<Classifier> classifier;
std::unique_ptr<Strategy> strategy;
std::unique_ptr<FeatureFunction> featureFunction;
std::map<std::string, Dict> dicts;
public :
ReadingMachine(const std::string & filename);
TransitionSet & getTransitionSet();
Strategy & getStrategy();
Dict & getDict(const std::string & state);
Classifier * getClassifier();
};
#endif
......@@ -4,7 +4,7 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c
{
this->name = name;
this->transitionSet.reset(new TransitionSet(tsFile));
this->nn = MLP(topology);
this->nn = TestNetwork(transitionSet->size(), 5);
}
TransitionSet & Classifier::getTransitionSet()
......@@ -12,3 +12,8 @@ TransitionSet & Classifier::getTransitionSet()
return *transitionSet;
}
TestNetwork & Classifier::getNN()
{
return nn;
}
......@@ -3,6 +3,8 @@
ReadingMachine::ReadingMachine(const std::string & filename)
{
dicts.emplace(std::make_pair("", Dict::State::Open));
std::FILE * file = std::fopen(filename.c_str(), "r");
char buffer[1024];
......@@ -57,3 +59,18 @@ Strategy & ReadingMachine::getStrategy()
return *strategy;
}
Dict & ReadingMachine::getDict(const std::string & state)
{
auto found = dicts.find(state);
if (found == dicts.end())
return dicts.at("");
return found->second;
}
Classifier * ReadingMachine::getClassifier()
{
return classifier.get();
}
FILE(GLOB SOURCES src/*.cpp)
add_library(trainer STATIC ${SOURCES})
target_link_libraries(trainer reading_machine)
#ifndef TRAINER__H
#define TRAINER__H
#include "ReadingMachine.hpp"
#include "ConfigDataset.hpp"
#include "SubConfig.hpp"
#include "TestNetwork.hpp"
class Trainer
{
private :
ReadingMachine & machine;
std::unique_ptr<ConfigDataset> dataset{nullptr};
std::unique_ptr<torch::optim::Adam> denseOptimizer;
std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
public :
Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig);
};
#endif
#include "Trainer.hpp"
#include "SubConfig.hpp"
Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{
}
void Trainer::createDataset(SubConfig & config)
{
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
while (true)
{
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong);
gold[0] = goldIndex;
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
if (config.needsUpdate())
config.update();
}
dataset.reset(new ConfigDataset(contexts, classes));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment