Skip to content
Snippets Groups Projects
Commit 096b59d9 authored by Franck Dary's avatar Franck Dary
Browse files

Changed extract context

parent 82827388
No related branches found
No related tags found
No related merge requests found
...@@ -27,9 +27,12 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) ...@@ -27,9 +27,12 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
int chosenTransition = -1; int chosenTransition = -1;
for (unsigned int i = 0; i < prediction.size(0); i++) try
if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config)) {
chosenTransition = i; for (unsigned int i = 0; i < prediction.size(0); i++)
if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config))
chosenTransition = i;
} catch(std::exception & e) {util::myThrow(e.what());}
if (chosenTransition == -1) if (chosenTransition == -1)
util::myThrow("No transition appliable !"); util::myThrow("No transition appliable !");
......
...@@ -30,7 +30,7 @@ void Classifier::initNeuralNetwork(const std::string & topology) ...@@ -30,7 +30,7 @@ void Classifier::initNeuralNetwork(const std::string & topology)
static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
{ {
{ {
std::regex("OneWord\\((\\d+)\\)"), std::regex("OneWord\\(([+\\-]?\\d+)\\)"),
"OneWord(focusedIndex) : Only use the word embedding of the focused word.", "OneWord(focusedIndex) : Only use the word embedding of the focused word.",
[this,topology](auto sm) [this,topology](auto sm)
{ {
...@@ -38,11 +38,11 @@ void Classifier::initNeuralNetwork(const std::string & topology) ...@@ -38,11 +38,11 @@ void Classifier::initNeuralNetwork(const std::string & topology)
} }
}, },
{ {
std::regex("ConcatWords"), std::regex("ConcatWords\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"ConcatWords : Concatenate embeddings of words in context.", "ConcatWords(leftBorder,rightBorder,nbStack) : Concatenate embeddings of words in context.",
[this,topology](auto) [this,topology](auto sm)
{ {
this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size())); this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
} }
}, },
}; };
......
...@@ -8,14 +8,15 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl ...@@ -8,14 +8,15 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl
private : private :
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear{nullptr}; torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
std::vector<torch::Tensor> _denseParameters; std::vector<torch::Tensor> _denseParameters;
std::vector<torch::Tensor> _sparseParameters; std::vector<torch::Tensor> _sparseParameters;
public : public :
ConcatWordsNetworkImpl(int nbOutputs); ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<torch::Tensor> & denseParameters() override; std::vector<torch::Tensor> & denseParameters() override;
std::vector<torch::Tensor> & sparseParameters() override; std::vector<torch::Tensor> & sparseParameters() override;
......
...@@ -11,6 +11,13 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -11,6 +11,13 @@ class NeuralNetworkImpl : public torch::nn::Module
int leftBorder{5}; int leftBorder{5};
int rightBorder{5}; int rightBorder{5};
int nbStackElements{2};
protected :
void setRightBorder(int rightBorder);
void setLeftBorder(int leftBorder);
void setNbStackElements(int nbStackElements);
public : public :
......
#include "ConcatWordsNetwork.hpp" #include "ConcatWordsNetwork.hpp"
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs) ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{ {
constexpr int embeddingsSize = 30; constexpr int embeddingsSize = 100;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(false)));
auto params = wordEmbeddings->parameters(); auto params = wordEmbeddings->parameters();
_sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs)); params = linear1->parameters();
params = linear->parameters(); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
params = linear2->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
} }
...@@ -30,7 +35,7 @@ torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) ...@@ -30,7 +35,7 @@ torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
// reshaped dim = {batch, sequence of embeddings} // reshaped dim = {batch, sequence of embeddings}
auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
auto res = torch::softmax(linear(reshaped), reshaped.dim() == 2 ? 1 : 0); auto res = torch::softmax(linear2(torch::relu(linear1(reshaped))), reshaped.dim() == 2 ? 1 : 0);
return res; return res;
} }
......
...@@ -24,11 +24,32 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict ...@@ -24,11 +24,32 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict
while ((int)context.size() < leftBorder+rightBorder+1) while ((int)context.size() < leftBorder+rightBorder+1)
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i < nbStackElements; i++)
if (config.hasStack(i))
context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", config.getStack(i))));
else
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
return context; return context;
} }
int NeuralNetworkImpl::getContextSize() const int NeuralNetworkImpl::getContextSize() const
{ {
return 1 + leftBorder + rightBorder; return 1 + leftBorder + rightBorder + nbStackElements;
}
void NeuralNetworkImpl::setRightBorder(int rightBorder)
{
this->rightBorder = rightBorder;
}
void NeuralNetworkImpl::setLeftBorder(int leftBorder)
{
this->leftBorder = leftBorder;
}
void NeuralNetworkImpl::setNbStackElements(int nbStackElements)
{
this->nbStackElements = nbStackElements;
} }
...@@ -12,7 +12,18 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) ...@@ -12,7 +12,18 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
params = linear->parameters(); params = linear->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
this->focusedIndex = focusedIndex; int leftBorder = 0;
int rightBorder = 0;
if (focusedIndex < 0)
leftBorder = -focusedIndex;
if (focusedIndex > 0)
rightBorder = focusedIndex;
this->focusedIndex = focusedIndex <= 0 ? 0 : focusedIndex;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(0);
} }
std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters() std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters()
......
...@@ -58,8 +58,8 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -58,8 +58,8 @@ void Trainer::createDataset(SubConfig & config, bool debug)
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)));
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5)));
} }
float Trainer::epoch(bool printAdvancement) float Trainer::epoch(bool printAdvancement)
...@@ -81,8 +81,11 @@ float Trainer::epoch(bool printAdvancement) ...@@ -81,8 +81,11 @@ float Trainer::epoch(bool printAdvancement)
auto prediction = machine.getClassifier()->getNN()(data); auto prediction = machine.getClassifier()->getNN()(data);
auto loss = torch::nll_loss(torch::log(prediction), labels); auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>(); try
lossSoFar += loss.item<float>(); {
totalLoss += loss.item<float>();
lossSoFar += loss.item<float>();
} catch(std::exception & e) {util::myThrow(e.what());}
loss.backward(); loss.backward();
denseOptimizer->step(); denseOptimizer->step();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment