Skip to content
Snippets Groups Projects
CNNNetwork.cpp 5.46 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "CNNNetwork.hpp"

CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
  constexpr int embeddingsSize = 64;
  constexpr int hiddenSize = 512;
  constexpr int nbFilters = 512;
  constexpr int nbFiltersLetters = 64;

  setLeftBorder(leftBorder);
  setRightBorder(rightBorder);
  setNbStackElements(nbStackElements);
  setColumns({"FORM", "UPOS"});

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
  linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*focusedBufferIndexes.size(), hiddenSize));
  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
  for (auto & windowSize : windowSizes)
  {
    CNNs.emplace_back(register_module(fmt::format("cnn_context_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,2*embeddingsSize})).padding({windowSize-1, 0}))));
    lettersCNNs.emplace_back(register_module(fmt::format("cnn_letters_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFiltersLetters, torch::ExpandingArray<2>({windowSize,embeddingsSize})).padding({windowSize-1, 0}))));
  }
}

torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
{
  if (input.dim() == 1)
    input = input.unsqueeze(0);

  auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder));
  auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*focusedBufferIndexes.size());

  auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
  auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1);

  auto permuted = lettersEmbeddings.permute({2,0,1,3,4});
  std::vector<torch::Tensor> windows;
  for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++)
    for (unsigned int i = 0; i < lettersCNNs.size(); i++)
    {
      auto input = permuted[word];
      auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
      auto pooled = torch::max_pool1d(convOut, convOut.size(2));
      windows.emplace_back(pooled);
    }
  auto lettersCnnOut = torch::cat(windows, 2);
  lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1});

  windows.clear();
  for (unsigned int i = 0; i < CNNs.size(); i++)
  {
    auto convOut = torch::relu(CNNs[i](embeddings).squeeze(-1));
    auto pooled = torch::max_pool1d(convOut, convOut.size(2));
    windows.emplace_back(pooled);
  }

  auto cnnOut = torch::cat(windows, 2);
  cnnOut = cnnOut.view({cnnOut.size(0), -1});

  auto totalInput = torch::cat({cnnOut, lettersCnnOut}, 1);

  return linear2(torch::relu(linear1(totalInput)));
}

std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{
  std::stack<int> leftContext;
  std::stack<std::string> leftForms;
  for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index)
    if (config.isToken(index))
      for (auto & column : columns)
      {
        leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index)));
        if (column == "FORM")
          leftForms.push(config.getAsFeature(column, index));
      }

  std::vector<long> context;
  std::vector<std::string> forms;

  while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size()))
    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
  while (forms.size() < leftBorder-leftForms.size())
    forms.emplace_back("");
  while (!leftForms.empty())
  {
    forms.emplace_back(leftForms.top());
    leftForms.pop();
  }
  while (!leftContext.empty())
  {
    context.emplace_back(leftContext.top());
    leftContext.pop();
  }

  for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index)
    if (config.isToken(index))
      for (auto & column : columns)
      {
        context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index)));
        if (column == "FORM")
          forms.emplace_back(config.getAsFeature(column, index));
      }

  while (context.size() < columns.size()*(leftBorder+rightBorder+1))
    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
  while ((int)forms.size() < leftBorder+rightBorder+1)
    forms.emplace_back("");

  for (int i = 0; i < nbStackElements; i++)
    for (auto & column : columns)
      if (config.hasStack(i))
        context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i))));
      else
        context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));

  for (auto index : focusedBufferIndexes)
  {
    util::utf8string letters;
    if (leftBorder+index >= 0 && leftBorder+index < (int)forms.size() && !forms[leftBorder+index].empty())
      letters = util::splitAsUtf8(forms[leftBorder+index]);
    for (unsigned int i = 0; i < maxNbLetters; i++)
    {
      if (i < letters.size())
      {
        std::string sLetter = fmt::format("Letter({})", letters[i]);
        context.emplace_back(dict.getIndexOrInsert(sLetter));
      }
      else
      {
        context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
      }
    }
  }

  return context;
}