Skip to content
Snippets Groups Projects
Commit 7e465832 authored by Franck Dary's avatar Franck Dary
Browse files

unknownValueThreshold is now an argument of CNNNetwork constructor

parent aa16b73b
No related branches found
No related tags found
No related merge requests found
......@@ -48,25 +48,25 @@ void Classifier::initNeuralNetwork(const std::string & topology)
}
},
{
std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
std::regex("CNN\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"CNN(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
[this,topology](auto sm)
{
std::vector<int> focusedBuffer, focusedStack, maxNbElements;
std::vector<std::string> focusedColumns, columns;
for (auto s : util::split(std::string(sm[4]), ','))
columns.emplace_back(s);
for (auto s : util::split(std::string(sm[5]), ','))
focusedBuffer.push_back(std::stoi(std::string(s)));
columns.emplace_back(s);
for (auto s : util::split(std::string(sm[6]), ','))
focusedStack.push_back(std::stoi(std::string(s)));
focusedBuffer.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[7]), ','))
focusedColumns.emplace_back(s);
focusedStack.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[8]), ','))
focusedColumns.emplace_back(s);
for (auto s : util::split(std::string(sm[9]), ','))
maxNbElements.push_back(std::stoi(std::string(s)));
if (focusedColumns.size() != maxNbElements.size())
util::myThrow("focusedColumns.size() != maxNbElements.size()");
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[9]), std::stoi(sm[10])));
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11])));
}
},
{
......
......@@ -9,8 +9,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl
private :
static constexpr int maxNbEmbeddings = 50000;
static constexpr int unknownValueThreshold = 0;
int unknownValueThreshold;
std::vector<int> focusedBufferIndexes;
std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns;
......@@ -31,7 +31,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
public :
CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override;
};
......
#include "CNNNetwork.hpp"
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 1024;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment