Newer
Older
#include "SplitTransModule.hpp"
#include "Transition.hpp"
SplitTransModule::SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : maxNbTrans(maxNbTrans)
{
myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
}
torch::Tensor SplitTransModule::forward(torch::Tensor input)
{
return myModule->forward(input.narrow(1, firstInputIndex, getInputSize()));
}
std::size_t SplitTransModule::getOutputSize()
{
return myModule->getOutputSize(maxNbTrans);
}
std::size_t SplitTransModule::getInputSize()
{
return maxNbTrans;
}
void SplitTransModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
auto & splitTransitions = config.getAppliableSplitTransitions();
for (auto & contextElement : context)
for (int i = 0; i < maxNbTrans; i++)
if (i < (int)splitTransitions.size())
contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
else
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}