Skip to content
Snippets Groups Projects
Commit d3ecc26c authored by Franck Dary's avatar Franck Dary
Browse files

Working version with SparseAdam

parent 92e9fda7
No related branches found
No related tags found
No related merge requests found
......@@ -8,155 +8,177 @@
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
constexpr int batchSize = 50;
constexpr int nbExamples = 350000;
constexpr int embeddingSize = 20;
constexpr int nbClasses = 15;
constexpr int nbWordsPerDatapoint = 5;
constexpr int maxNbEmbeddings = 1000000;
//3m15s
struct NetworkImpl : torch::nn::Module
//constexpr int batchSize = 50;
//constexpr int nbExamples = 350000;
//constexpr int embeddingSize = 20;
//constexpr int nbClasses = 15;
//constexpr int nbWordsPerDatapoint = 5;
//constexpr int maxNbEmbeddings = 1000000;
//
//struct NetworkImpl : torch::nn::Module
//{
// torch::nn::Linear linear{nullptr};
// torch::nn::Embedding wordEmbeddings{nullptr};
//
// std::vector<torch::Tensor> _sparseParameters;
// std::vector<torch::Tensor> _denseParameters;
// NetworkImpl()
// {
// linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
// auto params = linear->parameters();
// _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
//
// wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
// params = wordEmbeddings->parameters();
// _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
// };
// const std::vector<torch::Tensor> & denseParameters()
// {
// return _denseParameters;
// }
// const std::vector<torch::Tensor> & sparseParameters()
// {
// return _sparseParameters;
// }
// torch::Tensor forward(const torch::Tensor & input)
// {
// // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
// auto embeddingsOfInput = wordEmbeddings(input).mean(1);
// return torch::softmax(linear(embeddingsOfInput),1);
// }
//};
//TORCH_MODULE(Network);
//int main(int argc, char * argv[])
//{
// auto nn = Network();
// torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
// torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
// std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
// for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
// batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
//
// for (auto & batch : batches)
// {
// sparseOptimizer.zero_grad();
// denseOptimizer.zero_grad();
// auto prediction = nn(batch.first);
// auto loss = torch::nll_loss(torch::log(prediction), batch.second);
// loss.backward();
// sparseOptimizer.step();
// denseOptimizer.step();
// }
// return 0;
//}
int main(int argc, char * argv[])
{
torch::nn::Linear linear{nullptr};
torch::nn::Embedding wordEmbeddings{nullptr};
NetworkImpl()
if (argc != 5)
{
linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
};
torch::Tensor forward(const torch::Tensor & input)
fmt::print(stderr, "needs 4 arguments.\n");
exit(1);
}
at::init_num_threads();
std::string machineFile = argv[1];
std::string mcdFile = argv[2];
std::string tsvFile = argv[3];
//std::string rawFile = argv[4];
std::string rawFile = "";
ReadingMachine machine(machineFile);
BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
SubConfig config(goldConfig);
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
fmt::print("Generating dataset...\n");
Dict dict(Dict::State::Open);
while (true)
{
// I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
auto embeddingsOfInput = wordEmbeddings(input).mean(1);
return torch::softmax(linear(embeddingsOfInput),1);
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
auto context = config.extractContext(5,5,dict);
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = 3;
auto gold = torch::zeros(1, at::kLong);
gold[0] = goldIndex;
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
if (config.needsUpdate())
config.update();
}
};
TORCH_MODULE(Network);
int main(int argc, char * argv[])
{
auto nn = Network();
torch::optim::SparseAdam sparseOptimizer(nn->parameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
torch::optim::Adam denseOptimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
for (auto & batch : batches)
auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
int nbExamples = *dataset.size();
fmt::print("Done! size={}\n", nbExamples);
int batchSize = 100;
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
TestNetwork nn(machine.getTransitionSet().size(), 5);
torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-1).beta1(0.5));
torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-1).beta1(0.5));
for (int epoch = 1; epoch <= 2; ++epoch)
{
sparseOptimizer.zero_grad();
denseOptimizer.zero_grad();
auto prediction = nn(batch.first);
auto loss = torch::nll_loss(torch::log(prediction), batch.second);
loss.backward();
sparseOptimizer.step();
denseOptimizer.step();
float totalLoss = 0.0;
float lossSoFar = 0.0;
torch::Tensor example;
int currentBatchNumber = 0;
for (auto & batch : *dataLoader)
{
denseOptimizer.zero_grad();
sparseOptimizer.zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = nn(data);
example = prediction[0];
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
loss.backward();
denseOptimizer.step();
sparseOptimizer.step();
if (++currentBatchNumber*batchSize % 1000 == 0)
{
fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar);
std::fflush(stdout);
lossSoFar = 0;
}
}
fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss);
}
return 0;
}
//int main(int argc, char * argv[])
//{
// if (argc != 5)
// {
// fmt::print(stderr, "needs 4 arguments.\n");
// exit(1);
// }
//
// at::init_num_threads();
//
// std::string machineFile = argv[1];
// std::string mcdFile = argv[2];
// std::string tsvFile = argv[3];
// //std::string rawFile = argv[4];
// std::string rawFile = "";
//
// ReadingMachine machine(machineFile);
//
// BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
// SubConfig config(goldConfig);
//
// config.setState(machine.getStrategy().getInitialState());
//
// std::vector<torch::Tensor> contexts;
// std::vector<torch::Tensor> classes;
//
// fmt::print("Generating dataset...\n");
//
// Dict dict(Dict::State::Open);
//
// while (true)
// {
// auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
// if (!transition)
// util::myThrow("No transition appliable !");
//
// auto context = config.extractContext(5,5,dict);
// contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
//
// int goldIndex = 3;
// auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
//
// classes.emplace_back(gold);
//
// transition->apply(config);
// config.addToHistory(transition->getName());
//
// auto movement = machine.getStrategy().getMovement(config, transition->getName());
// if (movement == Strategy::endMovement)
// break;
//
// config.setState(movement.first);
// if (!config.moveWordIndex(movement.second))
// util::myThrow("Cannot move word index !");
//
// if (config.needsUpdate())
// config.update();
// }
//
// auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
//
// int nbExamples = *dataset.size();
// fmt::print("Done! size={}\n", nbExamples);
//
// int batchSize = 100;
// auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
//
// TestNetwork nn(machine.getTransitionSet().size(), 5);
// torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
//
// for (int epoch = 1; epoch <= 1; ++epoch)
// {
// float totalLoss = 0.0;
// torch::Tensor example;
// int currentBatchNumber = 0;
//
// for (auto & batch : *dataLoader)
// {
// optimizer.zero_grad();
//
// auto data = batch.data;
// auto labels = batch.target.squeeze();
//
// auto prediction = nn(data);
// example = prediction[0];
//
// auto loss = torch::nll_loss(torch::log(prediction), labels);
// totalLoss += loss.item<float>();
// loss.backward();
// optimizer.step();
//
// if (++currentBatchNumber*batchSize % 1000 == 0)
// {
// fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
// std::fflush(stdout);
// }
// }
//
// fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
// }
//
// return 0;
//}
//
......@@ -12,10 +12,15 @@ class TestNetworkImpl : public torch::nn::Module
torch::nn::Linear linear{nullptr};
int focusedIndex;
std::vector<torch::Tensor> _denseParameters;
std::vector<torch::Tensor> _sparseParameters;
public :
TestNetworkImpl(int nbOutputs, int focusedIndex);
torch::Tensor forward(torch::Tensor input);
std::vector<torch::Tensor> & denseParameters();
std::vector<torch::Tensor> & sparseParameters();
};
TORCH_MODULE(TestNetwork);
......
......@@ -3,11 +3,28 @@
TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
{
constexpr int embeddingsSize = 30;
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
auto params = wordEmbeddings->parameters();
_sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
params = linear->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
this->focusedIndex = focusedIndex;
}
std::vector<torch::Tensor> & TestNetworkImpl::denseParameters()
{
return _denseParameters;
}
std::vector<torch::Tensor> & TestNetworkImpl::sparseParameters()
{
return _sparseParameters;
}
torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment