From 97b8d2fcb0b7d75a486f7d1d29b590060536b6fa Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 23 Feb 2020 17:20:31 +0100 Subject: [PATCH] Removed distinction between dense and sparse parameters because it was hurting performances and the advantage in speed was not significant --- torch_modules/include/ConcatWordsNetwork.hpp | 5 ---- torch_modules/include/NeuralNetwork.hpp | 3 +-- torch_modules/include/OneWordNetwork.hpp | 5 ---- torch_modules/src/ConcatWordsNetwork.cpp | 18 +------------- torch_modules/src/NeuralNetwork.cpp | 25 +++++++++++--------- torch_modules/src/OneWordNetwork.cpp | 15 ------------ trainer/include/Trainer.hpp | 3 +-- trainer/src/Trainer.cpp | 9 +++---- 8 files changed, 20 insertions(+), 63 deletions(-) diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp index 4dd7aa3..064a00e 100644 --- a/torch_modules/include/ConcatWordsNetwork.hpp +++ b/torch_modules/include/ConcatWordsNetwork.hpp @@ -11,15 +11,10 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; - std::vector<torch::Tensor> _denseParameters; - std::vector<torch::Tensor> _sparseParameters; - public : ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); torch::Tensor forward(torch::Tensor input) override; - std::vector<torch::Tensor> & denseParameters() override; - std::vector<torch::Tensor> & sparseParameters() override; }; #endif diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 2683122..5299d2d 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -12,6 +12,7 @@ class NeuralNetworkImpl : public torch::nn::Module int leftBorder{5}; int rightBorder{5}; int nbStackElements{2}; + std::vector<std::string> columns{"FORM", "UPOS"}; protected : @@ -21,8 +22,6 @@ class NeuralNetworkImpl : public torch::nn::Module public : - virtual std::vector<torch::Tensor> & denseParameters() = 0; - virtual std::vector<torch::Tensor> & sparseParameters() = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; std::vector<long> extractContext(Config & config, Dict & dict) const; int getContextSize() const; diff --git a/torch_modules/include/OneWordNetwork.hpp b/torch_modules/include/OneWordNetwork.hpp index 29edb7d..b4ad475 100644 --- a/torch_modules/include/OneWordNetwork.hpp +++ b/torch_modules/include/OneWordNetwork.hpp @@ -11,15 +11,10 @@ class OneWordNetworkImpl : public NeuralNetworkImpl torch::nn::Linear linear{nullptr}; int focusedIndex; - std::vector<torch::Tensor> _denseParameters; - std::vector<torch::Tensor> _sparseParameters; - public : OneWordNetworkImpl(int nbOutputs, int focusedIndex); torch::Tensor forward(torch::Tensor input) override; - std::vector<torch::Tensor> & denseParameters() override; - std::vector<torch::Tensor> & sparseParameters() override; }; #endif diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index 4c1e366..9d7d0c4 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -7,25 +7,9 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in setRightBorder(rightBorder); setNbStackElements(nbStackElements); - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(false))); - auto params = wordEmbeddings->parameters(); - _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(true))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); - params = linear1->parameters(); - _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs)); - params = linear2->parameters(); - _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); -} - -std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters() -{ - return _denseParameters; -} - -std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters() -{ - return _sparseParameters; } torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 215fda5..e39729b 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -3,13 +3,14 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const { std::stack<int> leftContext; - for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index) + for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index) if (config.isToken(index)) - leftContext.push(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index))); + for (auto & column : columns) + leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index))); std::vector<long> context; - while ((int)context.size() < leftBorder-(int)leftContext.size()) + while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size())) context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); while (!leftContext.empty()) { @@ -17,25 +18,27 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict leftContext.pop(); } - for (int index = config.getWordIndex(); config.has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index) + for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index) if (config.isToken(index)) - context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index))); + for (auto & column : columns) + context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index))); - while ((int)context.size() < leftBorder+rightBorder+1) + while (context.size() < columns.size()*(leftBorder+rightBorder+1)) context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); for (int i = 0; i < nbStackElements; i++) - if (config.hasStack(i)) - context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", config.getStack(i)))); - else - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + for (auto & column : columns) + if (config.hasStack(i)) + context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i)))); + else + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); return context; } int NeuralNetworkImpl::getContextSize() const { - return 1 + leftBorder + rightBorder + nbStackElements; + return columns.size()*(1 + leftBorder + rightBorder + nbStackElements); } void NeuralNetworkImpl::setRightBorder(int rightBorder) diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp index c054e6d..c2a11db 100644 --- a/torch_modules/src/OneWordNetwork.cpp +++ b/torch_modules/src/OneWordNetwork.cpp @@ -5,12 +5,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) constexpr int embeddingsSize = 30; wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); - auto params = wordEmbeddings->parameters(); - _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); - linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); - params = linear->parameters(); - _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); int leftBorder = 0; int rightBorder = 0; @@ -26,16 +21,6 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) setNbStackElements(0); } -std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters() -{ - return _denseParameters; -} - -std::vector<torch::Tensor> & OneWordNetworkImpl::sparseParameters() -{ - return _sparseParameters; -} - torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 0f9c3ec..a63c977 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -16,8 +16,7 @@ class Trainer ReadingMachine & machine; DataLoader dataLoader{nullptr}; - std::unique_ptr<torch::optim::Adam> denseOptimizer; - std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer; + std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; int batchSize{100}; int nbExamples{0}; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index a439c88..6c2222c 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -58,8 +58,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); - sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); + optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(1e-2))); } float Trainer::epoch(bool printAdvancement) @@ -74,8 +73,7 @@ float Trainer::epoch(bool printAdvancement) for (auto & batch : *dataLoader) { - denseOptimizer->zero_grad(); - sparseOptimizer->zero_grad(); + optimizer->zero_grad(); auto data = batch.data; auto labels = batch.target.squeeze(); @@ -90,8 +88,7 @@ float Trainer::epoch(bool printAdvancement) } catch(std::exception & e) {util::myThrow(e.what());} loss.backward(); - denseOptimizer->step(); - sparseOptimizer->step(); + optimizer->step(); if (printAdvancement) { -- GitLab