Skip to content
Snippets Groups Projects
Select Git revision
  • 81fdf35497fbce9dd113c8ad40798f5222f24231
  • master default protected
  • loss
  • producer
4 results

AppliableTransModule.cpp

Blame
  • Franck Dary's avatar
    Franck Dary authored
    Removed pretrainedEmbeddings as a global parameter, instead submodules can now have their own pretrained w2v
    05062ca7
    History
    AppliableTransModule.cpp 922 B
    #include "AppliableTransModule.hpp"
    
    AppliableTransModuleImpl::AppliableTransModuleImpl(std::string name, int nbTrans) : nbTrans(nbTrans)
    {
      setName(name);
    }
    
    torch::Tensor AppliableTransModuleImpl::forward(torch::Tensor input)
    {
      return input.narrow(1, firstInputIndex, getInputSize()).to(torch::kFloat);
    }
    
    std::size_t AppliableTransModuleImpl::getOutputSize()
    {
      return nbTrans;
    }
    
    std::size_t AppliableTransModuleImpl::getInputSize()
    {
      return nbTrans;
    }
    
    void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
    {
      auto & appliableTrans = config.getAppliableTransitions();
      for (auto & contextElement : context)
        for (int i = 0; i < nbTrans; i++)
          if (i < (int)appliableTrans.size())
            contextElement.emplace_back(appliableTrans[i]);
          else
            contextElement.emplace_back(0);
    }
    
    void AppliableTransModuleImpl::registerEmbeddings()
    {
    }