Newer
Older
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
auto lstmOptionsAll = lstmOptions;
std::get<4>(lstmOptionsAll) = true;
int currentOutputSize = embeddingsSize;
int currentInputSize = 1;
contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold));
contextLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += contextLSTM->getOutputSize();
currentInputSize += contextLSTM->getInputSize();
if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
hasRawInputLSTM = true;
rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
rawInputLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += rawInputLSTM->getOutputSize();
currentInputSize += rawInputLSTM->getInputSize();
splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll));
splitTransLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += splitTransLSTM->getOutputSize();
currentInputSize += splitTransLSTM->getInputSize();
focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions)));
focusedLstms.back()->setFirstInputIndex(currentInputSize);
currentOutputSize += focusedLstms.back()->getOutputSize();
currentInputSize += focusedLstms.back()->getInputSize();
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
linear1 = register_module("linear1", torch::nn::Linear(currentOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}
torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto embeddings = embeddingsDropout(wordEmbeddings(input));
std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
if (hasRawInputLSTM)
outputs.emplace_back(rawInputLSTM(embeddings));
for (auto & lstm : focusedLstms)
outputs.emplace_back(lstm(embeddings));
return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
}
std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
{
if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<std::vector<long>> context;
context.emplace_back();
context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
contextLSTM->addToContext(context, dict, config);
if (hasRawInputLSTM)
rawInputLSTM->addToContext(context, dict, config);
splitTransLSTM->addToContext(context, dict, config);
for (auto & lstm : focusedLstms)
lstm->addToContext(context, dict, config);
if (!is_training() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
return context;
}