Newer
Older
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 512;
constexpr int nbFilters = 512;
constexpr int nbFiltersLetters = 64;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
contextCNN = register_module("contextCNN", CNN(std::vector<long>{2,3,4}, nbFilters, 2*embeddingsSize));
lettersCNN = register_module("lettersCNN", CNN(std::vector<long>{2,3,4,5}, nbFiltersLetters, embeddingsSize));
linear1 = register_module("linear1", torch::nn::Linear(contextCNN->getOutputSize()+lettersCNN->getOutputSize()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}
torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder));
auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*(focusedBufferIndexes.size()+focusedStackIndexes.size()));
auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
auto permuted = lettersEmbeddings.permute({2,0,1,3,4});
for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++)
for (unsigned int word = 0; word < focusedStackIndexes.size(); word++)
cnnOuts.emplace_back(lettersCNN(permuted[word]));
auto lettersCnnOut = torch::cat(cnnOuts, 1);
auto totalInput = torch::cat({contextCnnOut, lettersCnnOut}, 1);
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
return linear2(torch::relu(linear1(totalInput)));
}
std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::stack<int> leftContext;
std::stack<std::string> leftForms;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index)
if (config.isToken(index))
for (auto & column : columns)
{
leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index)));
if (column == "FORM")
leftForms.push(config.getAsFeature(column, index));
}
std::vector<long> context;
std::vector<std::string> forms;
while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size()))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (forms.size() < leftBorder-leftForms.size())
forms.emplace_back("");
while (!leftForms.empty())
{
forms.emplace_back(leftForms.top());
leftForms.pop();
}
while (!leftContext.empty())
{
context.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index)
if (config.isToken(index))
for (auto & column : columns)
{
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index)));
if (column == "FORM")
forms.emplace_back(config.getAsFeature(column, index));
}
while (context.size() < columns.size()*(leftBorder+rightBorder+1))
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while ((int)forms.size() < leftBorder+rightBorder+1)
forms.emplace_back("");
for (int i = 0; i < nbStackElements; i++)
for (auto & column : columns)
if (config.hasStack(i))
context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i))));
else
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (auto index : focusedBufferIndexes)
{
util::utf8string letters;
if (leftBorder+index >= 0 && leftBorder+index < (int)forms.size() && !forms[leftBorder+index].empty())
letters = util::splitAsUtf8(forms[leftBorder+index]);
for (unsigned int i = 0; i < maxNbLetters; i++)
{
if (i < letters.size())
{
std::string sLetter = fmt::format("Letter({})", letters[i]);
context.emplace_back(dict.getIndexOrInsert(sLetter));
}
else
{
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
}
}
for (auto index : focusedStackIndexes)
{
util::utf8string letters;
if (config.hasStack(index) and config.has("FORM", config.getStack(index),0))
letters = util::splitAsUtf8(config.getAsFeature("FORM", config.getStack(index)).get());
for (unsigned int i = 0; i < maxNbLetters; i++)
{
if (i < letters.size())
{
std::string sLetter = fmt::format("Letter({})", letters[i]);
context.emplace_back(dict.getIndexOrInsert(sLetter));
}
else
{
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
}
}