Skip to content
Snippets Groups Projects
LSTMNetwork.cpp 8.52 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "LSTMNetwork.hpp"

LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
Franck Dary's avatar
Franck Dary committed
{
  constexpr int embeddingsSize = 64;
  constexpr int hiddenSize = 1024;
  constexpr int contextLSTMSize = 512;
  constexpr int focusedLSTMSize = 64;
Franck Dary's avatar
Franck Dary committed
  constexpr int rawInputLSTMSize = 16;
Franck Dary's avatar
Franck Dary committed

  setBufferContext(bufferContext);
  setStackContext(stackContext);
Franck Dary's avatar
Franck Dary committed
  setColumns(columns);
  setBufferFocused(focusedBufferIndexes);
  setStackFocused(focusedStackIndexes);
Franck Dary's avatar
Franck Dary committed

  rawInputSize =  leftWindowRawInput + rightWindowRawInput + 1;
  if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
    rawInputSize = 0;
  else
Franck Dary's avatar
Franck Dary committed
    rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, rawInputLSTMSize).batch_first(true).bidirectional(true)));
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
  int rawInputLSTMOutputSize = 0;
  if (rawInputSize > 0)
    rawInputLSTMOutputSize = (rawInputSize * rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 2 : 1));
Franck Dary's avatar
Franck Dary committed

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
  embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
  lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3));
  hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
Franck Dary's avatar
Franck Dary committed
  contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(true).bidirectional(true)));
Franck Dary's avatar
Franck Dary committed

  int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize;

  for (auto & col : focusedColumns)
  {
Franck Dary's avatar
Franck Dary committed
    lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(true).bidirectional(true))));
    totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (bufferFocused.size()+stackFocused.size());
Franck Dary's avatar
Franck Dary committed
  }

  linear1 = register_module("linear1", torch::nn::Linear(totalLSTMOutputSize, hiddenSize));
  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}

torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
{
  if (input.dim() == 1)
    input = input.unsqueeze(0);

  auto embeddings = embeddingsDropout(wordEmbeddings(input));

  auto context = embeddings.narrow(1, rawInputSize, getContextSize());
Franck Dary's avatar
Franck Dary committed
  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});

  auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));

  std::vector<torch::Tensor> lstmOutputs;

  if (rawInputSize != 0)
  {
Franck Dary's avatar
Franck Dary committed
    auto rawLetters = embeddings.narrow(1, 0, rawInputSize);
Franck Dary's avatar
Franck Dary committed
    auto lstmOut = rawInputLSTM(rawLetters).output;
Franck Dary's avatar
Franck Dary committed
    lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1}));
Franck Dary's avatar
Franck Dary committed
  }

  auto curIndex = 0;
  for (unsigned int i = 0; i < focusedColumns.size(); i++)
  {
    long nbElements = maxNbElements[i];
    for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
Franck Dary's avatar
Franck Dary committed
    {
Franck Dary's avatar
Franck Dary committed
      auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements);
Franck Dary's avatar
Franck Dary committed
      curIndex += nbElements;
      auto lstmOut = lstms[i](lstmInput).output;

      if (lstms[i]->options.bidirectional())
Franck Dary's avatar
Franck Dary committed
        lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1));
Franck Dary's avatar
Franck Dary committed
      else
Franck Dary's avatar
Franck Dary committed
        lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1));
Franck Dary's avatar
Franck Dary committed
    }
  }

  auto lstmOut = contextLSTM(context).output;
  if (contextLSTM->options.bidirectional())
Franck Dary's avatar
Franck Dary committed
    lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1));
   else
     lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1));
Franck Dary's avatar
Franck Dary committed

  auto totalInput = lstmDropout(torch::cat(lstmOutputs, 1));

  return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
}

std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
{
  if (dict.size() >= maxNbEmbeddings)
    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));

  std::vector<long> contextIndexes = extractContextIndexes(config);
  std::vector<std::vector<long>> context;
  context.emplace_back();

  if (rawInputSize > 0)
  {
    for (int i = 0; i < leftWindowRawInput; i++)
      if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
        context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
Franck Dary's avatar
Franck Dary committed
      else
        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));

    for (int i = 0; i <= rightWindowRawInput; i++)
      if (config.hasCharacter(config.getCharacterIndex()+i))
        context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
Franck Dary's avatar
Franck Dary committed
      else
        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
  }

  for (auto index : contextIndexes)
    for (auto & col : columns)
      if (index == -1)
        for (auto & contextElement : context)
          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
      else
      {
        int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));

        for (auto & contextElement : context)
          contextElement.push_back(dictIndex);

        if (is_training())
          if (col == "FORM" || col == "LEMMA")
            if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
            {
              context.emplace_back(context.back());
              context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
            }
      }

  std::vector<long> focusedIndexes = extractFocusedIndexes(config);

Franck Dary's avatar
Franck Dary committed
  for (auto & contextElement : context)
    for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
    {
      auto & col = focusedColumns[colIndex];

      for (auto index : focusedIndexes)
      {
        if (index == -1)
        {
          for (int i = 0; i < maxNbElements[colIndex]; i++)
            contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
          continue;
        }

        std::vector<std::string> elements;
        if (col == "FORM")
        {
          auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());

          for (int i = 0; i < maxNbElements[colIndex]; i++)
            if (i < (int)asUtf8.size())
              elements.emplace_back(fmt::format("{}", asUtf8[i]));
Franck Dary's avatar
Franck Dary committed
            else
              elements.emplace_back(Dict::nullValueStr);
        }
        else if (col == "FEATS")
        {
          auto splited = util::split(config.getAsFeature(col, index).get(), '|');

          for (int i = 0; i < maxNbElements[colIndex]; i++)
            if (i < (int)splited.size())
              elements.emplace_back(fmt::format("FEATS({})", splited[i]));
            else
              elements.emplace_back(Dict::nullValueStr);
        }
        else if (col == "ID")
        {
          if (config.isTokenPredicted(index))
            elements.emplace_back("ID(TOKEN)");
          else if (config.isMultiwordPredicted(index))
            elements.emplace_back("ID(MULTIWORD)");
          else if (config.isEmptyNodePredicted(index))
            elements.emplace_back("ID(EMPTYNODE)");
        }
        else
        {
          elements.emplace_back(config.getAsFeature(col, index));
        }

        if ((int)elements.size() != maxNbElements[colIndex])
          util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));

        for (auto & element : elements)
          contextElement.emplace_back(dict.getIndexOrInsert(element));
      }
    }

  if (!is_training() && context.size() > 1)
    util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));

  return context;
}