#include "CNNNetwork.hpp" CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 512; constexpr int nbFiltersContext = 512; constexpr int nbFiltersFocused = 64; setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); setColumns(columns); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); int totalCnnOutputSize = contextCNN->getOutputSize(); for (auto & col : focusedColumns) { std::vector<int> windows{2,3,4}; cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize))); totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size()); } linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); } torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); auto curIndex = wordIndexes.size(1); std::vector<torch::Tensor> cnnOutputs; for (unsigned int i = 0; i < focusedColumns.size(); i++) { long nbElements = input[0][curIndex].item<long>(); curIndex++; for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++) { cnnOutputs.emplace_back(cnns[i](wordEmbeddings(input.narrow(1, curIndex, nbElements)).unsqueeze(1))); curIndex += nbElements; } } auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); cnnOutputs.emplace_back(contextCNN(embeddings)); auto totalInput = torch::cat(cnnOutputs, 1); return linear2(torch::relu(linear1(totalInput))); } std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const { std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> context; for (auto & col : columns) for (auto index : contextIndexes) if (index == -1) context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); else context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index))); for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) { auto & col = focusedColumns[colIndex]; context.push_back(maxNbElements[colIndex]); std::vector<int> focusedIndexes; for (auto relIndex : focusedBufferIndexes) { int index = relIndex + leftBorder; if (index < 0 || index >= (int)contextIndexes.size()) focusedIndexes.push_back(-1); else focusedIndexes.push_back(contextIndexes[index]); } for (auto index : focusedStackIndexes) { if (!config.hasStack(index)) focusedIndexes.push_back(-1); else if (!config.has(col, config.getStack(index), 0)) focusedIndexes.push_back(-1); else focusedIndexes.push_back(config.getStack(index)); } for (auto index : focusedIndexes) { if (index == -1) { for (int i = 0; i < maxNbElements[colIndex]; i++) context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); continue; } std::vector<std::string> elements; if (col == "FORM") { auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); for (int i = 0; i < maxNbElements[colIndex]; i++) if (i < (int)asUtf8.size()) elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); else elements.emplace_back(Dict::nullValueStr); } else if (col == "FEATS") { auto splited = util::split(config.getAsFeature(col, index).get(), '|'); for (int i = 0; i < maxNbElements[colIndex]; i++) if (i < (int)splited.size()) elements.emplace_back(fmt::format("FEATS({})", splited[i])); else elements.emplace_back(Dict::nullValueStr); } else { elements.emplace_back(config.getAsFeature(col, index)); } if ((int)elements.size() != maxNbElements[colIndex]) util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); for (auto & element : elements) context.emplace_back(dict.getIndexOrInsert(element)); } } return context; }