Skip to content
Snippets Groups Projects
CNNNetwork.cpp 6.53 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "CNNNetwork.hpp"

CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
Franck Dary's avatar
Franck Dary committed
{
  constexpr int embeddingsSize = 64;
  constexpr int hiddenSize = 512;
  constexpr int nbFiltersContext = 512;
  constexpr int nbFiltersFocused = 64;
Franck Dary's avatar
Franck Dary committed

  setLeftBorder(leftBorder);
  setRightBorder(rightBorder);
  setNbStackElements(nbStackElements);
  setColumns(columns);
Franck Dary's avatar
Franck Dary committed

  rawInputSize =  leftWindowRawInput + rightWindowRawInput + 1;
  if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
    rawInputSize = 0;
  else
    rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
  int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();

Franck Dary's avatar
Franck Dary committed
  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
  contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
  int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
  for (auto & col : focusedColumns)
  {
    std::vector<int> windows{2,3,4};
    cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize)));
    totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size());
  }
  linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize));
Franck Dary's avatar
Franck Dary committed
  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}

torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
{
  if (input.dim() == 1)
    input = input.unsqueeze(0);

  auto embeddings = wordEmbeddings(input);
Franck Dary's avatar
Franck Dary committed

  auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});

  auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
Franck Dary's avatar
Franck Dary committed
  std::vector<torch::Tensor> cnnOutputs;
  if (rawInputSize != 0)
  {
    auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
    cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
  }
Franck Dary's avatar
Franck Dary committed
  auto curIndex = 0;
  for (unsigned int i = 0; i < focusedColumns.size(); i++)
  {
Franck Dary's avatar
Franck Dary committed
    long nbElements = maxNbElements[i];
    for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++)
    {
Franck Dary's avatar
Franck Dary committed
      auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1);
      curIndex += nbElements;
Franck Dary's avatar
Franck Dary committed
      cnnOutputs.emplace_back(cnns[i](cnnInput));
Franck Dary's avatar
Franck Dary committed

  cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
Franck Dary's avatar
Franck Dary committed

  auto totalInput = torch::cat(cnnOutputs, 1);
Franck Dary's avatar
Franck Dary committed

  return linear2(torch::relu(linear1(totalInput)));
}

std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{
  std::vector<long> contextIndexes = extractContextIndexes(config);
Franck Dary's avatar
Franck Dary committed
  std::vector<long> context;

  if (rawInputSize > 0)
  {
    for (int i = 0; i < leftWindowRawInput; i++)
      if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
        context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
      else
        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
    for (int i = 0; i <= rightWindowRawInput; i++)
      if (config.hasCharacter(config.getCharacterIndex()+i))
        context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
      else
        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
  }
Franck Dary's avatar
Franck Dary committed
  for (auto index : contextIndexes)
    for (auto & col : columns)
      if (index == -1)
        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
      else
        context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
Franck Dary's avatar
Franck Dary committed

  for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
  {
    auto & col = focusedColumns[colIndex];
Franck Dary's avatar
Franck Dary committed

    std::vector<int> focusedIndexes;
    for (auto relIndex : focusedBufferIndexes)
    {
      int index = relIndex + leftBorder;
      if (index < 0 || index >= (int)contextIndexes.size())
        focusedIndexes.push_back(-1);
      else
        focusedIndexes.push_back(contextIndexes[index]);
    }
    for (auto index : focusedStackIndexes)
    {
      if (!config.hasStack(index))
        focusedIndexes.push_back(-1);
      else if (!config.has(col, config.getStack(index), 0))
        focusedIndexes.push_back(-1);
Franck Dary's avatar
Franck Dary committed
      else
        focusedIndexes.push_back(config.getStack(index));
    }
Franck Dary's avatar
Franck Dary committed

    for (auto index : focusedIndexes)
Franck Dary's avatar
Franck Dary committed
    {
Franck Dary's avatar
Franck Dary committed
      {
        for (int i = 0; i < maxNbElements[colIndex]; i++)
          context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
        continue;
Franck Dary's avatar
Franck Dary committed
      }

      std::vector<std::string> elements;
      if (col == "FORM")
Franck Dary's avatar
Franck Dary committed
      {
        auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
Franck Dary's avatar
Franck Dary committed

        for (int i = 0; i < maxNbElements[colIndex]; i++)
          if (i < (int)asUtf8.size())
            elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
          else
            elements.emplace_back(Dict::nullValueStr);
      }
      else if (col == "FEATS")
        auto splited = util::split(config.getAsFeature(col, index).get(), '|');

        for (int i = 0; i < maxNbElements[colIndex]; i++)
          if (i < (int)splited.size())
            elements.emplace_back(fmt::format("FEATS({})", splited[i]));
          else
            elements.emplace_back(Dict::nullValueStr);
        elements.emplace_back(config.getAsFeature(col, index));

      if ((int)elements.size() != maxNbElements[colIndex])
        util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));

      for (auto & element : elements)
        context.emplace_back(dict.getIndexOrInsert(element));
Franck Dary's avatar
Franck Dary committed
  return context;
}