diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index f056ce2a28d64fa4ec9a95cd60a47893b3a52904..a31fe21a3164818c792b142b563d5ac3b7ae2967 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -8,69 +8,6 @@ #include "TestNetwork.hpp" #include "ConfigDataset.hpp" -//constexpr int batchSize = 50; -//constexpr int nbExamples = 350000; -//constexpr int embeddingSize = 20; -//constexpr int nbClasses = 15; -//constexpr int nbWordsPerDatapoint = 5; -//constexpr int maxNbEmbeddings = 1000000; -// -//struct NetworkImpl : torch::nn::Module -//{ -// torch::nn::Linear linear{nullptr}; -// torch::nn::Embedding wordEmbeddings{nullptr}; -// -// std::vector<torch::Tensor> _sparseParameters; -// std::vector<torch::Tensor> _denseParameters; -// NetworkImpl() -// { -// linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses)); -// auto params = linear->parameters(); -// _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); -// -// wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true))); -// params = wordEmbeddings->parameters(); -// _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); -// }; -// const std::vector<torch::Tensor> & denseParameters() -// { -// return _denseParameters; -// } -// const std::vector<torch::Tensor> & sparseParameters() -// { -// return _sparseParameters; -// } -// torch::Tensor forward(const torch::Tensor & input) -// { -// // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words -// auto embeddingsOfInput = wordEmbeddings(input).mean(1); -// return torch::softmax(linear(embeddingsOfInput),1); -// } -//}; -//TORCH_MODULE(Network); - -//int main(int argc, char * argv[]) -//{ -// auto nn = Network(); -// torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5)); -// torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); -// std::vector<std::pair<torch::Tensor,torch::Tensor>> batches; -// for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch) -// batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong))); -// -// for (auto & batch : batches) -// { -// sparseOptimizer.zero_grad(); -// denseOptimizer.zero_grad(); -// auto prediction = nn(batch.first); -// auto loss = torch::nll_loss(torch::log(prediction), batch.second); -// loss.backward(); -// sparseOptimizer.step(); -// denseOptimizer.step(); -// } -// return 0; -//} - int main(int argc, char * argv[]) { if (argc != 5) @@ -110,7 +47,7 @@ int main(int argc, char * argv[]) auto context = config.extractContext(5,5,dict); contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); - int goldIndex = 3; + int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, at::kLong); gold[0] = goldIndex; @@ -140,10 +77,10 @@ int main(int argc, char * argv[]) auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); TestNetwork nn(machine.getTransitionSet().size(), 5); - torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-1).beta1(0.5)); - torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-1).beta1(0.5)); + torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)); + torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)); - for (int epoch = 1; epoch <= 2; ++epoch) + for (int epoch = 1; epoch <= 30; ++epoch) { float totalLoss = 0.0; float lossSoFar = 0.0; diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index 898ce837988ae39bac99e557599044225dd954d1..2b17c7a22358fb87256a0c65e572fd1aa13a0d21 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -18,6 +18,7 @@ class TransitionSet TransitionSet(const std::string & filename); std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); Transition * getBestAppliableTransition(const Config & c); + std::size_t getTransitionIndex(const Transition * transition) const; std::size_t size() const; }; diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 55b67a0425ea6245063356d8b2a9295b115a7c61..92bd4d0444d1a9f9a5d332fd2048530732058ee2 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -72,3 +72,11 @@ std::size_t TransitionSet::size() const return transitions.size(); } +std::size_t TransitionSet::getTransitionIndex(const Transition * transition) const +{ + if (!transition) + util::myThrow("transition is null"); + + return transition - &transitions[0]; +} +