Skip to content
Snippets Groups Projects
Commit ba5742bd authored by Franck Dary's avatar Franck Dary
Browse files

Added rawInput window parameters to CNNNetwork

parent 182b62de
No related branches found
No related tags found
No related merge requests found
...@@ -48,8 +48,8 @@ void Classifier::initNeuralNetwork(const std::string & topology) ...@@ -48,8 +48,8 @@ void Classifier::initNeuralNetwork(const std::string & topology)
} }
}, },
{ {
std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"), std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements}) : CNN to capture context.", "CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
[this,topology](auto sm) [this,topology](auto sm)
{ {
std::vector<int> focusedBuffer, focusedStack, maxNbElements; std::vector<int> focusedBuffer, focusedStack, maxNbElements;
...@@ -66,7 +66,7 @@ void Classifier::initNeuralNetwork(const std::string & topology) ...@@ -66,7 +66,7 @@ void Classifier::initNeuralNetwork(const std::string & topology)
maxNbElements.push_back(std::stoi(std::string(s))); maxNbElements.push_back(std::stoi(std::string(s)));
if (focusedColumns.size() != maxNbElements.size()) if (focusedColumns.size() != maxNbElements.size())
util::myThrow("focusedColumns.size() != maxNbElements.size()"); util::myThrow("focusedColumns.size() != maxNbElements.size()");
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements)); this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[9]), std::stoi(sm[10])));
} }
}, },
{ {
......
...@@ -12,8 +12,9 @@ class CNNNetworkImpl : public NeuralNetworkImpl ...@@ -12,8 +12,9 @@ class CNNNetworkImpl : public NeuralNetworkImpl
std::vector<int> focusedStackIndexes; std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns; std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements; std::vector<int> maxNbElements;
int leftWindowRawInput{5}; int leftWindowRawInput;
int rightWindowRawInput{5}; int rightWindowRawInput;
int rawInputSize;
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear1{nullptr};
...@@ -24,7 +25,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl ...@@ -24,7 +25,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
public : public :
CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements); CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override; std::vector<long> extractContext(Config & config, Dict & dict) const override;
}; };
......
#include "CNNNetwork.hpp" #include "CNNNetwork.hpp"
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements) CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{ {
constexpr int embeddingsSize = 64; constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 512; constexpr int hiddenSize = 512;
...@@ -12,10 +12,16 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i ...@@ -12,10 +12,16 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
setNbStackElements(nbStackElements); setNbStackElements(nbStackElements);
setColumns(columns); setColumns(columns);
rawInputSize = leftWindowRawInput + rightWindowRawInput + 1;
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
rawInputSize = 0;
else
rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize)); int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNN->getOutputSize();
for (auto & col : focusedColumns) for (auto & col : focusedColumns)
{ {
std::vector<int> windows{2,3,4}; std::vector<int> windows{2,3,4};
...@@ -33,16 +39,18 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) ...@@ -33,16 +39,18 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
auto embeddings = wordEmbeddings(input); auto embeddings = wordEmbeddings(input);
auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1); auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
auto context = embeddings.narrow(1, rawLetters.size(1), columns.size()*(1+leftBorder+rightBorder));
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
auto elementsEmbeddings = embeddings.narrow(1, rawLetters.size(1)+context.size(1), input.size(1)-(rawLetters.size(1)+context.size(1))); auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
std::vector<torch::Tensor> cnnOutputs; std::vector<torch::Tensor> cnnOutputs;
if (rawInputSize != 0)
{
auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1))); cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
}
auto curIndex = 0; auto curIndex = 0;
for (unsigned int i = 0; i < focusedColumns.size(); i++) for (unsigned int i = 0; i < focusedColumns.size(); i++)
...@@ -68,6 +76,8 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c ...@@ -68,6 +76,8 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<long> context; std::vector<long> context;
if (rawInputSize > 0)
{
for (int i = 0; i < leftWindowRawInput; i++) for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
...@@ -80,6 +90,7 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c ...@@ -80,6 +90,7 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
else else
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
for (auto index : contextIndexes) for (auto index : contextIndexes)
for (auto & col : columns) for (auto & col : columns)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment