Skip to content
Snippets Groups Projects
Commit 8a10a847 authored by Franck Dary's avatar Franck Dary
Browse files

Working training

parent a38db411
No related branches found
No related tags found
No related merge requests found
...@@ -4,3 +4,4 @@ add_executable(dev src/dev.cpp) ...@@ -4,3 +4,4 @@ add_executable(dev src/dev.cpp)
target_link_libraries(dev common) target_link_libraries(dev common)
target_link_libraries(dev reading_machine) target_link_libraries(dev reading_machine)
target_link_libraries(dev torch_modules) target_link_libraries(dev torch_modules)
target_link_libraries(dev trainer)
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
#include "SubConfig.hpp" #include "SubConfig.hpp"
#include "TransitionSet.hpp" #include "TransitionSet.hpp"
#include "ReadingMachine.hpp" #include "ReadingMachine.hpp"
#include "TestNetwork.hpp" #include "Trainer.hpp"
#include "ConfigDataset.hpp"
int main(int argc, char * argv[]) int main(int argc, char * argv[])
{ {
...@@ -16,8 +15,6 @@ int main(int argc, char * argv[]) ...@@ -16,8 +15,6 @@ int main(int argc, char * argv[])
exit(1); exit(1);
} }
at::init_num_threads();
std::string machineFile = argv[1]; std::string machineFile = argv[1];
std::string mcdFile = argv[2]; std::string mcdFile = argv[2];
std::string tsvFile = argv[3]; std::string tsvFile = argv[3];
...@@ -29,91 +26,13 @@ int main(int argc, char * argv[]) ...@@ -29,91 +26,13 @@ int main(int argc, char * argv[])
BaseConfig goldConfig(mcdFile, tsvFile, rawFile); BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
SubConfig config(goldConfig); SubConfig config(goldConfig);
config.setState(machine.getStrategy().getInitialState()); Trainer trainer(machine);
trainer.createDataset(config);
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
fmt::print("Generating dataset...\n");
Dict dict(Dict::State::Open);
while (true)
{
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
auto context = config.extractContext(5,5,dict);
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong);
gold[0] = goldIndex;
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
if (config.needsUpdate()) for (int i = 0; i < 5; i++)
config.update();
}
auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
int nbExamples = *dataset.size();
fmt::print("Done! size={}\n", nbExamples);
int batchSize = 1000;
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
TestNetwork nn(machine.getTransitionSet().size(), 5);
torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5));
torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5));
for (int epoch = 1; epoch <= 30; ++epoch)
{ {
float totalLoss = 0.0; float loss = trainer.epoch();
float lossSoFar = 0.0; fmt::print("\nEpoch {} loss = {}\n", i+1, loss);
torch::Tensor example;
int currentBatchNumber = 0;
for (auto & batch : *dataLoader)
{
denseOptimizer.zero_grad();
sparseOptimizer.zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = nn(data);
example = prediction[0];
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
loss.backward();
denseOptimizer.step();
sparseOptimizer.step();
if (++currentBatchNumber*batchSize % 1000 == 0)
{
fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar);
std::fflush(stdout);
lossSoFar = 0;
}
}
fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss);
} }
return 0; return 0;
......
...@@ -10,15 +10,25 @@ class Trainer ...@@ -10,15 +10,25 @@ class Trainer
{ {
private : private :
using Dataset = ConfigDataset;
using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >;
private :
ReadingMachine & machine; ReadingMachine & machine;
std::unique_ptr<ConfigDataset> dataset{nullptr}; DataLoader dataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> denseOptimizer; std::unique_ptr<torch::optim::Adam> denseOptimizer;
std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer; std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
std::size_t epochNumber{0};
int batchSize{100};
int nbExamples{0};
public : public :
Trainer(ReadingMachine & machine); Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig); void createDataset(SubConfig & goldConfig);
float epoch();
}; };
#endif #endif
...@@ -42,9 +42,52 @@ void Trainer::createDataset(SubConfig & config) ...@@ -42,9 +42,52 @@ void Trainer::createDataset(SubConfig & config)
config.update(); config.update();
} }
dataset.reset(new ConfigDataset(contexts, classes)); nbExamples = classes.size();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
} }
float Trainer::epoch()
{
constexpr int printInterval = 2000;
float totalLoss = 0.0;
float lossSoFar = 0.0;
int nbExamplesUntilPrint = printInterval;
int currentBatchNumber = 0;
for (auto & batch : *dataLoader)
{
denseOptimizer->zero_grad();
sparseOptimizer->zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = machine.getClassifier()->getNN()(data);
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
loss.backward();
denseOptimizer->step();
sparseOptimizer->step();
nbExamplesUntilPrint -= labels.size(0);
++currentBatchNumber;
if (nbExamplesUntilPrint <= 0)
{
nbExamplesUntilPrint = printInterval;
fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
std::fflush(stdout);
lossSoFar = 0;
}
}
return totalLoss;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment