diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index ebbfefd2464f602f7d105e784fee0847b4e9d04e..22f4b0c933e618fd75c44494bc270383e27c73a1 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -12,11 +12,14 @@ class CNNNetworkImpl : public NeuralNetworkImpl std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; + int leftWindowRawInput{5}; + int rightWindowRawInput{5}; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; CNN contextCNN{nullptr}; + CNN rawInputCNN{nullptr}; std::vector<CNN> cnns; public : diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 2981b744a4aec17fbe3c996f0559058b903088cc..7d2a4fc6f7a03feb098147e53e69e2bbd898f8bd 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -14,7 +14,8 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); - int totalCnnOutputSize = contextCNN->getOutputSize(); + rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize)); + int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNN->getOutputSize(); for (auto & col : focusedColumns) { std::vector<int> windows{2,3,4}; @@ -30,11 +31,19 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) if (input.dim() == 1) input = input.unsqueeze(0); - auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); + auto embeddings = wordEmbeddings(input); - auto elementsEmbeddings = wordEmbeddings(input.narrow(1, wordIndexes.size(1), input.size(1)-wordIndexes.size(1))); + auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1); + + auto context = embeddings.narrow(1, rawLetters.size(0), columns.size()*(1+leftBorder+rightBorder)); + context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); + + auto elementsEmbeddings = embeddings.narrow(1, rawLetters.size(1)+context.size(1), input.size(1)-(rawLetters.size(1)+context.size(1))); std::vector<torch::Tensor> cnnOutputs; + + cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1))); + auto curIndex = 0; for (unsigned int i = 0; i < focusedColumns.size(); i++) { @@ -47,8 +56,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) } } - auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); - cnnOutputs.emplace_back(contextCNN(embeddings)); + cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1))); auto totalInput = torch::cat(cnnOutputs, 1); @@ -60,6 +68,19 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> context; + for (int i = 0; i < leftWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) + context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + else + context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + + for (int i = 0; i <= rightWindowRawInput; i++) + if (config.hasCharacter(config.getCharacterIndex()+i)) + + context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); + else + context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + for (auto index : contextIndexes) for (auto & col : columns) if (index == -1)