#include "SplitTransModule.hpp" #include "Transition.hpp" SplitTransModule::SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : maxNbTrans(maxNbTrans) { myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); } torch::Tensor SplitTransModule::forward(torch::Tensor input) { return myModule->forward(input.narrow(1, firstInputIndex, getInputSize())); } std::size_t SplitTransModule::getOutputSize() { return myModule->getOutputSize(maxNbTrans); } std::size_t SplitTransModule::getInputSize() { return maxNbTrans; } void SplitTransModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { auto & splitTransitions = config.getAppliableSplitTransitions(); for (auto & contextElement : context) for (int i = 0; i < maxNbTrans; i++) if (i < (int)splitTransitions.size()) contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); else contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); }