From ba5742bdfc479346305904d637e6eaa9bf42bde5 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 8 Mar 2020 19:16:48 +0100 Subject: [PATCH] Added rawInput window parameters to CNNNetwork --- reading_machine/src/Classifier.cpp | 6 ++-- torch_modules/include/CNNNetwork.hpp | 7 +++-- torch_modules/src/CNNNetwork.cpp | 47 +++++++++++++++++----------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index e5eee2d..3e0cc65 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -48,8 +48,8 @@ void Classifier::initNeuralNetwork(const std::string & topology) } }, { - std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"), - "CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements}) : CNN to capture context.", + std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), + "CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", [this,topology](auto sm) { std::vector<int> focusedBuffer, focusedStack, maxNbElements; @@ -66,7 +66,7 @@ void Classifier::initNeuralNetwork(const std::string & topology) maxNbElements.push_back(std::stoi(std::string(s))); if (focusedColumns.size() != maxNbElements.size()) util::myThrow("focusedColumns.size() != maxNbElements.size()"); - this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements)); + this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[9]), std::stoi(sm[10]))); } }, { diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index 22f4b0c..f193ebc 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -12,8 +12,9 @@ class CNNNetworkImpl : public NeuralNetworkImpl std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; - int leftWindowRawInput{5}; - int rightWindowRawInput{5}; + int leftWindowRawInput; + int rightWindowRawInput; + int rawInputSize; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; @@ -24,7 +25,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl public : - CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements); + CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 5ecccae..2f0cc9f 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -1,6 +1,6 @@ #include "CNNNetwork.hpp" -CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements) +CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 512; @@ -12,10 +12,16 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i setNbStackElements(nbStackElements); setColumns(columns); + rawInputSize = leftWindowRawInput + rightWindowRawInput + 1; + if (leftWindowRawInput < 0 or rightWindowRawInput < 0) + rawInputSize = 0; + else + rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize)); + int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize(); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); - rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize)); - int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNN->getOutputSize(); + int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize; for (auto & col : focusedColumns) { std::vector<int> windows{2,3,4}; @@ -33,16 +39,18 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) auto embeddings = wordEmbeddings(input); - auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1); - - auto context = embeddings.narrow(1, rawLetters.size(1), columns.size()*(1+leftBorder+rightBorder)); + auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder)); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); - auto elementsEmbeddings = embeddings.narrow(1, rawLetters.size(1)+context.size(1), input.size(1)-(rawLetters.size(1)+context.size(1))); + auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1))); std::vector<torch::Tensor> cnnOutputs; - cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1))); + if (rawInputSize != 0) + { + auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1); + cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1))); + } auto curIndex = 0; for (unsigned int i = 0; i < focusedColumns.size(); i++) @@ -68,18 +76,21 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> context; - for (int i = 0; i < leftWindowRawInput; i++) - if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) - context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); - else - context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + if (rawInputSize > 0) + { + for (int i = 0; i < leftWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) + context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + else + context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - for (int i = 0; i <= rightWindowRawInput; i++) - if (config.hasCharacter(config.getCharacterIndex()+i)) + for (int i = 0; i <= rightWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()+i)) - context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); - else - context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); + else + context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } for (auto index : contextIndexes) for (auto & col : columns) -- GitLab