Skip to content
Snippets Groups Projects
Commit 97b8d2fc authored by Franck Dary's avatar Franck Dary
Browse files

Removed distinction between dense and sparse parameters because it was hurting...

Removed distinction between dense and sparse parameters because it was hurting performances and the advantage in speed was not significant
parent 5b150102
No related branches found
No related tags found
No related merge requests found
...@@ -11,15 +11,10 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl ...@@ -11,15 +11,10 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl
torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr}; torch::nn::Linear linear2{nullptr};
std::vector<torch::Tensor> _denseParameters;
std::vector<torch::Tensor> _sparseParameters;
public : public :
ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<torch::Tensor> & denseParameters() override;
std::vector<torch::Tensor> & sparseParameters() override;
}; };
#endif #endif
...@@ -12,6 +12,7 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -12,6 +12,7 @@ class NeuralNetworkImpl : public torch::nn::Module
int leftBorder{5}; int leftBorder{5};
int rightBorder{5}; int rightBorder{5};
int nbStackElements{2}; int nbStackElements{2};
std::vector<std::string> columns{"FORM", "UPOS"};
protected : protected :
...@@ -21,8 +22,6 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -21,8 +22,6 @@ class NeuralNetworkImpl : public torch::nn::Module
public : public :
virtual std::vector<torch::Tensor> & denseParameters() = 0;
virtual std::vector<torch::Tensor> & sparseParameters() = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
std::vector<long> extractContext(Config & config, Dict & dict) const; std::vector<long> extractContext(Config & config, Dict & dict) const;
int getContextSize() const; int getContextSize() const;
......
...@@ -11,15 +11,10 @@ class OneWordNetworkImpl : public NeuralNetworkImpl ...@@ -11,15 +11,10 @@ class OneWordNetworkImpl : public NeuralNetworkImpl
torch::nn::Linear linear{nullptr}; torch::nn::Linear linear{nullptr};
int focusedIndex; int focusedIndex;
std::vector<torch::Tensor> _denseParameters;
std::vector<torch::Tensor> _sparseParameters;
public : public :
OneWordNetworkImpl(int nbOutputs, int focusedIndex); OneWordNetworkImpl(int nbOutputs, int focusedIndex);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<torch::Tensor> & denseParameters() override;
std::vector<torch::Tensor> & sparseParameters() override;
}; };
#endif #endif
...@@ -7,25 +7,9 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in ...@@ -7,25 +7,9 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in
setRightBorder(rightBorder); setRightBorder(rightBorder);
setNbStackElements(nbStackElements); setNbStackElements(nbStackElements);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(false))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(true)));
auto params = wordEmbeddings->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
params = linear1->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs)); linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
params = linear2->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
}
std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters()
{
return _denseParameters;
}
std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters()
{
return _sparseParameters;
} }
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
......
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{ {
std::stack<int> leftContext; std::stack<int> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index) for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index)
if (config.isToken(index)) if (config.isToken(index))
leftContext.push(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index))); for (auto & column : columns)
leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index)));
std::vector<long> context; std::vector<long> context;
while ((int)context.size() < leftBorder-(int)leftContext.size()) while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size()))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (!leftContext.empty()) while (!leftContext.empty())
{ {
...@@ -17,25 +18,27 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict ...@@ -17,25 +18,27 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict
leftContext.pop(); leftContext.pop();
} }
for (int index = config.getWordIndex(); config.has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index) for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index)
if (config.isToken(index)) if (config.isToken(index))
context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index))); for (auto & column : columns)
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index)));
while ((int)context.size() < leftBorder+rightBorder+1) while (context.size() < columns.size()*(leftBorder+rightBorder+1))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i < nbStackElements; i++) for (int i = 0; i < nbStackElements; i++)
if (config.hasStack(i)) for (auto & column : columns)
context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", config.getStack(i)))); if (config.hasStack(i))
else context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i))));
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); else
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
return context; return context;
} }
int NeuralNetworkImpl::getContextSize() const int NeuralNetworkImpl::getContextSize() const
{ {
return 1 + leftBorder + rightBorder + nbStackElements; return columns.size()*(1 + leftBorder + rightBorder + nbStackElements);
} }
void NeuralNetworkImpl::setRightBorder(int rightBorder) void NeuralNetworkImpl::setRightBorder(int rightBorder)
......
...@@ -5,12 +5,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) ...@@ -5,12 +5,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
constexpr int embeddingsSize = 30; constexpr int embeddingsSize = 30;
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
auto params = wordEmbeddings->parameters();
_sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
params = linear->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
int leftBorder = 0; int leftBorder = 0;
int rightBorder = 0; int rightBorder = 0;
...@@ -26,16 +21,6 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) ...@@ -26,16 +21,6 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
setNbStackElements(0); setNbStackElements(0);
} }
std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters()
{
return _denseParameters;
}
std::vector<torch::Tensor> & OneWordNetworkImpl::sparseParameters()
{
return _sparseParameters;
}
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
{ {
// input dim = {batch, sequence, embeddings} // input dim = {batch, sequence, embeddings}
......
...@@ -16,8 +16,7 @@ class Trainer ...@@ -16,8 +16,7 @@ class Trainer
ReadingMachine & machine; ReadingMachine & machine;
DataLoader dataLoader{nullptr}; DataLoader dataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> denseOptimizer; std::unique_ptr<torch::optim::Adam> optimizer;
std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
std::size_t epochNumber{0}; std::size_t epochNumber{0};
int batchSize{100}; int batchSize{100};
int nbExamples{0}; int nbExamples{0};
......
...@@ -58,8 +58,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -58,8 +58,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(1e-2)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
} }
float Trainer::epoch(bool printAdvancement) float Trainer::epoch(bool printAdvancement)
...@@ -74,8 +73,7 @@ float Trainer::epoch(bool printAdvancement) ...@@ -74,8 +73,7 @@ float Trainer::epoch(bool printAdvancement)
for (auto & batch : *dataLoader) for (auto & batch : *dataLoader)
{ {
denseOptimizer->zero_grad(); optimizer->zero_grad();
sparseOptimizer->zero_grad();
auto data = batch.data; auto data = batch.data;
auto labels = batch.target.squeeze(); auto labels = batch.target.squeeze();
...@@ -90,8 +88,7 @@ float Trainer::epoch(bool printAdvancement) ...@@ -90,8 +88,7 @@ float Trainer::epoch(bool printAdvancement)
} catch(std::exception & e) {util::myThrow(e.what());} } catch(std::exception & e) {util::myThrow(e.what());}
loss.backward(); loss.backward();
denseOptimizer->step(); optimizer->step();
sparseOptimizer->step();
if (printAdvancement) if (printAdvancement)
{ {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment