diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index deddf3fd6e38ac04941e2d985ca7fba821e8aa29..2996b796806af3a633a72bc182e2da673419578e 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -27,9 +27,12 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) int chosenTransition = -1; - for (unsigned int i = 0; i < prediction.size(0); i++) - if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config)) - chosenTransition = i; + try + { + for (unsigned int i = 0; i < prediction.size(0); i++) + if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config)) + chosenTransition = i; + } catch(std::exception & e) {util::myThrow(e.what());} if (chosenTransition == -1) util::myThrow("No transition appliable !"); diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 6ddd264c6a3602d829a072313989ead738560403..242b3caced353589522b4107e41859c4bf928231 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -30,7 +30,7 @@ void Classifier::initNeuralNetwork(const std::string & topology) static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers { { - std::regex("OneWord\\((\\d+)\\)"), + std::regex("OneWord\\(([+\\-]?\\d+)\\)"), "OneWord(focusedIndex) : Only use the word embedding of the focused word.", [this,topology](auto sm) { @@ -38,11 +38,11 @@ void Classifier::initNeuralNetwork(const std::string & topology) } }, { - std::regex("ConcatWords"), - "ConcatWords : Concatenate embeddings of words in context.", - [this,topology](auto) + std::regex("ConcatWords\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), + "ConcatWords(leftBorder,rightBorder,nbStack) : Concatenate embeddings of words in context.", + [this,topology](auto sm) { - this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size())); + this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); } }, }; diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp index 4f67b3ac29ae8cb5253548d20e6eed4b2c30f5dd..4dd7aa3a4913fddd66f05936a3bf715abc2c1b00 100644 --- a/torch_modules/include/ConcatWordsNetwork.hpp +++ b/torch_modules/include/ConcatWordsNetwork.hpp @@ -8,14 +8,15 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl private : torch::nn::Embedding wordEmbeddings{nullptr}; - torch::nn::Linear linear{nullptr}; + torch::nn::Linear linear1{nullptr}; + torch::nn::Linear linear2{nullptr}; std::vector<torch::Tensor> _denseParameters; std::vector<torch::Tensor> _sparseParameters; public : - ConcatWordsNetworkImpl(int nbOutputs); + ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); torch::Tensor forward(torch::Tensor input) override; std::vector<torch::Tensor> & denseParameters() override; std::vector<torch::Tensor> & sparseParameters() override; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 18462190a7bf31ce4294dce41e510d8f0f695f65..268312211f419f8c066e264243e96ee3455ba6d9 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -11,6 +11,13 @@ class NeuralNetworkImpl : public torch::nn::Module int leftBorder{5}; int rightBorder{5}; + int nbStackElements{2}; + + protected : + + void setRightBorder(int rightBorder); + void setLeftBorder(int leftBorder); + void setNbStackElements(int nbStackElements); public : diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index 1343d53add153012b463dfb682c61d624a160e98..49619151bac810cdd90de910312caf6aede85166 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -1,15 +1,20 @@ #include "ConcatWordsNetwork.hpp" -ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs) +ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) { - constexpr int embeddingsSize = 30; + constexpr int embeddingsSize = 100; + setLeftBorder(leftBorder); + setRightBorder(rightBorder); + setNbStackElements(nbStackElements); - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(false))); auto params = wordEmbeddings->parameters(); - _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); - - linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs)); - params = linear->parameters(); + _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); + linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); + params = linear1->parameters(); + _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); + linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs)); + params = linear2->parameters(); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); } @@ -30,7 +35,7 @@ torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) // reshaped dim = {batch, sequence of embeddings} auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); - auto res = torch::softmax(linear(reshaped), reshaped.dim() == 2 ? 1 : 0); + auto res = torch::softmax(linear2(torch::relu(linear1(reshaped))), reshaped.dim() == 2 ? 1 : 0); return res; } diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index ab8921eb32a69a7852403016e20d043231c30e69..215fda5b5d032b6f95bf612ae088b83629f67d4f 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -24,11 +24,32 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict while ((int)context.size() < leftBorder+rightBorder+1) context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + for (int i = 0; i < nbStackElements; i++) + if (config.hasStack(i)) + context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", config.getStack(i)))); + else + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + return context; } int NeuralNetworkImpl::getContextSize() const { - return 1 + leftBorder + rightBorder; + return 1 + leftBorder + rightBorder + nbStackElements; +} + +void NeuralNetworkImpl::setRightBorder(int rightBorder) +{ + this->rightBorder = rightBorder; +} + +void NeuralNetworkImpl::setLeftBorder(int leftBorder) +{ + this->leftBorder = leftBorder; +} + +void NeuralNetworkImpl::setNbStackElements(int nbStackElements) +{ + this->nbStackElements = nbStackElements; } diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp index 5cfa4f77e29b85f47bfd55ad65248da5c93ac12d..1d9b3869304f48196a610ac085af0c603d69c764 100644 --- a/torch_modules/src/OneWordNetwork.cpp +++ b/torch_modules/src/OneWordNetwork.cpp @@ -12,7 +12,18 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) params = linear->parameters(); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); - this->focusedIndex = focusedIndex; + int leftBorder = 0; + int rightBorder = 0; + if (focusedIndex < 0) + leftBorder = -focusedIndex; + if (focusedIndex > 0) + rightBorder = focusedIndex; + + this->focusedIndex = focusedIndex <= 0 ? 0 : focusedIndex; + + setLeftBorder(leftBorder); + setRightBorder(rightBorder); + setNbStackElements(0); } std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters() diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 8e84642de20d5bead34bbf80e599992e50dfcb8a..94d23d2aa9f49f9eba40694e133bf41ba48bd185 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -58,8 +58,8 @@ void Trainer::createDataset(SubConfig & config, bool debug) dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5))); - sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); + denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5))); + sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5))); } float Trainer::epoch(bool printAdvancement) @@ -81,8 +81,11 @@ float Trainer::epoch(bool printAdvancement) auto prediction = machine.getClassifier()->getNN()(data); auto loss = torch::nll_loss(torch::log(prediction), labels); - totalLoss += loss.item<float>(); - lossSoFar += loss.item<float>(); + try + { + totalLoss += loss.item<float>(); + lossSoFar += loss.item<float>(); + } catch(std::exception & e) {util::myThrow(e.what());} loss.backward(); denseOptimizer->step();