Skip to content
Snippets Groups Projects
Commit ce420cb5 authored by Franck Dary's avatar Franck Dary
Browse files

Added function to get index of a transition

parent d3ecc26c
No related branches found
No related tags found
No related merge requests found
...@@ -8,69 +8,6 @@ ...@@ -8,69 +8,6 @@
#include "TestNetwork.hpp" #include "TestNetwork.hpp"
#include "ConfigDataset.hpp" #include "ConfigDataset.hpp"
//constexpr int batchSize = 50;
//constexpr int nbExamples = 350000;
//constexpr int embeddingSize = 20;
//constexpr int nbClasses = 15;
//constexpr int nbWordsPerDatapoint = 5;
//constexpr int maxNbEmbeddings = 1000000;
//
//struct NetworkImpl : torch::nn::Module
//{
// torch::nn::Linear linear{nullptr};
// torch::nn::Embedding wordEmbeddings{nullptr};
//
// std::vector<torch::Tensor> _sparseParameters;
// std::vector<torch::Tensor> _denseParameters;
// NetworkImpl()
// {
// linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
// auto params = linear->parameters();
// _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
//
// wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
// params = wordEmbeddings->parameters();
// _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
// };
// const std::vector<torch::Tensor> & denseParameters()
// {
// return _denseParameters;
// }
// const std::vector<torch::Tensor> & sparseParameters()
// {
// return _sparseParameters;
// }
// torch::Tensor forward(const torch::Tensor & input)
// {
// // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
// auto embeddingsOfInput = wordEmbeddings(input).mean(1);
// return torch::softmax(linear(embeddingsOfInput),1);
// }
//};
//TORCH_MODULE(Network);
//int main(int argc, char * argv[])
//{
// auto nn = Network();
// torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
// torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
// std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
// for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
// batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
//
// for (auto & batch : batches)
// {
// sparseOptimizer.zero_grad();
// denseOptimizer.zero_grad();
// auto prediction = nn(batch.first);
// auto loss = torch::nll_loss(torch::log(prediction), batch.second);
// loss.backward();
// sparseOptimizer.step();
// denseOptimizer.step();
// }
// return 0;
//}
int main(int argc, char * argv[]) int main(int argc, char * argv[])
{ {
if (argc != 5) if (argc != 5)
...@@ -110,7 +47,7 @@ int main(int argc, char * argv[]) ...@@ -110,7 +47,7 @@ int main(int argc, char * argv[])
auto context = config.extractContext(5,5,dict); auto context = config.extractContext(5,5,dict);
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = 3; int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong); auto gold = torch::zeros(1, at::kLong);
gold[0] = goldIndex; gold[0] = goldIndex;
...@@ -140,10 +77,10 @@ int main(int argc, char * argv[]) ...@@ -140,10 +77,10 @@ int main(int argc, char * argv[])
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
TestNetwork nn(machine.getTransitionSet().size(), 5); TestNetwork nn(machine.getTransitionSet().size(), 5);
torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-1).beta1(0.5)); torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5));
torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-1).beta1(0.5)); torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5));
for (int epoch = 1; epoch <= 2; ++epoch) for (int epoch = 1; epoch <= 30; ++epoch)
{ {
float totalLoss = 0.0; float totalLoss = 0.0;
float lossSoFar = 0.0; float lossSoFar = 0.0;
......
...@@ -18,6 +18,7 @@ class TransitionSet ...@@ -18,6 +18,7 @@ class TransitionSet
TransitionSet(const std::string & filename); TransitionSet(const std::string & filename);
std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c);
Transition * getBestAppliableTransition(const Config & c); Transition * getBestAppliableTransition(const Config & c);
std::size_t getTransitionIndex(const Transition * transition) const;
std::size_t size() const; std::size_t size() const;
}; };
......
...@@ -72,3 +72,11 @@ std::size_t TransitionSet::size() const ...@@ -72,3 +72,11 @@ std::size_t TransitionSet::size() const
return transitions.size(); return transitions.size();
} }
std::size_t TransitionSet::getTransitionIndex(const Transition * transition) const
{
if (!transition)
util::myThrow("transition is null");
return transition - &transitions[0];
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment