Skip to content
Snippets Groups Projects
RTLSTMNetwork.cpp 6.84 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "RTLSTMNetwork.hpp"

RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
  constexpr int embeddingsSize = 30;
Franck Dary's avatar
Franck Dary committed
  constexpr int lstmOutputSize = 128;
  constexpr int treeEmbeddingsSize = 256;
Franck Dary's avatar
Franck Dary committed
  constexpr int hiddenSize = 500;
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
  setLeftBorder(leftBorder);
  setRightBorder(rightBorder);
  setNbStackElements(nbStackElements);
  setColumns({"FORM", "UPOS"});
Franck Dary's avatar
Franck Dary committed

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
Franck Dary's avatar
Franck Dary committed
  linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
Franck Dary's avatar
Franck Dary committed
  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
Franck Dary's avatar
Franck Dary committed
  vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
  treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false)));
  S = register_parameter("S", torch::randn(treeEmbeddingsSize));
  nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
Franck Dary's avatar
Franck Dary committed
}

torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
{
Franck Dary's avatar
Franck Dary committed
  input = input.squeeze();
  if (input.dim() != 1)
    util::myThrow(fmt::format("Does not support batched input (dim()={})", input.dim()));

  auto focusedIndexes = input.narrow(0, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
  auto computeOrder = input.narrow(0, focusedIndexes.size(0), leftBorder+rightBorder+1);
  auto childsFlat = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0), maxNbChilds*(leftBorder+rightBorder+1));
  auto childs = torch::reshape(childsFlat, {computeOrder.size(0), maxNbChilds});
  auto wordIndexes = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0)+childsFlat.size(0), columns.size()*(leftBorder+rightBorder+1));
  auto baseEmbeddings = wordEmbeddings(wordIndexes);
  auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {(int)baseEmbeddings.size(0)/(int)columns.size(), (int)baseEmbeddings.size(1)*(int)columns.size()}).unsqueeze(0);
  auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output.squeeze();
  std::vector<torch::Tensor> treeRepresentations(vectorRepresentations.size(0), nullTree);
  for (unsigned int i = 0; i < computeOrder.size(0); i++)
  {
    int index = computeOrder[i].item<int>();
    if (index == -1)
      break;
    std::vector<torch::Tensor> inputVector;
    inputVector.emplace_back(torch::cat({vectorRepresentations[index], S}, 0));
    for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
    {
      int child = childs[index][childIndex].item<int>();
      if (child == -1)
        break;
      inputVector.emplace_back(torch::cat({vectorRepresentations[index], treeRepresentations[child]}, 0));
    }
    auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
    auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
    treeRepresentations[index] = lstmOut;
  }

  std::vector<torch::Tensor> focusedTrees;
  for (unsigned int i = 0; i < focusedIndexes.size(0); i++)
  {
    int index = focusedIndexes[i].item<int>();
    if (index == -1)
      focusedTrees.emplace_back(nullTree);
    else
      focusedTrees.emplace_back(treeRepresentations[index]);
  }

  auto representation = torch::cat(focusedTrees, 0);
  return linear2(torch::relu(linear1(representation)));
}

std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
{
  std::vector<long> contextIndexes;
  std::stack<int> leftContext;
  for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
    if (config.isToken(index))
      leftContext.push(index);

  while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size())
    contextIndexes.emplace_back(-1);
  while (!leftContext.empty())
  {
    contextIndexes.emplace_back(leftContext.top());
    leftContext.pop();
  }

  for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index)
    if (config.isToken(index))
      contextIndexes.emplace_back(index);

  while ((int)contextIndexes.size() < leftBorder+rightBorder+1)
    contextIndexes.emplace_back(-1);

  std::map<long, long> indexInContext;
  for (auto & l : contextIndexes)
    indexInContext.emplace(std::make_pair(l, indexInContext.size()));

  std::vector<long> headOf;
  for (auto & l : contextIndexes)
  {
    if (l == -1)
      headOf.push_back(-1);
    else
    {
      auto & head = config.getAsFeature(Config::headColName, l);
      if (util::isEmpty(head) or head == "_")
        headOf.push_back(-1);
      else if  (indexInContext.count(std::stoi(head)))
        headOf.push_back(std::stoi(head));
      else
        headOf.push_back(-1);
    }
  }

  std::vector<std::vector<long>> childs(headOf.size());
  for (unsigned int i = 0; i < headOf.size(); i++)
    if (headOf[i] != -1)
      childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]);

  std::vector<long> treeComputationOrder;
  std::vector<bool> treeIsComputed(contextIndexes.size(), false);

  std::function<void(long)> depthFirst;
  depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root)
  {
    if (!indexInContext.count(root))
      return;

    if (treeIsComputed[indexInContext[root]])
      return;

    for (auto child : childs[indexInContext[root]])
      depthFirst(child);

    treeIsComputed[indexInContext[root]] = true;
    treeComputationOrder.push_back(indexInContext[root]);
  };

  for (auto & l : focusedBufferIndexes)
    if (contextIndexes[leftBorder+l] != -1)
      depthFirst(contextIndexes[leftBorder+l]);

  for (auto & l : focusedStackIndexes)
    if (config.hasStack(l))
      depthFirst(config.getStack(l));

  std::vector<long> context;
  
  for (auto & c : focusedBufferIndexes)
    context.push_back(leftBorder+c);
  for (auto & c : focusedStackIndexes)
    if (config.hasStack(c) && indexInContext.count(config.getStack(c)))
      context.push_back(indexInContext[config.getStack(c)]);
    else
      context.push_back(-1);
  for (auto & c : treeComputationOrder)
    context.push_back(c);
  while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size())
    context.push_back(-1);
  for (auto & c : childs)
  {
    for (unsigned int i = 0; i < maxNbChilds; i++)
      if (i < c.size())
        context.push_back(indexInContext[c[i]]);
      else
        context.push_back(-1);
  }
  for (auto & l : contextIndexes)
    for (auto & col : columns)
      if (l == -1)
        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
      else
        context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));

  return context;