diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 41f9f1556faf20e5204e6382020c7c514a669a1a..a58c3b19187778b9115fc7bd1b49edc8490e4056 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -1,6 +1,5 @@ #include "Classifier.hpp" #include "util.hpp" -#include "OneWordNetwork.hpp" #include "ConcatWordsNetwork.hpp" #include "RLTNetwork.hpp" #include "CNNNetwork.hpp" @@ -41,14 +40,6 @@ void Classifier::initNeuralNetwork(const std::string & topology) this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); } }, - { - std::regex("OneWord\\(([+\\-]?\\d+)\\)"), - "OneWord(focusedIndex) : Only use the word embedding of the focused word.", - [this,topology](auto sm) - { - this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)))); - } - }, { std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"), "ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.", diff --git a/torch_modules/include/OneWordNetwork.hpp b/torch_modules/include/OneWordNetwork.hpp deleted file mode 100644 index 9882b620187fcfaaebff33f1400b98ddc3446aae..0000000000000000000000000000000000000000 --- a/torch_modules/include/OneWordNetwork.hpp +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef ONEWORDNETWORK__H -#define ONEWORDNETWORK__H - -#include "NeuralNetwork.hpp" - -class OneWordNetworkImpl : public NeuralNetworkImpl -{ - private : - - torch::nn::Embedding wordEmbeddings{nullptr}; - torch::nn::Linear linear{nullptr}; - - public : - - OneWordNetworkImpl(int nbOutputs, int focusedIndex); - torch::Tensor forward(torch::Tensor input) override; -}; - -#endif diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp deleted file mode 100644 index e3ed3d57c31387bea2d62dc5e230db79f092ec6e..0000000000000000000000000000000000000000 --- a/torch_modules/src/OneWordNetwork.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include "OneWordNetwork.hpp" - -OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) -{ - constexpr int embeddingsSize = 64; - - setBufferContext({focusedIndex}); - setStackContext({}); - setColumns({"FORM", "UPOS"}); - - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); - linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs)); -} - -torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) -{ - if (input.dim() == 1) - input = input.unsqueeze(0); - auto wordAsEmb = wordEmbeddings(input).view({input.size(0),-1}); - auto res = linear(wordAsEmb); - - return res; -} -