-
Franck Dary authoredFranck Dary authored
SplitTransLSTM.cpp 1.03 KiB
#include "SplitTransLSTM.hpp"
#include "Transition.hpp"
SplitTransLSTMImpl::SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxNbTrans(maxNbTrans)
{
lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
}
torch::Tensor SplitTransLSTMImpl::forward(torch::Tensor input)
{
return lstm(input.narrow(1, firstInputIndex, getInputSize()));
}
std::size_t SplitTransLSTMImpl::getOutputSize()
{
return lstm->getOutputSize(maxNbTrans);
}
std::size_t SplitTransLSTMImpl::getInputSize()
{
return maxNbTrans;
}
void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
auto & splitTransitions = config.getAppliableSplitTransitions();
for (int i = 0; i < maxNbTrans; i++)
if (i < (int)splitTransitions.size())
context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
else
context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}