diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 0c0f9bbe8c6564049f169e59e02c728ff13d11e0..deddf3fd6e38ac04941e2d985ca7fba821e8aa29 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -19,7 +19,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) config.printForDebug(stderr); auto dictState = machine.getDict(config.getState()).getState(); - auto context = config.extractContext(5,5,machine.getDict(config.getState())); + auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())); machine.getDict(config.getState()).setState(dictState); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong); diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 35f0611cba57c6fbebe1134af1d6fe5467db2350..1131db7f238d7cb460c40e8a9868c0b430a6366a 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -3,7 +3,7 @@ #include <string> #include "TransitionSet.hpp" -#include "TestNetwork.hpp" +#include "NeuralNetwork.hpp" class Classifier { @@ -11,13 +11,17 @@ class Classifier std::string name; std::unique_ptr<TransitionSet> transitionSet; - TestNetwork nn{nullptr}; + std::shared_ptr<NeuralNetworkImpl> nn; + + private : + + void initNeuralNetwork(const std::string & topology); public : Classifier(const std::string & name, const std::string & topology, const std::string & tsFile); TransitionSet & getTransitionSet(); - TestNetwork & getNN(); + NeuralNetwork & getNN(); const std::string & getName() const; }; diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index b18bc88bbc5d03e1f99f0ef425c7e25886c41bb9..310e2f442114567b08f470359dc0629839c970d1 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -107,7 +107,6 @@ class Config String getState() const; void setState(const std::string state); bool stateIsDone() const; - std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const; void addPredicted(const std::set<std::string> & predicted); bool isPredicted(const std::string & colName) const; int getLastPoppedStack() const; diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 13e25b45dcac7da7cefa8066a3019ac84abf3b0d..6ddd264c6a3602d829a072313989ead738560403 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -1,10 +1,13 @@ #include "Classifier.hpp" +#include "util.hpp" +#include "OneWordNetwork.hpp" +#include "ConcatWordsNetwork.hpp" Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) { this->name = name; this->transitionSet.reset(new TransitionSet(tsFile)); - this->nn = TestNetwork(transitionSet->size(), 5); + initNeuralNetwork(topology); } TransitionSet & Classifier::getTransitionSet() @@ -12,9 +15,9 @@ TransitionSet & Classifier::getTransitionSet() return *transitionSet; } -TestNetwork & Classifier::getNN() +NeuralNetwork & Classifier::getNN() { - return nn; + return reinterpret_cast<NeuralNetwork&>(nn); } const std::string & Classifier::getName() const @@ -22,3 +25,36 @@ const std::string & Classifier::getName() const return name; } +void Classifier::initNeuralNetwork(const std::string & topology) +{ + static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers + { + { + std::regex("OneWord\\((\\d+)\\)"), + "OneWord(focusedIndex) : Only use the word embedding of the focused word.", + [this,topology](auto sm) + { + this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]))); + } + }, + { + std::regex("ConcatWords"), + "ConcatWords : Concatenate embeddings of words in context.", + [this,topology](auto) + { + this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size())); + } + }, + }; + + for (auto & initializer : initializers) + if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer))) + return; + + std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); + for (auto & initializer : initializers) + errorMessage += std::get<1>(initializer) + "\n"; + + util::myThrow(errorMessage); +} + diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 286386a0e1496c329ec1b9697124d9ef097a13b0..dd8ca3ec38c7ce3af5ad0096f6fd97e81ece7403 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -455,33 +455,6 @@ bool Config::stateIsDone() const return !has(0, wordIndex+1, 0) and !hasStack(0); } -std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const -{ - std::stack<int> leftContext; - for (int index = wordIndex-1; has(0,index,0) && (int)leftContext.size() < leftBorder; --index) - if (isToken(index)) - leftContext.push(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index))); - - std::vector<long> context; - - while ((int)context.size() < leftBorder-(int)leftContext.size()) - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); - while (!leftContext.empty()) - { - context.emplace_back(leftContext.top()); - leftContext.pop(); - } - - for (int index = wordIndex; has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index) - if (isToken(index)) - context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index))); - - while ((int)context.size() < leftBorder+rightBorder+1) - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); - - return context; -} - void Config::addPredicted(const std::set<std::string> & predicted) { this->predicted.insert(predicted.begin(), predicted.end()); diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4f67b3ac29ae8cb5253548d20e6eed4b2c30f5dd --- /dev/null +++ b/torch_modules/include/ConcatWordsNetwork.hpp @@ -0,0 +1,24 @@ +#ifndef CONCATWORDSNETWORK__H +#define CONCATWORDSNETWORK__H + +#include "NeuralNetwork.hpp" + +class ConcatWordsNetworkImpl : public NeuralNetworkImpl +{ + private : + + torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Linear linear{nullptr}; + + std::vector<torch::Tensor> _denseParameters; + std::vector<torch::Tensor> _sparseParameters; + + public : + + ConcatWordsNetworkImpl(int nbOutputs); + torch::Tensor forward(torch::Tensor input) override; + std::vector<torch::Tensor> & denseParameters() override; + std::vector<torch::Tensor> & sparseParameters() override; +}; + +#endif diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..18462190a7bf31ce4294dce41e510d8f0f695f65 --- /dev/null +++ b/torch_modules/include/NeuralNetwork.hpp @@ -0,0 +1,25 @@ +#ifndef NEURALNETWORK__H +#define NEURALNETWORK__H + +#include <torch/torch.h> +#include "Config.hpp" +#include "Dict.hpp" + +class NeuralNetworkImpl : public torch::nn::Module +{ + private : + + int leftBorder{5}; + int rightBorder{5}; + + public : + + virtual std::vector<torch::Tensor> & denseParameters() = 0; + virtual std::vector<torch::Tensor> & sparseParameters() = 0; + virtual torch::Tensor forward(torch::Tensor input) = 0; + std::vector<long> extractContext(Config & config, Dict & dict) const; + int getContextSize() const; +}; +TORCH_MODULE(NeuralNetwork); + +#endif diff --git a/torch_modules/include/OneWordNetwork.hpp b/torch_modules/include/OneWordNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..29edb7d58931627b72b4bc157acbeaeb2ff82ee0 --- /dev/null +++ b/torch_modules/include/OneWordNetwork.hpp @@ -0,0 +1,25 @@ +#ifndef ONEWORDNETWORK__H +#define ONEWORDNETWORK__H + +#include "NeuralNetwork.hpp" + +class OneWordNetworkImpl : public NeuralNetworkImpl +{ + private : + + torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Linear linear{nullptr}; + int focusedIndex; + + std::vector<torch::Tensor> _denseParameters; + std::vector<torch::Tensor> _sparseParameters; + + public : + + OneWordNetworkImpl(int nbOutputs, int focusedIndex); + torch::Tensor forward(torch::Tensor input) override; + std::vector<torch::Tensor> & denseParameters() override; + std::vector<torch::Tensor> & sparseParameters() override; +}; + +#endif diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp deleted file mode 100644 index 27b92e8a00ac3430567ba35c7846ada1aa076d4a..0000000000000000000000000000000000000000 --- a/torch_modules/include/TestNetwork.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef TESTNETWORK__H -#define TESTNETWORK__H - -#include <torch/torch.h> -#include "Config.hpp" - -class TestNetworkImpl : public torch::nn::Module -{ - private : - - torch::nn::Embedding wordEmbeddings{nullptr}; - torch::nn::Linear linear{nullptr}; - int focusedIndex; - - std::vector<torch::Tensor> _denseParameters; - std::vector<torch::Tensor> _sparseParameters; - - public : - - TestNetworkImpl(int nbOutputs, int focusedIndex); - torch::Tensor forward(torch::Tensor input); - std::vector<torch::Tensor> & denseParameters(); - std::vector<torch::Tensor> & sparseParameters(); -}; -TORCH_MODULE(TestNetwork); - -#endif diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1343d53add153012b463dfb682c61d624a160e98 --- /dev/null +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -0,0 +1,37 @@ +#include "ConcatWordsNetwork.hpp" + +ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs) +{ + constexpr int embeddingsSize = 30; + + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); + auto params = wordEmbeddings->parameters(); + _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); + + linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs)); + params = linear->parameters(); + _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); +} + +std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters() +{ + return _denseParameters; +} + +std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters() +{ + return _sparseParameters; +} + +torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) +{ + // input dim = {batch, sequence, embeddings} + auto wordsAsEmb = wordEmbeddings(input); + // reshaped dim = {batch, sequence of embeddings} + auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); + + auto res = torch::softmax(linear(reshaped), reshaped.dim() == 2 ? 1 : 0); + + return res; +} + diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab8921eb32a69a7852403016e20d043231c30e69 --- /dev/null +++ b/torch_modules/src/NeuralNetwork.cpp @@ -0,0 +1,34 @@ +#include "NeuralNetwork.hpp" + +std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const +{ + std::stack<int> leftContext; + for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index) + if (config.isToken(index)) + leftContext.push(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index))); + + std::vector<long> context; + + while ((int)context.size() < leftBorder-(int)leftContext.size()) + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + while (!leftContext.empty()) + { + context.emplace_back(leftContext.top()); + leftContext.pop(); + } + + for (int index = config.getWordIndex(); config.has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index) + if (config.isToken(index)) + context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index))); + + while ((int)context.size() < leftBorder+rightBorder+1) + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + + return context; +} + +int NeuralNetworkImpl::getContextSize() const +{ + return 1 + leftBorder + rightBorder; +} + diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp similarity index 76% rename from torch_modules/src/TestNetwork.cpp rename to torch_modules/src/OneWordNetwork.cpp index 19debf8225cdb85af503c07714a839776d0ad480..5cfa4f77e29b85f47bfd55ad65248da5c93ac12d 100644 --- a/torch_modules/src/TestNetwork.cpp +++ b/torch_modules/src/OneWordNetwork.cpp @@ -1,6 +1,6 @@ -#include "TestNetwork.hpp" +#include "OneWordNetwork.hpp" -TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex) +OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) { constexpr int embeddingsSize = 30; @@ -15,17 +15,17 @@ TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex) this->focusedIndex = focusedIndex; } -std::vector<torch::Tensor> & TestNetworkImpl::denseParameters() +std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters() { return _denseParameters; } -std::vector<torch::Tensor> & TestNetworkImpl::sparseParameters() +std::vector<torch::Tensor> & OneWordNetworkImpl::sparseParameters() { return _sparseParameters; } -torch::Tensor TestNetworkImpl::forward(torch::Tensor input) +torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} auto wordsAsEmb = wordEmbeddings(input); diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 69dde8d6819d4dff26eb31a84f2515e21b4d085b..0f9c3ec0425946b15243f318155ba980850923f8 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -4,7 +4,6 @@ #include "ReadingMachine.hpp" #include "ConfigDataset.hpp" #include "SubConfig.hpp" -#include "TestNetwork.hpp" class Trainer { diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 195cc5cd9a4bbbc8e42480f9ce70922124d40c8b..8e84642de20d5bead34bbf80e599992e50dfcb8a 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -25,7 +25,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) util::myThrow("No transition appliable !"); } - auto context = config.extractContext(5,5,machine.getDict(config.getState())); + auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);