diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index d5ec5bf5a8f6d0f3633e74c110fffb8986454f79..b9a730c7000983231b7c75e76e251f435868a049 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -8,6 +8,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl private : static inline std::vector<long> focusedBufferIndexes{0,1}; + static inline std::vector<long> focusedStackIndexes{0,1}; static inline std::vector<long> windowSizes{2,3,4}; static constexpr unsigned int maxNbLetters = 10; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 7612f7c60bf568a4ae60df510e2d3f51b3cecea9..9633cb4f185ee392fa4299204317224ca692ff84 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -13,7 +13,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); - linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*focusedBufferIndexes.size(), hiddenSize)); + linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); for (auto & windowSize : windowSizes) { @@ -28,7 +28,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) input = input.unsqueeze(0); auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); - auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*focusedBufferIndexes.size()); + auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*(focusedBufferIndexes.size()+focusedStackIndexes.size())); auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1); @@ -43,6 +43,14 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) auto pooled = torch::max_pool1d(convOut, convOut.size(2)); windows.emplace_back(pooled); } + for (unsigned int word = 0; word < focusedStackIndexes.size(); word++) + for (unsigned int i = 0; i < lettersCNNs.size(); i++) + { + auto input = permuted[focusedBufferIndexes.size()+word]; + auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1)); + auto pooled = torch::max_pool1d(convOut, convOut.size(2)); + windows.emplace_back(pooled); + } auto lettersCnnOut = torch::cat(windows, 2); lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1}); @@ -133,6 +141,25 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c } } + for (auto index : focusedStackIndexes) + { + util::utf8string letters; + if (config.hasStack(index) and config.has("FORM", config.getStack(index),0)) + letters = util::splitAsUtf8(config.getAsFeature("FORM", config.getStack(index)).get()); + for (unsigned int i = 0; i < maxNbLetters; i++) + { + if (i < letters.size()) + { + std::string sLetter = fmt::format("Letter({})", letters[i]); + context.emplace_back(dict.getIndexOrInsert(sLetter)); + } + else + { + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } + } + } + return context; }