Newer
Older
#include "CNNNetwork.hpp"
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 512;
constexpr int nbFilters = 512;
constexpr int nbFiltersLetters = 64;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
for (auto & windowSize : windowSizes)
{
CNNs.emplace_back(register_module(fmt::format("cnn_context_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,2*embeddingsSize})).padding({windowSize-1, 0}))));
lettersCNNs.emplace_back(register_module(fmt::format("cnn_letters_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFiltersLetters, torch::ExpandingArray<2>({windowSize,embeddingsSize})).padding({windowSize-1, 0}))));
}
}
torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder));
auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*(focusedBufferIndexes.size()+focusedStackIndexes.size()));
auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
auto permuted = lettersEmbeddings.permute({2,0,1,3,4});
std::vector<torch::Tensor> windows;
for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++)
for (unsigned int i = 0; i < lettersCNNs.size(); i++)
{
auto input = permuted[word];
auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
for (unsigned int word = 0; word < focusedStackIndexes.size(); word++)
for (unsigned int i = 0; i < lettersCNNs.size(); i++)
{
auto input = permuted[focusedBufferIndexes.size()+word];
auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
auto lettersCnnOut = torch::cat(windows, 2);
lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1});
windows.clear();
for (unsigned int i = 0; i < CNNs.size(); i++)
{
auto convOut = torch::relu(CNNs[i](embeddings).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
auto cnnOut = torch::cat(windows, 2);
cnnOut = cnnOut.view({cnnOut.size(0), -1});
auto totalInput = torch::cat({cnnOut, lettersCnnOut}, 1);
return linear2(torch::relu(linear1(totalInput)));
}
std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::stack<int> leftContext;
std::stack<std::string> leftForms;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index)
if (config.isToken(index))
for (auto & column : columns)
{
leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index)));
if (column == "FORM")
leftForms.push(config.getAsFeature(column, index));
}
std::vector<long> context;
std::vector<std::string> forms;
while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size()))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (forms.size() < leftBorder-leftForms.size())
forms.emplace_back("");
while (!leftForms.empty())
{
forms.emplace_back(leftForms.top());
leftForms.pop();
}
while (!leftContext.empty())
{
context.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index)
if (config.isToken(index))
for (auto & column : columns)
{
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index)));
if (column == "FORM")
forms.emplace_back(config.getAsFeature(column, index));
}
while (context.size() < columns.size()*(leftBorder+rightBorder+1))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while ((int)forms.size() < leftBorder+rightBorder+1)
forms.emplace_back("");
for (int i = 0; i < nbStackElements; i++)
for (auto & column : columns)
if (config.hasStack(i))
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i))));
else
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (auto index : focusedBufferIndexes)
{
util::utf8string letters;
if (leftBorder+index >= 0 && leftBorder+index < (int)forms.size() && !forms[leftBorder+index].empty())
letters = util::splitAsUtf8(forms[leftBorder+index]);
for (unsigned int i = 0; i < maxNbLetters; i++)
{
if (i < letters.size())
{
std::string sLetter = fmt::format("Letter({})", letters[i]);
context.emplace_back(dict.getIndexOrInsert(sLetter));
}
else
{
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
}
}
for (auto index : focusedStackIndexes)
{
util::utf8string letters;
if (config.hasStack(index) and config.has("FORM", config.getStack(index),0))
letters = util::splitAsUtf8(config.getAsFeature("FORM", config.getStack(index)).get());
for (unsigned int i = 0; i < maxNbLetters; i++)
{
if (i < letters.size())
{
std::string sLetter = fmt::format("Letter({})", letters[i]);
context.emplace_back(dict.getIndexOrInsert(sLetter));
}
else
{
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
}
}