Newer
Older
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
constexpr int hiddenSize = 1024;
constexpr int nbFiltersContext = 512;
constexpr int nbFiltersFocused = 64;
setBufferContext(bufferContext);
setStackContext(stackContext);
setBufferFocused(focusedBufferIndexes);
setStackFocused(focusedStackIndexes);
rawInputSize = leftWindowRawInput + rightWindowRawInput + 1;
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
rawInputSize = 0;
else
rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
Franck Dary
committed
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
for (auto & col : focusedColumns)
{
std::vector<int> windows{2,3,4};
cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize)));
totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size());
}
linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}
torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto context = embeddings.narrow(1, rawInputSize, getContextSize());
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
if (rawInputSize != 0)
{
auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
}
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1);
curIndex += nbElements;
cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1));
return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
Franck Dary
committed
if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<std::vector<long>> context;
context.emplace_back();
if (rawInputSize > 0)
{
for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (auto index : contextIndexes)
for (auto & col : columns)
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
Franck Dary
committed
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
if (col == "FORM" || col == "LEMMA")
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
{
context.emplace_back(context.back());
context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
}
}
std::vector<long> focusedIndexes = extractFocusedIndexes(config);
for (auto & contextElement : context)
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
auto & col = focusedColumns[colIndex];
for (auto index : focusedIndexes)
if (index == -1)
{
for (int i = 0; i < maxNbElements[colIndex]; i++)
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue;
}
std::vector<std::string> elements;
if (col == "FORM")
{
auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("{}", asUtf8[i]));
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "FEATS")
{
auto splited = util::split(config.getAsFeature(col, index).get(), '|');
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)splited.size())
elements.emplace_back(fmt::format("FEATS({})", splited[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
}
else
{
elements.emplace_back(config.getAsFeature(col, index));
}
if ((int)elements.size() != maxNbElements[colIndex])
util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
for (auto & element : elements)
contextElement.emplace_back(dict.getIndexOrInsert(element));
if (!is_training() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));