#include "AppliableTransModule.hpp" AppliableTransModuleImpl::AppliableTransModuleImpl(std::string name, int nbTrans) : nbTrans(nbTrans) { setName(name); } torch::Tensor AppliableTransModuleImpl::forward(torch::Tensor input) { return input.narrow(1, firstInputIndex, getInputSize()).to(torch::kFloat); } std::size_t AppliableTransModuleImpl::getOutputSize() { return nbTrans; } std::size_t AppliableTransModuleImpl::getInputSize() { return nbTrans; } void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) { auto & appliableTrans = config.getAppliableTransitions(); for (auto & contextElement : context) for (int i = 0; i < nbTrans; i++) if (i < (int)appliableTrans.size()) contextElement.emplace_back(appliableTrans[i]); else contextElement.emplace_back(0); } void AppliableTransModuleImpl::registerEmbeddings() { }