#include "AppliableTransModule.hpp"

AppliableTransModuleImpl::AppliableTransModuleImpl(std::string name, int nbTrans) : nbTrans(nbTrans)
{
  setName(name);
}

torch::Tensor AppliableTransModuleImpl::forward(torch::Tensor input)
{
  return input.narrow(1, firstInputIndex, getInputSize()).to(torch::kFloat);
}

std::size_t AppliableTransModuleImpl::getOutputSize()
{
  return nbTrans;
}

std::size_t AppliableTransModuleImpl::getInputSize()
{
  return nbTrans;
}

void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
  auto & appliableTrans = config.getAppliableTransitions();
  for (auto & contextElement : context)
    for (int i = 0; i < nbTrans; i++)
      if (i < (int)appliableTrans.size())
        contextElement.emplace_back(appliableTrans[i]);
      else
        contextElement.emplace_back(0);
}

void AppliableTransModuleImpl::registerEmbeddings()
{
}