Skip to content
Snippets Groups Projects
AppliableTransModule.cpp 922 B
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "AppliableTransModule.hpp"
    
    AppliableTransModuleImpl::AppliableTransModuleImpl(std::string name, int nbTrans) : nbTrans(nbTrans)
    {
      setName(name);
    }
    
    torch::Tensor AppliableTransModuleImpl::forward(torch::Tensor input)
    {
      return input.narrow(1, firstInputIndex, getInputSize()).to(torch::kFloat);
    }
    
    std::size_t AppliableTransModuleImpl::getOutputSize()
    {
      return nbTrans;
    }
    
    std::size_t AppliableTransModuleImpl::getInputSize()
    {
      return nbTrans;
    }
    
    void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
    {
      auto & appliableTrans = config.getAppliableTransitions();
      for (auto & contextElement : context)
        for (int i = 0; i < nbTrans; i++)
          if (i < (int)appliableTrans.size())
            contextElement.emplace_back(appliableTrans[i]);
          else
            contextElement.emplace_back(0);
    }
    
    
    void AppliableTransModuleImpl::registerEmbeddings()