diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index aeb6ac87cc665db30c5cbb3ec4bc17fa99989214..6925568e06b8829cdf130f7a44886561ca7f10e7 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -73,10 +73,12 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size std::vector<int> bufferContext, stackContext; std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns; std::vector<int> focusedBuffer, focusedStack; + std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack; std::vector<int> maxNbElements; + std::vector<int> treeEmbeddingNbElems; std::vector<std::pair<int, float>> mlp; int rawInputLeftWindow, rawInputRightWindow; - int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers; + int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize; bool bilstm; float lstmDropout; @@ -229,6 +231,37 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}")); - this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns)); + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + treeEmbeddingBuffer.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + treeEmbeddingStack.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + treeEmbeddingNbElems.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm) + { + treeEmbeddingSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value")); + + this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize)); } diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp index 6eb069b99d308982327142c15acc85c6cb50c042..436a082a06121a2c62f50da0b5f5ef4b79b99ba8 100644 --- a/torch_modules/include/DepthLayerTreeEmbedding.hpp +++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp @@ -9,17 +9,15 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule { private : - std::vector<std::string> columns{"DEPREL"}; - std::vector<int> focusedBuffer{0}; - std::vector<int> focusedStack{0}; - std::string firstElem{"__special_DepthLayerTreeEmbeddingImpl__"}; + std::vector<int> maxElemPerDepth; + std::vector<std::string> columns; + std::vector<int> focusedBuffer; + std::vector<int> focusedStack; std::vector<LSTM> depthLstm; - int maxDepth; - int maxElemPerDepth; public : - DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); + DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index 9cb051d45f15252c304d84174db1965f357eb721..f0b58dc099d512e73ce95c9d9c808cf206006cc5 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -29,7 +29,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl public : - LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns); + LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp index 3f1926d72dd2f8dad1109ca15b2eed32e957ff43..aa4c7aef388105493ce3315d7761acb96b7693ba 100644 --- a/torch_modules/src/DepthLayerTreeEmbedding.cpp +++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp @@ -1,9 +1,10 @@ #include "DepthLayerTreeEmbedding.hpp" -DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth) +DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options) : + maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack) { - for (int i = 0; i < maxDepth; i++) - depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(embeddingsSize, outEmbeddingsSize, options))); + for (unsigned int i = 0; i < maxElemPerDepth.size(); i++) + depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options))); } torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input) @@ -12,9 +13,13 @@ torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input) std::vector<torch::Tensor> outputs; - for (unsigned int i = 0; i < depthLstm.size(); i++) - for (unsigned int j = 0; j < focusedBuffer.size()+focusedStack.size(); j++) - outputs.emplace_back(depthLstm[i](input.narrow(1,i*(focusedBuffer.size()+focusedStack.size())*columns.size()*maxElemPerDepth + j*maxElemPerDepth, maxElemPerDepth))); + int offset = 0; + for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++) + for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) + { + outputs.emplace_back(depthLstm[depth](context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)}))); + offset += maxElemPerDepth[depth]*columns.size(); + } return torch::cat(outputs, 1); } @@ -23,15 +28,18 @@ std::size_t DepthLayerTreeEmbeddingImpl::getOutputSize() { std::size_t outputSize = 0; - for (auto & lstm : depthLstm) - outputSize += lstm->getOutputSize(maxElemPerDepth); + for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) + outputSize += depthLstm[depth]->getOutputSize(maxElemPerDepth[depth]); - return outputSize; + return outputSize*(focusedBuffer.size()+focusedStack.size()); } std::size_t DepthLayerTreeEmbeddingImpl::getInputSize() { - return (focusedBuffer.size()+focusedStack.size())*columns.size()*maxDepth*maxElemPerDepth; + int inputSize = 0; + for (int maxElem : maxElemPerDepth) + inputSize += (focusedBuffer.size()+focusedStack.size())*maxElem*columns.size(); + return inputSize; } void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const @@ -48,11 +56,27 @@ void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & focusedIndexes.emplace_back(-1); for (auto & contextElement : context) - { for (auto index : focusedIndexes) { + std::vector<std::string> childs{std::to_string(index)}; + for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) + { + std::vector<std::string> newChilds; + for (auto & child : childs) + if (config.has(Config::childsColName, std::stoi(child), 0)) + { + auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|'); + newChilds.insert(newChilds.end(), val.begin(), val.end()); + } + childs = newChilds; + for (int i = 0; i < maxElemPerDepth[depth]; i++) + for (auto & col : columns) + if (i < (int)childs.size() and config.has(col, std::stoi(childs[i]), 0)) + contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(childs[i])))); + else + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } } - } } diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 2a9b1b4dca054e89a22322cdf7301998e720ee37..7a25024b67d90bd2f5f2acde0839baa308e4bc74 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,6 +1,6 @@ #include "LSTMNetwork.hpp" -LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns) +LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize) { LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false}; auto lstmOptionsAll = lstmOptions; @@ -26,7 +26,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: if (!treeEmbeddingColumns.empty()) { hasTreeEmbedding = true; - treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(1,3,embeddingsSize,128,treeEmbeddingColumns,focusedBufferIndexes,focusedStackIndexes,lstmOptionsAll)); + treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptionsAll)); treeEmbedding->setFirstInputIndex(currentInputSize); currentOutputSize += treeEmbedding->getOutputSize(); currentInputSize += treeEmbedding->getInputSize(); @@ -89,13 +89,15 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, context.back().emplace_back(dict.getIndexOrInsert(config.getState())); contextLSTM->addToContext(context, dict, config); + if (hasRawInputLSTM) rawInputLSTM->addToContext(context, dict, config); - fmt::print(stderr, "before {}\n", context.back().size()); + if (hasTreeEmbedding) treeEmbedding->addToContext(context, dict, config); - fmt::print(stderr, "after {}\n", context.back().size()); + splitTransLSTM->addToContext(context, dict, config); + for (auto & lstm : focusedLstms) lstm->addToContext(context, dict, config);