diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index fc5f386ae9d085c2a0489357085447122a27f476..60fbf2acb20760a6677f01f891ba27f8ea308b16 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -146,6 +146,8 @@ std::string Decoder::getMetricOfColName(const std::string & colName) const return "LAS"; if (colName == "EOS") return "Sentences"; + if (colName == "FEATS") + return "UFeats"; return colName; } diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 27ac7d30bff9d797f2f2b5e96cc7706192b939cb..e5eee2d6c1a4796cbca5db53042667cecb5d1e9e 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -48,11 +48,11 @@ void Classifier::initNeuralNetwork(const std::string & topology) } }, { - std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"), - "CNN(leftBorder,rightBorder,nbStack,{focusedBuffer},{focusedStack},{focusedColumns}) : CNN to capture context.", + std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"), + "CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements}) : CNN to capture context.", [this,topology](auto sm) { - std::vector<long> focusedBuffer, focusedStack; + std::vector<int> focusedBuffer, focusedStack, maxNbElements; std::vector<std::string> focusedColumns, columns; for (auto s : util::split(std::string(sm[4]), ',')) columns.emplace_back(s); @@ -62,7 +62,11 @@ void Classifier::initNeuralNetwork(const std::string & topology) focusedStack.push_back(std::stoi(std::string(s))); for (auto s : util::split(std::string(sm[7]), ',')) focusedColumns.emplace_back(s); - this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns)); + for (auto s : util::split(std::string(sm[8]), ',')) + maxNbElements.push_back(std::stoi(std::string(s))); + if (focusedColumns.size() != maxNbElements.size()) + util::myThrow("focusedColumns.size() != maxNbElements.size()"); + this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements)); } }, { diff --git a/torch_modules/include/CNN.hpp b/torch_modules/include/CNN.hpp index e08a869f48dcca2e45911ba7d89266a821dd97f7..509be856b83e40431aef0c138747a70b62847201 100644 --- a/torch_modules/include/CNN.hpp +++ b/torch_modules/include/CNN.hpp @@ -8,14 +8,14 @@ class CNNImpl : public torch::nn::Module { private : - std::vector<long> windowSizes; + std::vector<int> windowSizes; std::vector<torch::nn::Conv2d> CNNs; int nbFilters; int elementSize; public : - CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize); + CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize); torch::Tensor forward(torch::Tensor input); int getOutputSize(); diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index 0893ff98a7123c107bfc564ed2ff4f880f7f2c85..ebbfefd2464f602f7d105e784fee0847b4e9d04e 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -8,23 +8,20 @@ class CNNNetworkImpl : public NeuralNetworkImpl { private : - static constexpr unsigned int maxNbLetters = 10; - - private : - - std::vector<long> focusedBufferIndexes; - std::vector<long> focusedStackIndexes; + std::vector<int> focusedBufferIndexes; + std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; + std::vector<int> maxNbElements; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; CNN contextCNN{nullptr}; - CNN lettersCNN{nullptr}; + std::vector<CNN> cnns; public : - CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns); + CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 47092164e0abd730c081e7306182757b56eb14be..1ca0919cc3118a2ef5b01c0a466c97ed3c3bd6a5 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -13,9 +13,9 @@ class NeuralNetworkImpl : public torch::nn::Module protected : - int leftBorder{5}; - int rightBorder{5}; - int nbStackElements{2}; + unsigned leftBorder{5}; + unsigned rightBorder{5}; + unsigned nbStackElements{2}; std::vector<std::string> columns{"FORM"}; protected : @@ -28,6 +28,7 @@ class NeuralNetworkImpl : public torch::nn::Module virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<long> extractContext(Config & config, Dict & dict) const; + std::vector<long> extractContextIndexes(const Config & config) const; int getContextSize() const; void setColumns(const std::vector<std::string> & columns); }; diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp index f033403e3db3d89529653b67d8f388786d37201a..d5ac3ba5b088e06036ff42b0dea2361022af7ec7 100644 --- a/torch_modules/src/CNN.cpp +++ b/torch_modules/src/CNN.cpp @@ -1,7 +1,7 @@ #include "CNN.hpp" #include "CNN.hpp" -CNNImpl::CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize) +CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize) : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize) { for (auto & windowSize : windowSizes) diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 86781c9de9be88aae6fb4b3f9f9d72622b7b8030..71c79b35cd928fdae4a5acf5d69766074aad55e9 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -1,11 +1,11 @@ #include "CNNNetwork.hpp" -CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns) +CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 512; - constexpr int nbFilters = 512; - constexpr int nbFiltersLetters = 64; + constexpr int nbFiltersContext = 512; + constexpr int nbFiltersFocused = 64; setLeftBorder(leftBorder); setRightBorder(rightBorder); @@ -13,9 +13,15 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i setColumns(columns); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); - contextCNN = register_module("contextCNN", CNN(std::vector<long>{2,3,4}, nbFilters, 2*embeddingsSize)); - lettersCNN = register_module("lettersCNN", CNN(std::vector<long>{2,3,4,5}, nbFiltersLetters, embeddingsSize)); - linear1 = register_module("linear1", torch::nn::Linear(contextCNN->getOutputSize()+lettersCNN->getOutputSize()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); + contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); + int totalCnnOutputSize = contextCNN->getOutputSize(); + for (auto & col : focusedColumns) + { + std::vector<int> windows{2,3,4}; + cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize))); + totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size()); + } + linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); } @@ -25,113 +31,107 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) input = input.unsqueeze(0); auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); - auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*(focusedBufferIndexes.size()+focusedStackIndexes.size())); + auto curIndex = wordIndexes.size(1); - auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); - auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1); + std::vector<torch::Tensor> cnnOutputs; + + for (unsigned int i = 0; i < focusedColumns.size(); i++) + { + long nbElements = input[0][curIndex].item<long>(); - auto permuted = lettersEmbeddings.permute({2,0,1,3,4}); - std::vector<torch::Tensor> cnnOuts; - for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++) - cnnOuts.emplace_back(lettersCNN(permuted[word])); - for (unsigned int word = 0; word < focusedStackIndexes.size(); word++) - cnnOuts.emplace_back(lettersCNN(permuted[word])); - auto lettersCnnOut = torch::cat(cnnOuts, 1); + curIndex++; + for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++) + { + cnnOutputs.emplace_back(cnns[i](wordEmbeddings(input.narrow(1, curIndex, nbElements)).unsqueeze(1))); + curIndex += nbElements; + } + } - auto contextCnnOut = contextCNN(embeddings); + auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); + cnnOutputs.emplace_back(contextCNN(embeddings)); - auto totalInput = torch::cat({contextCnnOut, lettersCnnOut}, 1); + auto totalInput = torch::cat(cnnOutputs, 1); return linear2(torch::relu(linear1(totalInput))); } std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const { - std::stack<int> leftContext; - std::stack<std::string> leftForms; - for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index) - if (config.isToken(index)) - for (auto & column : columns) - { - leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index))); - if (column == "FORM") - leftForms.push(config.getAsFeature(column, index)); - } - + std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> context; - std::vector<std::string> forms; - while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size())) - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); - while (forms.size() < leftBorder-leftForms.size()) - forms.emplace_back(""); - while (!leftForms.empty()) - { - forms.emplace_back(leftForms.top()); - leftForms.pop(); - } - while (!leftContext.empty()) - { - context.emplace_back(leftContext.top()); - leftContext.pop(); - } + for (auto & col : columns) + for (auto index : contextIndexes) + if (index == -1) + context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + else + context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index))); - for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index) - if (config.isToken(index)) - for (auto & column : columns) - { - context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index))); - if (column == "FORM") - forms.emplace_back(config.getAsFeature(column, index)); - } + for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) + { + auto & col = focusedColumns[colIndex]; - while (context.size() < columns.size()*(leftBorder+rightBorder+1)) - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); - while ((int)forms.size() < leftBorder+rightBorder+1) - forms.emplace_back(""); + context.push_back(maxNbElements[colIndex]); - for (int i = 0; i < nbStackElements; i++) - for (auto & column : columns) - if (config.hasStack(i)) - context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i)))); + std::vector<int> focusedIndexes; + for (auto relIndex : focusedBufferIndexes) + { + int index = relIndex + leftBorder; + if (index < 0 || index >= (int)contextIndexes.size()) + focusedIndexes.push_back(-1); + else + focusedIndexes.push_back(contextIndexes[index]); + } + for (auto index : focusedStackIndexes) + { + if (!config.hasStack(index)) + focusedIndexes.push_back(-1); + else if (!config.has(col, config.getStack(index), 0)) + focusedIndexes.push_back(-1); else - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + focusedIndexes.push_back(config.getStack(index)); + } - for (auto index : focusedBufferIndexes) - { - util::utf8string letters; - if (leftBorder+index >= 0 && leftBorder+index < (int)forms.size() && !forms[leftBorder+index].empty()) - letters = util::splitAsUtf8(forms[leftBorder+index]); - for (unsigned int i = 0; i < maxNbLetters; i++) + for (auto index : focusedIndexes) { - if (i < letters.size()) + if (index == -1) { - std::string sLetter = fmt::format("Letter({})", letters[i]); - context.emplace_back(dict.getIndexOrInsert(sLetter)); + for (int i = 0; i < maxNbElements[colIndex]; i++) + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + continue; } - else + + std::vector<std::string> elements; + if (col == "FORM") { - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); - } - } - } + auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); - for (auto index : focusedStackIndexes) - { - util::utf8string letters; - if (config.hasStack(index) and config.has("FORM", config.getStack(index),0)) - letters = util::splitAsUtf8(config.getAsFeature("FORM", config.getStack(index)).get()); - for (unsigned int i = 0; i < maxNbLetters; i++) - { - if (i < letters.size()) + for (int i = 0; i < maxNbElements[colIndex]; i++) + if (i < (int)asUtf8.size()) + elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (col == "FEATS") { - std::string sLetter = fmt::format("Letter({})", letters[i]); - context.emplace_back(dict.getIndexOrInsert(sLetter)); + auto splited = util::split(config.getAsFeature(col, index).get(), '|'); + + for (int i = 0; i < maxNbElements[colIndex]; i++) + if (i < (int)splited.size()) + elements.emplace_back(fmt::format("FEATS({})", splited[i])); + else + elements.emplace_back(Dict::nullValueStr); } else { - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + elements.emplace_back(config.getAsFeature(col, index)); } + + if ((int)elements.size() != maxNbElements[colIndex]) + util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); + + for (auto & element : elements) + context.emplace_back(dict.getIndexOrInsert(element)); } } diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 0719d84430f7bf194a2b887a022c1c473667572a..3f69b4a3e63753738ac7b7b0921877da03ce6317 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -2,38 +2,50 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); -std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const +std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config) const { - std::stack<int> leftContext; - for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index) + std::stack<long> leftContext; + for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index) if (config.isToken(index)) - for (auto & column : columns) - leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index))); + leftContext.push(index); std::vector<long> context; - while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size())) - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + while (context.size() < leftBorder-leftContext.size()) + context.emplace_back(-1); while (!leftContext.empty()) { context.emplace_back(leftContext.top()); leftContext.pop(); } - for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index) + for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index) if (config.isToken(index)) - for (auto & column : columns) - context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index))); + context.emplace_back(index); + + while (context.size() < leftBorder+rightBorder+1) + context.emplace_back(-1); + + for (unsigned int i = 0; i < nbStackElements; i++) + if (config.hasStack(i)) + context.emplace_back(config.getStack(i)); + else + context.emplace_back(-1); - while (context.size() < columns.size()*(leftBorder+rightBorder+1)) - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + return context; +} + +std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const +{ + std::vector<long> indexes = extractContextIndexes(config); + std::vector<long> context; - for (int i = 0; i < nbStackElements; i++) - for (auto & column : columns) - if (config.hasStack(i)) - context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i)))); + for (auto & col : columns) + for (auto index : indexes) + if (index == -1) + context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); else - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index))); return context; }