diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 61a3d87c4ba4e9da636ae64ceb18372440576df0..096800b01a4db30e06f0a3c2df4dc0a3b2b4f853 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -3,6 +3,7 @@ #include "OneWordNetwork.hpp" #include "ConcatWordsNetwork.hpp" #include "RTLSTMNetwork.hpp" +#include "CNNNetwork.hpp" Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) { @@ -46,6 +47,14 @@ void Classifier::initNeuralNetwork(const std::string & topology) this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); } }, + { + std::regex("CNN\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), + "CNN(leftBorder,rightBorder,nbStack) : CNN to capture context.", + [this,topology](auto sm) + { + this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); + } + }, { std::regex("RTLSTM\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), "RTLSTM(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d5ec5bf5a8f6d0f3633e74c110fffb8986454f79 --- /dev/null +++ b/torch_modules/include/CNNNetwork.hpp @@ -0,0 +1,27 @@ +#ifndef CNNNETWORK__H +#define CNNNETWORK__H + +#include "NeuralNetwork.hpp" + +class CNNNetworkImpl : public NeuralNetworkImpl +{ + private : + + static inline std::vector<long> focusedBufferIndexes{0,1}; + static inline std::vector<long> windowSizes{2,3,4}; + static constexpr unsigned int maxNbLetters = 10; + + torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Linear linear1{nullptr}; + torch::nn::Linear linear2{nullptr}; + std::vector<torch::nn::Conv2d> CNNs; + std::vector<torch::nn::Conv2d> lettersCNNs; + + public : + + CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); + torch::Tensor forward(torch::Tensor input) override; + std::vector<long> extractContext(Config & config, Dict & dict) const override; +}; + +#endif diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7612f7c60bf568a4ae60df510e2d3f51b3cecea9 --- /dev/null +++ b/torch_modules/src/CNNNetwork.cpp @@ -0,0 +1,138 @@ +#include "CNNNetwork.hpp" + +CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) +{ + constexpr int embeddingsSize = 64; + constexpr int hiddenSize = 512; + constexpr int nbFilters = 512; + constexpr int nbFiltersLetters = 64; + + setLeftBorder(leftBorder); + setRightBorder(rightBorder); + setNbStackElements(nbStackElements); + setColumns({"FORM", "UPOS"}); + + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*focusedBufferIndexes.size(), hiddenSize)); + linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); + for (auto & windowSize : windowSizes) + { + CNNs.emplace_back(register_module(fmt::format("cnn_context_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,2*embeddingsSize})).padding({windowSize-1, 0})))); + lettersCNNs.emplace_back(register_module(fmt::format("cnn_letters_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFiltersLetters, torch::ExpandingArray<2>({windowSize,embeddingsSize})).padding({windowSize-1, 0})))); + } +} + +torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) +{ + if (input.dim() == 1) + input = input.unsqueeze(0); + + auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); + auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*focusedBufferIndexes.size()); + + auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); + auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1); + + auto permuted = lettersEmbeddings.permute({2,0,1,3,4}); + std::vector<torch::Tensor> windows; + for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++) + for (unsigned int i = 0; i < lettersCNNs.size(); i++) + { + auto input = permuted[word]; + auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1)); + auto pooled = torch::max_pool1d(convOut, convOut.size(2)); + windows.emplace_back(pooled); + } + auto lettersCnnOut = torch::cat(windows, 2); + lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1}); + + windows.clear(); + for (unsigned int i = 0; i < CNNs.size(); i++) + { + auto convOut = torch::relu(CNNs[i](embeddings).squeeze(-1)); + auto pooled = torch::max_pool1d(convOut, convOut.size(2)); + windows.emplace_back(pooled); + } + + auto cnnOut = torch::cat(windows, 2); + cnnOut = cnnOut.view({cnnOut.size(0), -1}); + + auto totalInput = torch::cat({cnnOut, lettersCnnOut}, 1); + + return linear2(torch::relu(linear1(totalInput))); +} + +std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const +{ + std::stack<int> leftContext; + std::stack<std::string> leftForms; + for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index) + if (config.isToken(index)) + for (auto & column : columns) + { + leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index))); + if (column == "FORM") + leftForms.push(config.getAsFeature(column, index)); + } + + std::vector<long> context; + std::vector<std::string> forms; + + while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size())) + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + while (forms.size() < leftBorder-leftForms.size()) + forms.emplace_back(""); + while (!leftForms.empty()) + { + forms.emplace_back(leftForms.top()); + leftForms.pop(); + } + while (!leftContext.empty()) + { + context.emplace_back(leftContext.top()); + leftContext.pop(); + } + + for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index) + if (config.isToken(index)) + for (auto & column : columns) + { + context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index))); + if (column == "FORM") + forms.emplace_back(config.getAsFeature(column, index)); + } + + while (context.size() < columns.size()*(leftBorder+rightBorder+1)) + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + while ((int)forms.size() < leftBorder+rightBorder+1) + forms.emplace_back(""); + + for (int i = 0; i < nbStackElements; i++) + for (auto & column : columns) + if (config.hasStack(i)) + context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i)))); + else + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + + for (auto index : focusedBufferIndexes) + { + util::utf8string letters; + if (leftBorder+index >= 0 && leftBorder+index < (int)forms.size() && !forms[leftBorder+index].empty()) + letters = util::splitAsUtf8(forms[leftBorder+index]); + for (unsigned int i = 0; i < maxNbLetters; i++) + { + if (i < letters.size()) + { + std::string sLetter = fmt::format("Letter({})", letters[i]); + context.emplace_back(dict.getIndexOrInsert(sLetter)); + } + else + { + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } + } + } + + return context; +} + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 5a8c30230b0f56b1aeb98b04879d40cf3d51ab20..6e889171c5f1fa461f333605446fb8545d287270 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -18,7 +18,7 @@ class Trainer DataLoader dataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; - int batchSize{1}; + int batchSize{50}; int nbExamples{0}; public :