Skip to content
Snippets Groups Projects
ContextLSTM.cpp 2.16 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "ContextLSTM.hpp"
    
    ContextLSTMImpl::ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold) : columns(columns), bufferContext(bufferContext), stackContext(stackContext), unknownValueThreshold(unknownValueThreshold)
    {
      lstm = register_module("lstm", LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options));
    }
    
    std::size_t ContextLSTMImpl::getOutputSize()
    {
      return lstm->getOutputSize(bufferContext.size()+stackContext.size());
    }
    
    std::size_t ContextLSTMImpl::getInputSize()
    {
      return columns.size()*(bufferContext.size()+stackContext.size());
    }
    
    void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
    {
      std::vector<long> contextIndexes;
    
      for (int index : bufferContext)
        contextIndexes.emplace_back(config.getRelativeWordIndex(index));
    
      for (int index : stackContext)
        if (config.hasStack(index))
          contextIndexes.emplace_back(config.getStack(index));
        else
          contextIndexes.emplace_back(-1);
    
      for (auto index : contextIndexes)
        for (auto & col : columns)
          if (index == -1)
            for (auto & contextElement : context)
              contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
          else
          {
            int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
    
            for (auto & contextElement : context)
              contextElement.push_back(dictIndex);
    
            if (is_training())
              for (auto & targetCol : unknownValueColumns)
                if (col == targetCol)
                  if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
                  {
                    context.emplace_back(context.back());
                    context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
                  }
          }
    }
    
    torch::Tensor ContextLSTMImpl::forward(torch::Tensor input)
    {
      auto context = input.narrow(1, firstInputIndex, getInputSize());
    
      context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)});
    
      return lstm(context);
    }