Skip to content
Snippets Groups Projects
SplitTransLSTM.cpp 1.03 KiB
#include "SplitTransLSTM.hpp"
#include "Transition.hpp"

SplitTransLSTMImpl::SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxNbTrans(maxNbTrans)
{
  lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
}

torch::Tensor SplitTransLSTMImpl::forward(torch::Tensor input)
{
  return lstm(input.narrow(1, firstInputIndex, getInputSize()));
}

std::size_t SplitTransLSTMImpl::getOutputSize()
{
  return lstm->getOutputSize(maxNbTrans);
}

std::size_t SplitTransLSTMImpl::getInputSize()
{
  return maxNbTrans;
}

void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
  auto & splitTransitions = config.getAppliableSplitTransitions();
  for (int i = 0; i < maxNbTrans; i++)
    if (i < (int)splitTransitions.size())
      context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
    else
      context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}