Newer
Older
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
Franck Dary
committed
#include "ReadingMachine.hpp"
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
//constexpr int batchSize = 50;
//constexpr int nbExamples = 350000;
//constexpr int embeddingSize = 20;
//constexpr int nbClasses = 15;
//constexpr int nbWordsPerDatapoint = 5;
//constexpr int maxNbEmbeddings = 1000000;
//
//struct NetworkImpl : torch::nn::Module
//{
// torch::nn::Linear linear{nullptr};
// torch::nn::Embedding wordEmbeddings{nullptr};
//
// std::vector<torch::Tensor> _sparseParameters;
// std::vector<torch::Tensor> _denseParameters;
// NetworkImpl()
// {
// linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
// auto params = linear->parameters();
// _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
//
// wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
// params = wordEmbeddings->parameters();
// _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
// };
// const std::vector<torch::Tensor> & denseParameters()
// {
// return _denseParameters;
// }
// const std::vector<torch::Tensor> & sparseParameters()
// {
// return _sparseParameters;
// }
// torch::Tensor forward(const torch::Tensor & input)
// {
// // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
// auto embeddingsOfInput = wordEmbeddings(input).mean(1);
// return torch::softmax(linear(embeddingsOfInput),1);
// }
//};
//TORCH_MODULE(Network);
//int main(int argc, char * argv[])
//{
// auto nn = Network();
// torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
// torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
// std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
// for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
// batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
//
// for (auto & batch : batches)
// {
// sparseOptimizer.zero_grad();
// denseOptimizer.zero_grad();
// auto prediction = nn(batch.first);
// auto loss = torch::nll_loss(torch::log(prediction), batch.second);
// loss.backward();
// sparseOptimizer.step();
// denseOptimizer.step();
// }
// return 0;
//}
int main(int argc, char * argv[])
fmt::print(stderr, "needs 4 arguments.\n");
exit(1);
}
at::init_num_threads();
std::string machineFile = argv[1];
std::string mcdFile = argv[2];
std::string tsvFile = argv[3];
//std::string rawFile = argv[4];
std::string rawFile = "";
ReadingMachine machine(machineFile);
BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
SubConfig config(goldConfig);
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
fmt::print("Generating dataset...\n");
Dict dict(Dict::State::Open);
while (true)
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
auto context = config.extractContext(5,5,dict);
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = 3;
auto gold = torch::zeros(1, at::kLong);
gold[0] = goldIndex;
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
if (config.needsUpdate())
config.update();
auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
int nbExamples = *dataset.size();
fmt::print("Done! size={}\n", nbExamples);
int batchSize = 100;
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
TestNetwork nn(machine.getTransitionSet().size(), 5);
torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-1).beta1(0.5));
torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-1).beta1(0.5));
for (int epoch = 1; epoch <= 2; ++epoch)
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
float totalLoss = 0.0;
float lossSoFar = 0.0;
torch::Tensor example;
int currentBatchNumber = 0;
for (auto & batch : *dataLoader)
{
denseOptimizer.zero_grad();
sparseOptimizer.zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = nn(data);
example = prediction[0];
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
loss.backward();
denseOptimizer.step();
sparseOptimizer.step();
if (++currentBatchNumber*batchSize % 1000 == 0)
{
fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar);
std::fflush(stdout);
lossSoFar = 0;
}
}
fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss);