#include "CNNNetwork.hpp" CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 512; constexpr int nbFilters = 512; constexpr int nbFiltersLetters = 64; setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); setColumns(columns); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); contextCNN = register_module("contextCNN", CNN(std::vector<long>{2,3,4}, nbFilters, 2*embeddingsSize)); lettersCNN = register_module("lettersCNN", CNN(std::vector<long>{2,3,4,5}, nbFiltersLetters, embeddingsSize)); linear1 = register_module("linear1", torch::nn::Linear(contextCNN->getOutputSize()+lettersCNN->getOutputSize()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); } torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*(focusedBufferIndexes.size()+focusedStackIndexes.size())); auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1); auto permuted = lettersEmbeddings.permute({2,0,1,3,4}); std::vector<torch::Tensor> cnnOuts; for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++) cnnOuts.emplace_back(lettersCNN(permuted[word])); for (unsigned int word = 0; word < focusedStackIndexes.size(); word++) cnnOuts.emplace_back(lettersCNN(permuted[word])); auto lettersCnnOut = torch::cat(cnnOuts, 1); auto contextCnnOut = contextCNN(embeddings); auto totalInput = torch::cat({contextCnnOut, lettersCnnOut}, 1); return linear2(torch::relu(linear1(totalInput))); } std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const { std::stack<int> leftContext; std::stack<std::string> leftForms; for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index) if (config.isToken(index)) for (auto & column : columns) { leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index))); if (column == "FORM") leftForms.push(config.getAsFeature(column, index)); } std::vector<long> context; std::vector<std::string> forms; while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size())) context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); while (forms.size() < leftBorder-leftForms.size()) forms.emplace_back(""); while (!leftForms.empty()) { forms.emplace_back(leftForms.top()); leftForms.pop(); } while (!leftContext.empty()) { context.emplace_back(leftContext.top()); leftContext.pop(); } for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index) if (config.isToken(index)) for (auto & column : columns) { context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index))); if (column == "FORM") forms.emplace_back(config.getAsFeature(column, index)); } while (context.size() < columns.size()*(leftBorder+rightBorder+1)) context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); while ((int)forms.size() < leftBorder+rightBorder+1) forms.emplace_back(""); for (int i = 0; i < nbStackElements; i++) for (auto & column : columns) if (config.hasStack(i)) context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i)))); else context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); for (auto index : focusedBufferIndexes) { util::utf8string letters; if (leftBorder+index >= 0 && leftBorder+index < (int)forms.size() && !forms[leftBorder+index].empty()) letters = util::splitAsUtf8(forms[leftBorder+index]); for (unsigned int i = 0; i < maxNbLetters; i++) { if (i < letters.size()) { std::string sLetter = fmt::format("Letter({})", letters[i]); context.emplace_back(dict.getIndexOrInsert(sLetter)); } else { context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } } } for (auto index : focusedStackIndexes) { util::utf8string letters; if (config.hasStack(index) and config.has("FORM", config.getStack(index),0)) letters = util::splitAsUtf8(config.getAsFeature("FORM", config.getStack(index)).get()); for (unsigned int i = 0; i < maxNbLetters; i++) { if (i < letters.size()) { std::string sLetter = fmt::format("Letter({})", letters[i]); context.emplace_back(dict.getIndexOrInsert(sLetter)); } else { context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } } } return context; }