Skip to content
Snippets Groups Projects
Commit 97b8d2fc authored by Franck Dary's avatar Franck Dary
Browse files

Removed distinction between dense and sparse parameters because it was hurting...

Removed distinction between dense and sparse parameters because it was hurting performances and the advantage in speed was not significant
parent 5b150102
No related branches found
No related tags found
No related merge requests found
......@@ -11,15 +11,10 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
std::vector<torch::Tensor> _denseParameters;
std::vector<torch::Tensor> _sparseParameters;
public :
ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override;
std::vector<torch::Tensor> & denseParameters() override;
std::vector<torch::Tensor> & sparseParameters() override;
};
#endif
......@@ -12,6 +12,7 @@ class NeuralNetworkImpl : public torch::nn::Module
int leftBorder{5};
int rightBorder{5};
int nbStackElements{2};
std::vector<std::string> columns{"FORM", "UPOS"};
protected :
......@@ -21,8 +22,6 @@ class NeuralNetworkImpl : public torch::nn::Module
public :
virtual std::vector<torch::Tensor> & denseParameters() = 0;
virtual std::vector<torch::Tensor> & sparseParameters() = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0;
std::vector<long> extractContext(Config & config, Dict & dict) const;
int getContextSize() const;
......
......@@ -11,15 +11,10 @@ class OneWordNetworkImpl : public NeuralNetworkImpl
torch::nn::Linear linear{nullptr};
int focusedIndex;
std::vector<torch::Tensor> _denseParameters;
std::vector<torch::Tensor> _sparseParameters;
public :
OneWordNetworkImpl(int nbOutputs, int focusedIndex);
torch::Tensor forward(torch::Tensor input) override;
std::vector<torch::Tensor> & denseParameters() override;
std::vector<torch::Tensor> & sparseParameters() override;
};
#endif
......@@ -7,25 +7,9 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(false)));
auto params = wordEmbeddings->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(true)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
params = linear1->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
params = linear2->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
}
std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters()
{
return _denseParameters;
}
std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters()
{
return _sparseParameters;
}
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
......
......@@ -3,13 +3,14 @@
std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::stack<int> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index)
if (config.isToken(index))
leftContext.push(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index)));
for (auto & column : columns)
leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index)));
std::vector<long> context;
while ((int)context.size() < leftBorder-(int)leftContext.size())
while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size()))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (!leftContext.empty())
{
......@@ -17,16 +18,18 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict
leftContext.pop();
}
for (int index = config.getWordIndex(); config.has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index)
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index)
if (config.isToken(index))
context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index)));
for (auto & column : columns)
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index)));
while ((int)context.size() < leftBorder+rightBorder+1)
while (context.size() < columns.size()*(leftBorder+rightBorder+1))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i < nbStackElements; i++)
for (auto & column : columns)
if (config.hasStack(i))
context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", config.getStack(i))));
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i))));
else
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
......@@ -35,7 +38,7 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict
int NeuralNetworkImpl::getContextSize() const
{
return 1 + leftBorder + rightBorder + nbStackElements;
return columns.size()*(1 + leftBorder + rightBorder + nbStackElements);
}
void NeuralNetworkImpl::setRightBorder(int rightBorder)
......
......@@ -5,12 +5,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
constexpr int embeddingsSize = 30;
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
auto params = wordEmbeddings->parameters();
_sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
params = linear->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
int leftBorder = 0;
int rightBorder = 0;
......@@ -26,16 +21,6 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
setNbStackElements(0);
}
std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters()
{
return _denseParameters;
}
std::vector<torch::Tensor> & OneWordNetworkImpl::sparseParameters()
{
return _sparseParameters;
}
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
......
......@@ -16,8 +16,7 @@ class Trainer
ReadingMachine & machine;
DataLoader dataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> denseOptimizer;
std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0};
int batchSize{100};
int nbExamples{0};
......
......@@ -58,8 +58,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(1e-2)));
}
float Trainer::epoch(bool printAdvancement)
......@@ -74,8 +73,7 @@ float Trainer::epoch(bool printAdvancement)
for (auto & batch : *dataLoader)
{
denseOptimizer->zero_grad();
sparseOptimizer->zero_grad();
optimizer->zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
......@@ -90,8 +88,7 @@ float Trainer::epoch(bool printAdvancement)
} catch(std::exception & e) {util::myThrow(e.what());}
loss.backward();
denseOptimizer->step();
sparseOptimizer->step();
optimizer->step();
if (printAdvancement)
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment