Skip to content
Snippets Groups Projects
Commit 8a10a847 authored by Franck Dary's avatar Franck Dary
Browse files

Working training

parent a38db411
No related branches found
No related tags found
No related merge requests found
......@@ -4,3 +4,4 @@ add_executable(dev src/dev.cpp)
target_link_libraries(dev common)
target_link_libraries(dev reading_machine)
target_link_libraries(dev torch_modules)
target_link_libraries(dev trainer)
......@@ -5,8 +5,7 @@
#include "SubConfig.hpp"
#include "TransitionSet.hpp"
#include "ReadingMachine.hpp"
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
#include "Trainer.hpp"
int main(int argc, char * argv[])
{
......@@ -16,8 +15,6 @@ int main(int argc, char * argv[])
exit(1);
}
at::init_num_threads();
std::string machineFile = argv[1];
std::string mcdFile = argv[2];
std::string tsvFile = argv[3];
......@@ -29,91 +26,13 @@ int main(int argc, char * argv[])
BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
SubConfig config(goldConfig);
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
fmt::print("Generating dataset...\n");
Dict dict(Dict::State::Open);
while (true)
{
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
auto context = config.extractContext(5,5,dict);
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong);
gold[0] = goldIndex;
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
Trainer trainer(machine);
trainer.createDataset(config);
if (config.needsUpdate())
config.update();
}
auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
int nbExamples = *dataset.size();
fmt::print("Done! size={}\n", nbExamples);
int batchSize = 1000;
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
TestNetwork nn(machine.getTransitionSet().size(), 5);
torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5));
torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5));
for (int epoch = 1; epoch <= 30; ++epoch)
for (int i = 0; i < 5; i++)
{
float totalLoss = 0.0;
float lossSoFar = 0.0;
torch::Tensor example;
int currentBatchNumber = 0;
for (auto & batch : *dataLoader)
{
denseOptimizer.zero_grad();
sparseOptimizer.zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = nn(data);
example = prediction[0];
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
loss.backward();
denseOptimizer.step();
sparseOptimizer.step();
if (++currentBatchNumber*batchSize % 1000 == 0)
{
fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar);
std::fflush(stdout);
lossSoFar = 0;
}
}
fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss);
float loss = trainer.epoch();
fmt::print("\nEpoch {} loss = {}\n", i+1, loss);
}
return 0;
......
......@@ -10,15 +10,25 @@ class Trainer
{
private :
using Dataset = ConfigDataset;
using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >;
private :
ReadingMachine & machine;
std::unique_ptr<ConfigDataset> dataset{nullptr};
DataLoader dataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> denseOptimizer;
std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
std::size_t epochNumber{0};
int batchSize{100};
int nbExamples{0};
public :
Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig);
float epoch();
};
#endif
......@@ -42,9 +42,52 @@ void Trainer::createDataset(SubConfig & config)
config.update();
}
dataset.reset(new ConfigDataset(contexts, classes));
nbExamples = classes.size();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
}
float Trainer::epoch()
{
constexpr int printInterval = 2000;
float totalLoss = 0.0;
float lossSoFar = 0.0;
int nbExamplesUntilPrint = printInterval;
int currentBatchNumber = 0;
for (auto & batch : *dataLoader)
{
denseOptimizer->zero_grad();
sparseOptimizer->zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = machine.getClassifier()->getNN()(data);
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
loss.backward();
denseOptimizer->step();
sparseOptimizer->step();
nbExamplesUntilPrint -= labels.size(0);
++currentBatchNumber;
if (nbExamplesUntilPrint <= 0)
{
nbExamplesUntilPrint = printInterval;
fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
std::fflush(stdout);
lossSoFar = 0;
}
}
return totalLoss;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment