diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 3e0cc65fc0df6144db896c14610337567e8abd6e..ae638c18762df39e898aad0a9733975ec47d89aa 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -48,25 +48,25 @@ void Classifier::initNeuralNetwork(const std::string & topology) } }, { - std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), - "CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", + std::regex("CNN\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), + "CNN(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", [this,topology](auto sm) { std::vector<int> focusedBuffer, focusedStack, maxNbElements; std::vector<std::string> focusedColumns, columns; - for (auto s : util::split(std::string(sm[4]), ',')) - columns.emplace_back(s); for (auto s : util::split(std::string(sm[5]), ',')) - focusedBuffer.push_back(std::stoi(std::string(s))); + columns.emplace_back(s); for (auto s : util::split(std::string(sm[6]), ',')) - focusedStack.push_back(std::stoi(std::string(s))); + focusedBuffer.push_back(std::stoi(std::string(s))); for (auto s : util::split(std::string(sm[7]), ',')) - focusedColumns.emplace_back(s); + focusedStack.push_back(std::stoi(std::string(s))); for (auto s : util::split(std::string(sm[8]), ',')) + focusedColumns.emplace_back(s); + for (auto s : util::split(std::string(sm[9]), ',')) maxNbElements.push_back(std::stoi(std::string(s))); if (focusedColumns.size() != maxNbElements.size()) util::myThrow("focusedColumns.size() != maxNbElements.size()"); - this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[9]), std::stoi(sm[10]))); + this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11]))); } }, { diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index 1f60c73c5a3e9ee2d69a8be85d7be10a8b450952..2edac4993051841372c293c07d55a6aeee56088c 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -9,8 +9,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl private : static constexpr int maxNbEmbeddings = 50000; - static constexpr int unknownValueThreshold = 0; + int unknownValueThreshold; std::vector<int> focusedBufferIndexes; std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; @@ -31,7 +31,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl public : - CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); + CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 130c0928fe2cc44c495823b0699c04d8672ace0f..5e9696eba7062b67c1b36dccb4dc29dd1fb8f7c5 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -1,6 +1,6 @@ #include "CNNNetwork.hpp" -CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) +CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 1024;