Skip to content
Snippets Groups Projects
Commit 8fecf5ee authored by Franck Dary's avatar Franck Dary
Browse files

Added focusedStackIndexes to CNNNetwork

parent f33d5d8c
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
private :
static inline std::vector<long> focusedBufferIndexes{0,1};
static inline std::vector<long> focusedStackIndexes{0,1};
static inline std::vector<long> windowSizes{2,3,4};
static constexpr unsigned int maxNbLetters = 10;
......
......@@ -13,7 +13,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*focusedBufferIndexes.size(), hiddenSize));
linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
for (auto & windowSize : windowSizes)
{
......@@ -28,7 +28,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
input = input.unsqueeze(0);
auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder));
auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*focusedBufferIndexes.size());
auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*(focusedBufferIndexes.size()+focusedStackIndexes.size()));
auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
......@@ -43,6 +43,14 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
for (unsigned int word = 0; word < focusedStackIndexes.size(); word++)
for (unsigned int i = 0; i < lettersCNNs.size(); i++)
{
auto input = permuted[focusedBufferIndexes.size()+word];
auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
auto lettersCnnOut = torch::cat(windows, 2);
lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1});
......@@ -133,6 +141,25 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
}
}
for (auto index : focusedStackIndexes)
{
util::utf8string letters;
if (config.hasStack(index) and config.has("FORM", config.getStack(index),0))
letters = util::splitAsUtf8(config.getAsFeature("FORM", config.getStack(index)).get());
for (unsigned int i = 0; i < maxNbLetters; i++)
{
if (i < letters.size())
{
std::string sLetter = fmt::format("Letter({})", letters[i]);
context.emplace_back(dict.getIndexOrInsert(sLetter));
}
else
{
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
}
}
return context;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment