From 9db9288677d9fd53b32b5dc19fa4d34c9ae017dd Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 7 Mar 2020 22:35:06 +0100 Subject: [PATCH] Added rawInput to CNNNetwork --- torch_modules/include/CNNNetwork.hpp | 3 +++ torch_modules/src/CNNNetwork.cpp | 31 +++++++++++++++++++++++----- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index ebbfefd..22f4b0c 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -12,11 +12,14 @@ class CNNNetworkImpl : public NeuralNetworkImpl std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; + int leftWindowRawInput{5}; + int rightWindowRawInput{5}; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; CNN contextCNN{nullptr}; + CNN rawInputCNN{nullptr}; std::vector<CNN> cnns; public : diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 2981b74..7d2a4fc 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -14,7 +14,8 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); - int totalCnnOutputSize = contextCNN->getOutputSize(); + rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize)); + int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNN->getOutputSize(); for (auto & col : focusedColumns) { std::vector<int> windows{2,3,4}; @@ -30,11 +31,19 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) if (input.dim() == 1) input = input.unsqueeze(0); - auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); + auto embeddings = wordEmbeddings(input); - auto elementsEmbeddings = wordEmbeddings(input.narrow(1, wordIndexes.size(1), input.size(1)-wordIndexes.size(1))); + auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1); + + auto context = embeddings.narrow(1, rawLetters.size(0), columns.size()*(1+leftBorder+rightBorder)); + context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); + + auto elementsEmbeddings = embeddings.narrow(1, rawLetters.size(1)+context.size(1), input.size(1)-(rawLetters.size(1)+context.size(1))); std::vector<torch::Tensor> cnnOutputs; + + cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1))); + auto curIndex = 0; for (unsigned int i = 0; i < focusedColumns.size(); i++) { @@ -47,8 +56,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) } } - auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); - cnnOutputs.emplace_back(contextCNN(embeddings)); + cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1))); auto totalInput = torch::cat(cnnOutputs, 1); @@ -60,6 +68,19 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> context; + for (int i = 0; i < leftWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) + context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + else + context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + + for (int i = 0; i <= rightWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()+i)) + + context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); + else + context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + for (auto index : contextIndexes) for (auto & col : columns) if (index == -1) -- GitLab