Newer
Older
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
Franck Dary
committed
#include "ReadingMachine.hpp"
fmt::print(stderr, "needs 4 arguments.\n");
exit(1);
}
at::init_num_threads();
std::string machineFile = argv[1];
std::string mcdFile = argv[2];
std::string tsvFile = argv[3];
//std::string rawFile = argv[4];
std::string rawFile = "";
ReadingMachine machine(machineFile);
BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
SubConfig config(goldConfig);
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
fmt::print("Generating dataset...\n");
Dict dict(Dict::State::Open);
while (true)
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
auto context = config.extractContext(5,5,dict);
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong);
gold[0] = goldIndex;
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
if (config.needsUpdate())
config.update();
auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
int nbExamples = *dataset.size();
fmt::print("Done! size={}\n", nbExamples);
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
TestNetwork nn(machine.getTransitionSet().size(), 5);
torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5));
torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5));
for (int epoch = 1; epoch <= 30; ++epoch)
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
float totalLoss = 0.0;
float lossSoFar = 0.0;
torch::Tensor example;
int currentBatchNumber = 0;
for (auto & batch : *dataLoader)
{
denseOptimizer.zero_grad();
sparseOptimizer.zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = nn(data);
example = prediction[0];
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
loss.backward();
denseOptimizer.step();
sparseOptimizer.step();
if (++currentBatchNumber*batchSize % 1000 == 0)
{
fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar);
std::fflush(stdout);
lossSoFar = 0;
}
}
fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss);