From 4bfe61473b0cba6e89743f284f003d13da8a569f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 10 Apr 2020 12:26:07 +0200 Subject: [PATCH] Updated LSTMNetwork and attach action to work with tree embedding networks --- reading_machine/src/Action.cpp | 8 ++++++++ reading_machine/src/Classifier.cpp | 12 +++++++++--- torch_modules/include/LSTMNetwork.hpp | 5 ++++- torch_modules/src/LSTMNetwork.cpp | 18 +++++++++++++++++- 4 files changed, 38 insertions(+), 5 deletions(-) diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index b2e7adb..4e070db 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -570,12 +570,20 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent lineIndex = config.getWordIndex() + governorIndex; else lineIndex = config.getStack(governorIndex); + int depIndex = 0; + if (dependentObject == Object::Buffer) + depIndex = config.getWordIndex() + dependentIndex; + else + depIndex = config.getStack(dependentIndex); + addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(lineIndex)).apply(config, a); + addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, std::to_string(depIndex)).apply(config, a); }; auto undo = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) { addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, "").undo(config, a); + addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, "").apply(config, a); }; auto appliable = [governorObject, governorIndex, dependentObject, dependentIndex](const Config & config, const Action & action) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 27a4ef1..aeb6ac8 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -71,9 +71,8 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size { int unknownValueThreshold; std::vector<int> bufferContext, stackContext; - std::vector<std::string> columns; + std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns; std::vector<int> focusedBuffer, focusedStack; - std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; std::vector<std::pair<int, float>> mlp; int rawInputLeftWindow, rawInputRightWindow; @@ -223,6 +222,13 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value")); - this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout)); + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm) + { + treeEmbeddingColumns = util::split(sm.str(1), ' '); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}")); + + this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns)); } diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index 07d689e..9cb051d 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -7,6 +7,7 @@ #include "SplitTransLSTM.hpp" #include "FocusedColumnLSTM.hpp" #include "MLP.hpp" +#include "DepthLayerTreeEmbedding.hpp" class LSTMNetworkImpl : public NeuralNetworkImpl { @@ -20,13 +21,15 @@ class LSTMNetworkImpl : public NeuralNetworkImpl ContextLSTM contextLSTM{nullptr}; RawInputLSTM rawInputLSTM{nullptr}; SplitTransLSTM splitTransLSTM{nullptr}; + DepthLayerTreeEmbedding treeEmbedding{nullptr}; std::vector<FocusedColumnLSTM> focusedLstms; bool hasRawInputLSTM{false}; + bool hasTreeEmbedding{false}; public : - LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout); + LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index ab06806..2a9b1b4 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,6 +1,6 @@ #include "LSTMNetwork.hpp" -LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout) +LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns) { LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false}; auto lstmOptionsAll = lstmOptions; @@ -23,6 +23,15 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: currentInputSize += rawInputLSTM->getInputSize(); } + if (!treeEmbeddingColumns.empty()) + { + hasTreeEmbedding = true; + treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(1,3,embeddingsSize,128,treeEmbeddingColumns,focusedBufferIndexes,focusedStackIndexes,lstmOptionsAll)); + treeEmbedding->setFirstInputIndex(currentInputSize); + currentOutputSize += treeEmbedding->getOutputSize(); + currentInputSize += treeEmbedding->getInputSize(); + } + splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll)); splitTransLSTM->setFirstInputIndex(currentInputSize); currentOutputSize += splitTransLSTM->getOutputSize(); @@ -56,6 +65,9 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) if (hasRawInputLSTM) outputs.emplace_back(rawInputLSTM(embeddings)); + if (hasTreeEmbedding) + outputs.emplace_back(treeEmbedding(embeddings)); + outputs.emplace_back(splitTransLSTM(embeddings)); for (auto & lstm : focusedLstms) @@ -79,6 +91,10 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, contextLSTM->addToContext(context, dict, config); if (hasRawInputLSTM) rawInputLSTM->addToContext(context, dict, config); + fmt::print(stderr, "before {}\n", context.back().size()); + if (hasTreeEmbedding) + treeEmbedding->addToContext(context, dict, config); + fmt::print(stderr, "after {}\n", context.back().size()); splitTransLSTM->addToContext(context, dict, config); for (auto & lstm : focusedLstms) lstm->addToContext(context, dict, config); -- GitLab