Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include "Trainer.hpp"
#include "SubConfig.hpp"
Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{
}
void Trainer::createDataset(SubConfig & config)
{
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
while (true)
{
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong);
gold[0] = goldIndex;
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
if (config.needsUpdate())
config.update();
}
nbExamples = classes.size();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
}
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
float Trainer::epoch()
{
constexpr int printInterval = 2000;
float totalLoss = 0.0;
float lossSoFar = 0.0;
int nbExamplesUntilPrint = printInterval;
int currentBatchNumber = 0;
for (auto & batch : *dataLoader)
{
denseOptimizer->zero_grad();
sparseOptimizer->zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = machine.getClassifier()->getNN()(data);
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
loss.backward();
denseOptimizer->step();
sparseOptimizer->step();
nbExamplesUntilPrint -= labels.size(0);
++currentBatchNumber;
if (nbExamplesUntilPrint <= 0)
{
nbExamplesUntilPrint = printInterval;
fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
std::fflush(stdout);
lossSoFar = 0;
}
}
return totalLoss;
}