diff --git a/common/include/util.hpp b/common/include/util.hpp index efe509a0a176af75fdc0e0947c4ffad65074d774..1ab894349364cbe1dc29bcff089937c2d0ee885f 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -38,7 +38,7 @@ std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path di std::string_view getFilenameFromPath(std::string_view s); -std::vector<std::string_view> split(std::string_view s, char delimiter); +std::vector<std::string> split(std::string_view s, char delimiter); utf8string splitAsUtf8(std::string_view s); diff --git a/common/src/util.cpp b/common/src/util.cpp index 13d412282aab1d3e14a232dd0b4e7f8ec4c7a18c..23f5a6e0af99c9066d90518a08bb5773e225a82f 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -32,9 +32,9 @@ bool util::isIllegal(utf8char c) return c == '\n' || c == '\t'; } -std::vector<std::string_view> util::split(std::string_view remaining, char delimiter) +std::vector<std::string> util::split(std::string_view remaining, char delimiter) { - std::vector<std::string_view> result; + std::vector<std::string> result; for (auto firstDelimiterIndex = remaining.find_first_of(delimiter); firstDelimiterIndex != std::string_view::npos; firstDelimiterIndex = remaining.find_first_of(delimiter)) { diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 13e1806f5afce69d2c3f475c79d2f3f0cced6242..1e6833e016f5df74968b62b3a5354f9709ca8ed5 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -109,6 +109,7 @@ class Config bool rawInputOnlySeparatorsLeft() const; std::size_t getWordIndex() const; std::size_t getCharacterIndex() const; + long getRelativeWordIndex(int relativeIndex) const; const String & getHistory(int relativeIndex) const; std::size_t getStack(int relativeIndex) const; bool hasHistory(int relativeIndex) const; diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 7c96c7fc09c8d3ce96a3f36df1ec33a951e894b6..5cba1706646caf2ec4142b70b893abc313ad5dde 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -37,59 +37,72 @@ void Classifier::initNeuralNetwork(const std::string & topology) "OneWord(focusedIndex) : Only use the word embedding of the focused word.", [this,topology](auto sm) { - this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]))); + this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)))); } }, { - std::regex("ConcatWords\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), - "ConcatWords(leftBorder,rightBorder,nbStack) : Concatenate embeddings of words in context.", + std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"), + "ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.", [this,topology](auto sm) { - this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); + std::vector<int> bufferContext, stackContext; + for (auto s : util::split(sm.str(1), ',')) + bufferContext.emplace_back(std::stoi(s)); + for (auto s : util::split(sm.str(2), ',')) + stackContext.emplace_back(std::stoi(s)); + this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), bufferContext, stackContext)); } }, { - std::regex("CNN\\(([+\\-]?\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), - "CNN(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", + std::regex("CNN\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), + "CNN(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", [this,topology](auto sm) { - std::vector<int> focusedBuffer, focusedStack, maxNbElements; + std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext; std::vector<std::string> focusedColumns, columns; - for (auto s : util::split(std::string(sm[5]), ',')) + for (auto s : util::split(sm.str(2), ',')) + bufferContext.emplace_back(std::stoi(s)); + for (auto s : util::split(sm.str(3), ',')) + stackContext.emplace_back(std::stoi(s)); + for (auto s : util::split(sm.str(4), ',')) columns.emplace_back(s); - for (auto s : util::split(std::string(sm[6]), ',')) - focusedBuffer.push_back(std::stoi(std::string(s))); - for (auto s : util::split(std::string(sm[7]), ',')) - focusedStack.push_back(std::stoi(std::string(s))); - for (auto s : util::split(std::string(sm[8]), ',')) + for (auto s : util::split(sm.str(5), ',')) + focusedBuffer.push_back(std::stoi(s)); + for (auto s : util::split(sm.str(6), ',')) + focusedStack.push_back(std::stoi(s)); + for (auto s : util::split(sm.str(7), ',')) focusedColumns.emplace_back(s); - for (auto s : util::split(std::string(sm[9]), ',')) - maxNbElements.push_back(std::stoi(std::string(s))); + for (auto s : util::split(sm.str(8), ',')) + maxNbElements.push_back(std::stoi(s)); if (focusedColumns.size() != maxNbElements.size()) util::myThrow("focusedColumns.size() != maxNbElements.size()"); - this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11]))); + this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10)))); } }, { - std::regex("LSTM\\(([+\\-]?\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), - "LSTM(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", + std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), + "LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", [this,topology](auto sm) { - std::vector<int> focusedBuffer, focusedStack, maxNbElements; + std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext; std::vector<std::string> focusedColumns, columns; - for (auto s : util::split(std::string(sm[5]), ',')) + for (auto s : util::split(sm.str(2), ',')) + bufferContext.emplace_back(std::stoi(s)); + for (auto s : util::split(sm.str(3), ',')) + stackContext.emplace_back(std::stoi(s)); + for (auto s : util::split(sm.str(4), ',')) columns.emplace_back(s); - for (auto s : util::split(std::string(sm[6]), ',')) - focusedBuffer.push_back(std::stoi(std::string(s))); - for (auto s : util::split(std::string(sm[7]), ',')) - focusedStack.push_back(std::stoi(std::string(s))); - for (auto s : util::split(std::string(sm[8]), ',')) + for (auto s : util::split(sm.str(5), ',')) + focusedBuffer.push_back(std::stoi(s)); + for (auto s : util::split(sm.str(6), ',')) + focusedStack.push_back(std::stoi(s)); + for (auto s : util::split(sm.str(7), ',')) focusedColumns.emplace_back(s); - for (auto s : util::split(std::string(sm[9]), ',')) - maxNbElements.push_back(std::stoi(std::string(s))); + for (auto s : util::split(sm.str(8), ',')) + maxNbElements.push_back(std::stoi(s)); if (focusedColumns.size() != maxNbElements.size()) util::myThrow("focusedColumns.size() != maxNbElements.size()"); - this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11]))); + this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10)))); } }, { @@ -97,19 +110,28 @@ void Classifier::initNeuralNetwork(const std::string & topology) "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", [this,topology](auto sm) { - this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); + this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), std::stoi(sm.str(2)), std::stoi(sm.str(3)))); } }, }; + std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); + for (auto & initializer : initializers) - if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer))) + try { - this->nn->to(NeuralNetworkImpl::device); - return; + if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer))) + { + this->nn->to(NeuralNetworkImpl::device); + return; + } + } + catch (std::exception & e) + { + errorMessage = fmt::format("Caught({}) {}", e.what(), errorMessage); + break; } - std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); for (auto & initializer : initializers) errorMessage += std::get<1>(initializer) + "\n"; diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index ab35b3a58218bacbcc51e40fd063d50317a040a3..03179a020661091f63d51606249c106e4fd55074 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -573,3 +573,29 @@ void Config::addMissingColumns() } } +long Config::getRelativeWordIndex(int relativeIndex) const +{ + if (relativeIndex < 0) + { + for (int index = getWordIndex()-1, counter = 0; has(0,index,0); --index) + if (!isCommentPredicted(index)) + { + --counter; + if (counter == relativeIndex) + return index; + } + } + else + { + for (int index = getWordIndex(), counter = 0; has(0,index,0); ++index) + if (!isCommentPredicted(index)) + { + if (counter == relativeIndex) + return index; + ++counter; + } + } + + return -1; +} + diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index 0cd54b8f087034863e1fd8dbb2e07089be221a56..a036e7622940d5eb22f2c8f820df0b0063f19d0c 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -11,8 +11,6 @@ class CNNNetworkImpl : public NeuralNetworkImpl static constexpr int maxNbEmbeddings = 50000; int unknownValueThreshold; - std::vector<int> focusedBufferIndexes; - std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; int leftWindowRawInput; @@ -31,7 +29,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl public : - CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); + CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp index 7152eba2e06aa74b909b2136514d64120b6c0b8d..b2c9691dd3a062cf600e99ce1f7ec9ca4d478643 100644 --- a/torch_modules/include/ConcatWordsNetwork.hpp +++ b/torch_modules/include/ConcatWordsNetwork.hpp @@ -14,7 +14,7 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl public : - ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); + ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext); torch::Tensor forward(torch::Tensor input) override; }; diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index 5fe8de8461320a83b59537a6fae0660f16a3c334..e276eb23f786bc3e80b7004cb67f893b347fe5cf 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -10,8 +10,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl static constexpr int maxNbEmbeddings = 50000; int unknownValueThreshold; - std::vector<int> focusedBufferIndexes; - std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; int leftWindowRawInput; @@ -30,7 +28,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl public : - LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); + LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 34bf14b632cd912e1d0743fc667a14ed49e667c2..8ffa7349d35b2b50816f5574a13a0e6c096c2e9c 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -13,22 +13,25 @@ class NeuralNetworkImpl : public torch::nn::Module protected : - unsigned leftBorder{5}; - unsigned rightBorder{5}; - unsigned nbStackElements{2}; std::vector<std::string> columns{"FORM"}; + std::vector<int> bufferContext{-3,-2,-1,0,1}; + std::vector<int> stackContext{}; + std::vector<int> bufferFocused{}; + std::vector<int> stackFocused{}; protected : - void setRightBorder(int rightBorder); - void setLeftBorder(int leftBorder); - void setNbStackElements(int nbStackElements); + void setBufferContext(const std::vector<int> & bufferContext); + void setStackContext(const std::vector<int> & stackContext); + void setBufferFocused(const std::vector<int> & bufferFocused); + void setStackFocused(const std::vector<int> & stackFocused); public : virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const; std::vector<long> extractContextIndexes(const Config & config) const; + std::vector<long> extractFocusedIndexes(const Config & config) const; int getContextSize() const; void setColumns(const std::vector<std::string> & columns); }; diff --git a/torch_modules/include/OneWordNetwork.hpp b/torch_modules/include/OneWordNetwork.hpp index b4ad4753b8bb57a35afa15851742551f7227e0c3..9882b620187fcfaaebff33f1400b98ddc3446aae 100644 --- a/torch_modules/include/OneWordNetwork.hpp +++ b/torch_modules/include/OneWordNetwork.hpp @@ -9,7 +9,6 @@ class OneWordNetworkImpl : public NeuralNetworkImpl torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear{nullptr}; - int focusedIndex; public : diff --git a/torch_modules/include/RLTNetwork.hpp b/torch_modules/include/RLTNetwork.hpp index b996def57a5e540d738bd9db0874a5d8511d2983..71b8aa55dce4740c1f08fd8ecb343826dedd8989 100644 --- a/torch_modules/include/RLTNetwork.hpp +++ b/torch_modules/include/RLTNetwork.hpp @@ -11,6 +11,8 @@ class RLTNetworkImpl : public NeuralNetworkImpl static inline std::vector<long> focusedBufferIndexes{0,1,2}; static inline std::vector<long> focusedStackIndexes{0,1}; + int leftBorder, rightBorder; + torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 68bd9749cd78ed97687a80543e8109a583b2d3da..c03b347ee55f71d3e130be22f4770d3000904847 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -1,16 +1,17 @@ #include "CNNNetwork.hpp" -CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) +CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 1024; constexpr int nbFiltersContext = 512; constexpr int nbFiltersFocused = 64; - setLeftBorder(leftBorder); - setRightBorder(rightBorder); - setNbStackElements(nbStackElements); + setBufferContext(bufferContext); + setStackContext(stackContext); setColumns(columns); + setBufferFocused(focusedBufferIndexes); + setStackFocused(focusedStackIndexes); rawInputSize = leftWindowRawInput + rightWindowRawInput + 1; if (leftWindowRawInput < 0 or rightWindowRawInput < 0) @@ -42,7 +43,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) auto embeddings = embeddingsDropout(wordEmbeddings(input)); - auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder)); + auto context = embeddings.narrow(1, rawInputSize, getContextSize()); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1))); @@ -59,7 +60,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) for (unsigned int i = 0; i < focusedColumns.size(); i++) { long nbElements = maxNbElements[i]; - for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++) + for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++) { auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1); curIndex += nbElements; @@ -119,30 +120,13 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D } } + std::vector<long> focusedIndexes = extractFocusedIndexes(config); + for (auto & contextElement : context) for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) { auto & col = focusedColumns[colIndex]; - std::vector<int> focusedIndexes; - for (auto relIndex : focusedBufferIndexes) - { - int index = relIndex + leftBorder; - if (index < 0 || index >= (int)contextIndexes.size()) - focusedIndexes.push_back(-1); - else - focusedIndexes.push_back(contextIndexes[index]); - } - for (auto index : focusedStackIndexes) - { - if (!config.hasStack(index)) - focusedIndexes.push_back(-1); - else if (!config.has(col, config.getStack(index), 0)) - focusedIndexes.push_back(-1); - else - focusedIndexes.push_back(config.getStack(index)); - } - for (auto index : focusedIndexes) { if (index == -1) diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index 81ee6252f29ac1b1f676f25557cf64da3a413331..b03b8493b0af20679e70cea6cd73af59e4658588 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -1,26 +1,25 @@ #include "ConcatWordsNetwork.hpp" -ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) +ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext) { - constexpr int embeddingsSize = 100; - setLeftBorder(leftBorder); - setRightBorder(rightBorder); - setNbStackElements(nbStackElements); + constexpr int embeddingsSize = 64; + constexpr int hiddenSize = 500; + + setBufferContext(bufferContext); + setStackContext(stackContext); setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); - linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); - linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs)); + linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize)); + linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); dropout = register_module("dropout", torch::nn::Dropout(0.3)); } torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) { - // input dim = {batch, sequence, embeddings} - auto wordsAsEmb = dropout(wordEmbeddings(input)); - // reshaped dim = {batch, sequence of embeddings} - auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); - - return linear2(torch::relu(linear1(reshaped))); + if (input.dim() == 1) + input = input.unsqueeze(0); + auto wordsAsEmb = dropout(wordEmbeddings(input).view({input.size(0), -1})); + return linear2(torch::relu(linear1(wordsAsEmb))); } diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 84a36f20c292a95500262607c8767fa7494187d7..0ae4cd10cd7866b8d182bfb4ea60da3f47178efc 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,6 +1,6 @@ #include "LSTMNetwork.hpp" -LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) +LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 1024; @@ -8,10 +8,11 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int l constexpr int focusedLSTMSize = 64; constexpr int rawInputLSTMSize = 16; - setLeftBorder(leftBorder); - setRightBorder(rightBorder); - setNbStackElements(nbStackElements); + setBufferContext(bufferContext); + setStackContext(stackContext); setColumns(columns); + setBufferFocused(focusedBufferIndexes); + setStackFocused(focusedStackIndexes); rawInputSize = leftWindowRawInput + rightWindowRawInput + 1; if (leftWindowRawInput < 0 or rightWindowRawInput < 0) @@ -34,7 +35,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int l for (auto & col : focusedColumns) { lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(true).bidirectional(true)))); - totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (focusedBufferIndexes.size()+focusedStackIndexes.size()); + totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (bufferFocused.size()+stackFocused.size()); } linear1 = register_module("linear1", torch::nn::Linear(totalLSTMOutputSize, hiddenSize)); @@ -48,7 +49,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) auto embeddings = embeddingsDropout(wordEmbeddings(input)); - auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder)); + auto context = embeddings.narrow(1, rawInputSize, getContextSize()); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); @@ -67,7 +68,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) for (unsigned int i = 0; i < focusedColumns.size(); i++) { long nbElements = maxNbElements[i]; - for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++) + for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++) { auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements); curIndex += nbElements; @@ -136,30 +137,13 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, } } + std::vector<long> focusedIndexes = extractFocusedIndexes(config); + for (auto & contextElement : context) for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) { auto & col = focusedColumns[colIndex]; - std::vector<int> focusedIndexes; - for (auto relIndex : focusedBufferIndexes) - { - int index = relIndex + leftBorder; - if (index < 0 || index >= (int)contextIndexes.size()) - focusedIndexes.push_back(-1); - else - focusedIndexes.push_back(contextIndexes[index]); - } - for (auto index : focusedStackIndexes) - { - if (!config.hasStack(index)) - focusedIndexes.push_back(-1); - else if (!config.has(col, config.getStack(index), 0)) - focusedIndexes.push_back(-1); - else - focusedIndexes.push_back(config.getStack(index)); - } - for (auto index : focusedIndexes) { if (index == -1) diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index fef5519623d91e48bbac5dead3dcd575216c4121..088e5c0260f38fd5bc05df067b344b3e59e349e2 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -4,31 +4,30 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCU std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config) const { - std::stack<long> leftContext; - for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index) - if (!config.isCommentPredicted(index)) - leftContext.push(index); - std::vector<long> context; - while (context.size() < leftBorder-leftContext.size()) - context.emplace_back(-1); - while (!leftContext.empty()) - { - context.emplace_back(leftContext.top()); - leftContext.pop(); - } + for (int index : bufferContext) + context.emplace_back(config.getRelativeWordIndex(index)); + + for (int index : stackContext) + if (config.hasStack(index)) + context.emplace_back(config.getStack(index)); + else + context.emplace_back(-1); + + return context; +} - for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index) - if (!config.isCommentPredicted(index)) - context.emplace_back(index); +std::vector<long> NeuralNetworkImpl::extractFocusedIndexes(const Config & config) const +{ + std::vector<long> context; - while (context.size() < leftBorder+rightBorder+1) - context.emplace_back(-1); + for (int index : bufferFocused) + context.emplace_back(config.getRelativeWordIndex(index)); - for (unsigned int i = 0; i < nbStackElements; i++) - if (config.hasStack(i)) - context.emplace_back(config.getStack(i)); + for (int index : stackFocused) + if (config.hasStack(index)) + context.emplace_back(config.getStack(index)); else context.emplace_back(-1); @@ -52,22 +51,27 @@ std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config int NeuralNetworkImpl::getContextSize() const { - return columns.size()*(1 + leftBorder + rightBorder + nbStackElements); + return columns.size()*(bufferContext.size()+stackContext.size()); +} + +void NeuralNetworkImpl::setBufferContext(const std::vector<int> & bufferContext) +{ + this->bufferContext = bufferContext; } -void NeuralNetworkImpl::setRightBorder(int rightBorder) +void NeuralNetworkImpl::setStackContext(const std::vector<int> & stackContext) { - this->rightBorder = rightBorder; + this->stackContext = stackContext; } -void NeuralNetworkImpl::setLeftBorder(int leftBorder) +void NeuralNetworkImpl::setBufferFocused(const std::vector<int> & bufferFocused) { - this->leftBorder = leftBorder; + this->bufferFocused = bufferFocused; } -void NeuralNetworkImpl::setNbStackElements(int nbStackElements) +void NeuralNetworkImpl::setStackFocused(const std::vector<int> & stackFocused) { - this->nbStackElements = nbStackElements; + this->stackFocused = stackFocused; } void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns) diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp index d2d796693c88e2c019c0067b0801d3a69a0d0d2c..e3ed3d57c31387bea2d62dc5e230db79f092ec6e 100644 --- a/torch_modules/src/OneWordNetwork.cpp +++ b/torch_modules/src/OneWordNetwork.cpp @@ -2,36 +2,22 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) { - constexpr int embeddingsSize = 30; + constexpr int embeddingsSize = 64; - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize))); - linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); - - int leftBorder = 0; - int rightBorder = 0; - if (focusedIndex < 0) - leftBorder = -focusedIndex; - if (focusedIndex > 0) - rightBorder = focusedIndex; - - this->focusedIndex = focusedIndex <= 0 ? 0 : focusedIndex; - - setLeftBorder(leftBorder); - setRightBorder(rightBorder); - setNbStackElements(0); + setBufferContext({focusedIndex}); + setStackContext({}); setColumns({"FORM", "UPOS"}); + + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs)); } torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) { - // input dim = {batch, sequence, embeddings} - auto wordsAsEmb = wordEmbeddings(input); - auto reshaped = wordsAsEmb; - // reshaped dim = {sequence, batch, embeddings} - if (reshaped.dim() == 3) - reshaped = wordsAsEmb.permute({1,0,2}); - - auto res = linear(reshaped[focusedIndex]); + if (input.dim() == 1) + input = input.unsqueeze(0); + auto wordAsEmb = wordEmbeddings(input).view({input.size(0),-1}); + auto res = linear(wordAsEmb); return res; } diff --git a/torch_modules/src/RLTNetwork.cpp b/torch_modules/src/RLTNetwork.cpp index 38fe64203cbdb0f9546e96cd1b6ac758265af364..a9a346f635a2f35f10aea8ff6743a320bbc537f6 100644 --- a/torch_modules/src/RLTNetwork.cpp +++ b/torch_modules/src/RLTNetwork.cpp @@ -7,9 +7,11 @@ RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i constexpr int treeEmbeddingsSize = 256; constexpr int hiddenSize = 500; - setLeftBorder(leftBorder); - setRightBorder(rightBorder); - setNbStackElements(nbStackElements); + //TODO gerer ces context + this->leftBorder = leftBorder; + this->rightBorder = rightBorder; + setBufferContext({}); + setStackContext({}); setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); @@ -27,10 +29,10 @@ torch::Tensor RLTNetworkImpl::forward(torch::Tensor input) input = input.unsqueeze(0); auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size()); - auto computeOrder = input.narrow(1, focusedIndexes.size(1), leftBorder+rightBorder+1); - auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(leftBorder+rightBorder+1)); + auto computeOrder = input.narrow(1, focusedIndexes.size(1), getContextSize()/columns.size()); + auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(getContextSize()/columns.size())); auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds}); - auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), columns.size()*(leftBorder+rightBorder+1)); + auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), getContextSize()); auto baseEmbeddings = wordEmbeddings(wordIndexes); auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()}); auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output;