Newer
Older
#include "AppliableTransModule.hpp"
AppliableTransModuleImpl::AppliableTransModuleImpl(std::string name, int nbTrans) : nbTrans(nbTrans)
{
setName(name);
}
torch::Tensor AppliableTransModuleImpl::forward(torch::Tensor input)
{
return input.narrow(1, firstInputIndex, getInputSize()).to(torch::kFloat);
}
std::size_t AppliableTransModuleImpl::getOutputSize()
{
return nbTrans;
}
std::size_t AppliableTransModuleImpl::getInputSize()
{
return nbTrans;
}
void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
auto & appliableTrans = config.getAppliableTransitions();
for (auto & contextElement : context)
for (int i = 0; i < nbTrans; i++)
if (i < (int)appliableTrans.size())
contextElement.emplace_back(appliableTrans[i]);
else
contextElement.emplace_back(0);
}
Franck Dary
committed
void AppliableTransModuleImpl::registerEmbeddings()