#include "SplitTransLSTM.hpp" #include "Transition.hpp" SplitTransLSTMImpl::SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxNbTrans(maxNbTrans) { lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); } torch::Tensor SplitTransLSTMImpl::forward(torch::Tensor input) { return lstm(input.narrow(1, firstInputIndex, getInputSize())); } std::size_t SplitTransLSTMImpl::getOutputSize() { return lstm->getOutputSize(maxNbTrans); } std::size_t SplitTransLSTMImpl::getInputSize() { return maxNbTrans; } void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { auto & splitTransitions = config.getAppliableSplitTransitions(); for (int i = 0; i < maxNbTrans; i++) if (i < (int)splitTransitions.size()) context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); else context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); }