Skip to content
Snippets Groups Projects
Commit 69e871aa authored by Franck Dary's avatar Franck Dary
Browse files

CNNNetwork takes new parameters

parent 38571427
No related branches found
No related tags found
No related merge requests found
......@@ -48,11 +48,21 @@ void Classifier::initNeuralNetwork(const std::string & topology)
}
},
{
std::regex("CNN\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"CNN(leftBorder,rightBorder,nbStack) : CNN to capture context.",
std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"),
"CNN(leftBorder,rightBorder,nbStack,{focusedBuffer},{focusedStack},{focusedColumns}) : CNN to capture context.",
[this,topology](auto sm)
{
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
std::vector<long> focusedBuffer, focusedStack;
std::vector<std::string> focusedColumns, columns;
for (auto s : util::split(std::string(sm[4]), ','))
columns.emplace_back(s);
for (auto s : util::split(std::string(sm[5]), ','))
focusedBuffer.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[6]), ','))
focusedStack.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[7]), ','))
focusedColumns.emplace_back(s);
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns));
}
},
{
......
......@@ -7,11 +7,15 @@ class CNNNetworkImpl : public NeuralNetworkImpl
{
private :
static inline std::vector<long> focusedBufferIndexes{0,1};
static inline std::vector<long> focusedStackIndexes{0,1};
static inline std::vector<long> windowSizes{2,3,4};
static constexpr unsigned int maxNbLetters = 10;
private :
std::vector<long> focusedBufferIndexes;
std::vector<long> focusedStackIndexes;
std::vector<std::string> focusedColumns;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
......@@ -20,7 +24,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
public :
CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns);
torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override;
};
......
#include "CNNNetwork.hpp"
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 512;
......@@ -10,7 +10,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
setColumns({"FORM", "UPOS"});
setColumns(columns);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment