Skip to content
Snippets Groups Projects
LSTMNetwork.cpp 4.21 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "LSTMNetwork.hpp"
    
    
    LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
    
    Franck Dary's avatar
    Franck Dary committed
      auto lstmOptionsAll = lstmOptions;
      std::get<4>(lstmOptionsAll) = true;
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      int currentOutputSize = embeddingsSize;
      int currentInputSize = 1;
    
      contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold));
      contextLSTM->setFirstInputIndex(currentInputSize);
      currentOutputSize += contextLSTM->getOutputSize();
      currentInputSize += contextLSTM->getInputSize();
    
      if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
    
    Franck Dary's avatar
    Franck Dary committed
      {
    
    Franck Dary's avatar
    Franck Dary committed
        hasRawInputLSTM = true;
        rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
        rawInputLSTM->setFirstInputIndex(currentInputSize);
        currentOutputSize += rawInputLSTM->getOutputSize();
        currentInputSize += rawInputLSTM->getInputSize();
    
    Franck Dary's avatar
    Franck Dary committed
      }
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll));
    
    Franck Dary's avatar
    Franck Dary committed
      splitTransLSTM->setFirstInputIndex(currentInputSize);
      currentOutputSize += splitTransLSTM->getOutputSize();
      currentInputSize += splitTransLSTM->getInputSize();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      for (unsigned int i = 0; i < focusedColumns.size(); i++)
    
    Franck Dary's avatar
    Franck Dary committed
      {
    
    Franck Dary's avatar
    Franck Dary committed
        focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions)));
        focusedLstms.back()->setFirstInputIndex(currentInputSize);
        currentOutputSize += focusedLstms.back()->getOutputSize();
        currentInputSize += focusedLstms.back()->getInputSize();
    
    Franck Dary's avatar
    Franck Dary committed
      }
    
    
    Franck Dary's avatar
    Franck Dary committed
      wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
      embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
      hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
    
      linear1 = register_module("linear1", torch::nn::Linear(currentOutputSize, hiddenSize));
    
    Franck Dary's avatar
    Franck Dary committed
      linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
    }
    
    torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
    {
      if (input.dim() == 1)
        input = input.unsqueeze(0);
    
      auto embeddings = embeddingsDropout(wordEmbeddings(input));
    
    
    Franck Dary's avatar
    Franck Dary committed
      std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      outputs.emplace_back(contextLSTM(embeddings));
    
    Franck Dary's avatar
    Franck Dary committed
      if (hasRawInputLSTM)
        outputs.emplace_back(rawInputLSTM(embeddings));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      outputs.emplace_back(splitTransLSTM(embeddings));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      for (auto & lstm : focusedLstms)
        outputs.emplace_back(lstm(embeddings));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      auto totalInput = torch::cat(outputs, 1);
    
    Franck Dary's avatar
    Franck Dary committed
    
      return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
    }
    
    std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
    {
      if (dict.size() >= maxNbEmbeddings)
        util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
    
      std::vector<std::vector<long>> context;
      context.emplace_back();
    
    
    Franck Dary's avatar
    Franck Dary committed
      context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
    
    
    Franck Dary's avatar
    Franck Dary committed
      contextLSTM->addToContext(context, dict, config);
      if (hasRawInputLSTM)
        rawInputLSTM->addToContext(context, dict, config);
      splitTransLSTM->addToContext(context, dict, config);
      for (auto & lstm : focusedLstms)
        lstm->addToContext(context, dict, config);
    
    Franck Dary's avatar
    Franck Dary committed
    
      if (!is_training() && context.size() > 1)
        util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
    
      return context;
    }