diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index ad03c314757c302ccd73ae5768fc4c9b42c631f2..27ac7d30bff9d797f2f2b5e96cc7706192b939cb 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -48,11 +48,21 @@ void Classifier::initNeuralNetwork(const std::string & topology) } }, { - std::regex("CNN\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), - "CNN(leftBorder,rightBorder,nbStack) : CNN to capture context.", + std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"), + "CNN(leftBorder,rightBorder,nbStack,{focusedBuffer},{focusedStack},{focusedColumns}) : CNN to capture context.", [this,topology](auto sm) { - this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]))); + std::vector<long> focusedBuffer, focusedStack; + std::vector<std::string> focusedColumns, columns; + for (auto s : util::split(std::string(sm[4]), ',')) + columns.emplace_back(s); + for (auto s : util::split(std::string(sm[5]), ',')) + focusedBuffer.push_back(std::stoi(std::string(s))); + for (auto s : util::split(std::string(sm[6]), ',')) + focusedStack.push_back(std::stoi(std::string(s))); + for (auto s : util::split(std::string(sm[7]), ',')) + focusedColumns.emplace_back(s); + this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns)); } }, { diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index b9a730c7000983231b7c75e76e251f435868a049..b6b5cefccc4c2713b1fa71931276f14cf862eaaa 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -7,11 +7,15 @@ class CNNNetworkImpl : public NeuralNetworkImpl { private : - static inline std::vector<long> focusedBufferIndexes{0,1}; - static inline std::vector<long> focusedStackIndexes{0,1}; static inline std::vector<long> windowSizes{2,3,4}; static constexpr unsigned int maxNbLetters = 10; + private : + + std::vector<long> focusedBufferIndexes; + std::vector<long> focusedStackIndexes; + std::vector<std::string> focusedColumns; + torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; @@ -20,7 +24,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl public : - CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); + CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 9633cb4f185ee392fa4299204317224ca692ff84..50f9c0dff0519c971a2ceec14882ad728836fd12 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -1,6 +1,6 @@ #include "CNNNetwork.hpp" -CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) +CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 512; @@ -10,7 +10,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); - setColumns({"FORM", "UPOS"}); + setColumns(columns); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));