Skip to content
Snippets Groups Projects
Commit 1065b8b6 authored by Franck Dary's avatar Franck Dary
Browse files

CNNNetwork can now use any columns for focused CNNs

parent b50c6ff3
No related branches found
No related tags found
No related merge requests found
......@@ -146,6 +146,8 @@ std::string Decoder::getMetricOfColName(const std::string & colName) const
return "LAS";
if (colName == "EOS")
return "Sentences";
if (colName == "FEATS")
return "UFeats";
return colName;
}
......
......@@ -48,11 +48,11 @@ void Classifier::initNeuralNetwork(const std::string & topology)
}
},
{
std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"),
"CNN(leftBorder,rightBorder,nbStack,{focusedBuffer},{focusedStack},{focusedColumns}) : CNN to capture context.",
std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"),
"CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements}) : CNN to capture context.",
[this,topology](auto sm)
{
std::vector<long> focusedBuffer, focusedStack;
std::vector<int> focusedBuffer, focusedStack, maxNbElements;
std::vector<std::string> focusedColumns, columns;
for (auto s : util::split(std::string(sm[4]), ','))
columns.emplace_back(s);
......@@ -62,7 +62,11 @@ void Classifier::initNeuralNetwork(const std::string & topology)
focusedStack.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[7]), ','))
focusedColumns.emplace_back(s);
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns));
for (auto s : util::split(std::string(sm[8]), ','))
maxNbElements.push_back(std::stoi(std::string(s)));
if (focusedColumns.size() != maxNbElements.size())
util::myThrow("focusedColumns.size() != maxNbElements.size()");
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements));
}
},
{
......
......@@ -8,14 +8,14 @@ class CNNImpl : public torch::nn::Module
{
private :
std::vector<long> windowSizes;
std::vector<int> windowSizes;
std::vector<torch::nn::Conv2d> CNNs;
int nbFilters;
int elementSize;
public :
CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize);
CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize);
torch::Tensor forward(torch::Tensor input);
int getOutputSize();
......
......@@ -8,23 +8,20 @@ class CNNNetworkImpl : public NeuralNetworkImpl
{
private :
static constexpr unsigned int maxNbLetters = 10;
private :
std::vector<long> focusedBufferIndexes;
std::vector<long> focusedStackIndexes;
std::vector<int> focusedBufferIndexes;
std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
CNN contextCNN{nullptr};
CNN lettersCNN{nullptr};
std::vector<CNN> cnns;
public :
CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns);
CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements);
torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override;
};
......
......@@ -13,9 +13,9 @@ class NeuralNetworkImpl : public torch::nn::Module
protected :
int leftBorder{5};
int rightBorder{5};
int nbStackElements{2};
unsigned leftBorder{5};
unsigned rightBorder{5};
unsigned nbStackElements{2};
std::vector<std::string> columns{"FORM"};
protected :
......@@ -28,6 +28,7 @@ class NeuralNetworkImpl : public torch::nn::Module
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<long> extractContext(Config & config, Dict & dict) const;
std::vector<long> extractContextIndexes(const Config & config) const;
int getContextSize() const;
void setColumns(const std::vector<std::string> & columns);
};
......
#include "CNN.hpp"
#include "CNN.hpp"
CNNImpl::CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize)
CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize)
: windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
{
for (auto & windowSize : windowSizes)
......
#include "CNNNetwork.hpp"
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns)
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 512;
constexpr int nbFilters = 512;
constexpr int nbFiltersLetters = 64;
constexpr int nbFiltersContext = 512;
constexpr int nbFiltersFocused = 64;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
......@@ -13,9 +13,15 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
setColumns(columns);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
contextCNN = register_module("contextCNN", CNN(std::vector<long>{2,3,4}, nbFilters, 2*embeddingsSize));
lettersCNN = register_module("lettersCNN", CNN(std::vector<long>{2,3,4,5}, nbFiltersLetters, embeddingsSize));
linear1 = register_module("linear1", torch::nn::Linear(contextCNN->getOutputSize()+lettersCNN->getOutputSize()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
int totalCnnOutputSize = contextCNN->getOutputSize();
for (auto & col : focusedColumns)
{
std::vector<int> windows{2,3,4};
cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize)));
totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size());
}
linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}
......@@ -25,113 +31,107 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
input = input.unsqueeze(0);
auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder));
auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*(focusedBufferIndexes.size()+focusedStackIndexes.size()));
auto curIndex = wordIndexes.size(1);
auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
std::vector<torch::Tensor> cnnOutputs;
auto permuted = lettersEmbeddings.permute({2,0,1,3,4});
std::vector<torch::Tensor> cnnOuts;
for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++)
cnnOuts.emplace_back(lettersCNN(permuted[word]));
for (unsigned int word = 0; word < focusedStackIndexes.size(); word++)
cnnOuts.emplace_back(lettersCNN(permuted[word]));
auto lettersCnnOut = torch::cat(cnnOuts, 1);
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
long nbElements = input[0][curIndex].item<long>();
auto contextCnnOut = contextCNN(embeddings);
curIndex++;
for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++)
{
cnnOutputs.emplace_back(cnns[i](wordEmbeddings(input.narrow(1, curIndex, nbElements)).unsqueeze(1)));
curIndex += nbElements;
}
}
auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
cnnOutputs.emplace_back(contextCNN(embeddings));
auto totalInput = torch::cat({contextCnnOut, lettersCnnOut}, 1);
auto totalInput = torch::cat(cnnOutputs, 1);
return linear2(torch::relu(linear1(totalInput)));
}
std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::stack<int> leftContext;
std::stack<std::string> leftForms;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index)
if (config.isToken(index))
for (auto & column : columns)
std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<long> context;
for (auto & col : columns)
for (auto index : contextIndexes)
if (index == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{
leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index)));
if (column == "FORM")
leftForms.push(config.getAsFeature(column, index));
}
auto & col = focusedColumns[colIndex];
std::vector<long> context;
std::vector<std::string> forms;
context.push_back(maxNbElements[colIndex]);
while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size()))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (forms.size() < leftBorder-leftForms.size())
forms.emplace_back("");
while (!leftForms.empty())
std::vector<int> focusedIndexes;
for (auto relIndex : focusedBufferIndexes)
{
forms.emplace_back(leftForms.top());
leftForms.pop();
int index = relIndex + leftBorder;
if (index < 0 || index >= (int)contextIndexes.size())
focusedIndexes.push_back(-1);
else
focusedIndexes.push_back(contextIndexes[index]);
}
while (!leftContext.empty())
for (auto index : focusedStackIndexes)
{
context.emplace_back(leftContext.top());
leftContext.pop();
if (!config.hasStack(index))
focusedIndexes.push_back(-1);
else if (!config.has(col, config.getStack(index), 0))
focusedIndexes.push_back(-1);
else
focusedIndexes.push_back(config.getStack(index));
}
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index)
if (config.isToken(index))
for (auto & column : columns)
for (auto index : focusedIndexes)
{
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index)));
if (column == "FORM")
forms.emplace_back(config.getAsFeature(column, index));
if (index == -1)
{
for (int i = 0; i < maxNbElements[colIndex]; i++)
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue;
}
while (context.size() < columns.size()*(leftBorder+rightBorder+1))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while ((int)forms.size() < leftBorder+rightBorder+1)
forms.emplace_back("");
std::vector<std::string> elements;
if (col == "FORM")
{
auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
for (int i = 0; i < nbStackElements; i++)
for (auto & column : columns)
if (config.hasStack(i))
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i))));
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
else
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (auto index : focusedBufferIndexes)
{
util::utf8string letters;
if (leftBorder+index >= 0 && leftBorder+index < (int)forms.size() && !forms[leftBorder+index].empty())
letters = util::splitAsUtf8(forms[leftBorder+index]);
for (unsigned int i = 0; i < maxNbLetters; i++)
{
if (i < letters.size())
{
std::string sLetter = fmt::format("Letter({})", letters[i]);
context.emplace_back(dict.getIndexOrInsert(sLetter));
elements.emplace_back(Dict::nullValueStr);
}
else
else if (col == "FEATS")
{
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
}
}
auto splited = util::split(config.getAsFeature(col, index).get(), '|');
for (auto index : focusedStackIndexes)
{
util::utf8string letters;
if (config.hasStack(index) and config.has("FORM", config.getStack(index),0))
letters = util::splitAsUtf8(config.getAsFeature("FORM", config.getStack(index)).get());
for (unsigned int i = 0; i < maxNbLetters; i++)
{
if (i < letters.size())
{
std::string sLetter = fmt::format("Letter({})", letters[i]);
context.emplace_back(dict.getIndexOrInsert(sLetter));
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)splited.size())
elements.emplace_back(fmt::format("FEATS({})", splited[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else
{
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
elements.emplace_back(config.getAsFeature(col, index));
}
if ((int)elements.size() != maxNbElements[colIndex])
util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
for (auto & element : elements)
context.emplace_back(dict.getIndexOrInsert(element));
}
}
......
......@@ -2,38 +2,50 @@
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config) const
{
std::stack<int> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index)
std::stack<long> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index)
if (config.isToken(index))
for (auto & column : columns)
leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index)));
leftContext.push(index);
std::vector<long> context;
while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size()))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (context.size() < leftBorder-leftContext.size())
context.emplace_back(-1);
while (!leftContext.empty())
{
context.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index)
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index)
if (config.isToken(index))
for (auto & column : columns)
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index)));
context.emplace_back(index);
while (context.size() < columns.size()*(leftBorder+rightBorder+1))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (context.size() < leftBorder+rightBorder+1)
context.emplace_back(-1);
for (int i = 0; i < nbStackElements; i++)
for (auto & column : columns)
for (unsigned int i = 0; i < nbStackElements; i++)
if (config.hasStack(i))
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i))));
context.emplace_back(config.getStack(i));
else
context.emplace_back(-1);
return context;
}
std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<long> indexes = extractContextIndexes(config);
std::vector<long> context;
for (auto & col : columns)
for (auto index : indexes)
if (index == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
return context;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment