Skip to content
Snippets Groups Projects
CNNNetwork.cpp 7.51 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "CNNNetwork.hpp"
    
    
    CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
    
    Franck Dary's avatar
    Franck Dary committed
    {
      constexpr int embeddingsSize = 64;
    
      constexpr int hiddenSize = 1024;
    
      constexpr int nbFiltersContext = 512;
      constexpr int nbFiltersFocused = 64;
    
    Franck Dary's avatar
    Franck Dary committed
    
      setLeftBorder(leftBorder);
      setRightBorder(rightBorder);
      setNbStackElements(nbStackElements);
    
      setColumns(columns);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      rawInputSize =  leftWindowRawInput + rightWindowRawInput + 1;
      if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
        rawInputSize = 0;
      else
        rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
      int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
    
    
      wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
    
      embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
    
      cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
      hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
    
      contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
    
      int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
    
      for (auto & col : focusedColumns)
      {
        std::vector<int> windows{2,3,4};
        cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize)));
        totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size());
      }
      linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize));
    
    Franck Dary's avatar
    Franck Dary committed
      linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
    }
    
    torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
    {
      if (input.dim() == 1)
        input = input.unsqueeze(0);
    
    
      auto embeddings = embeddingsDropout(wordEmbeddings(input));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
    
      context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
    
    
      auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
    
    Franck Dary's avatar
    Franck Dary committed
      std::vector<torch::Tensor> cnnOutputs;
    
      if (rawInputSize != 0)
      {
        auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
        cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
      }
    
    Franck Dary's avatar
    Franck Dary committed
      auto curIndex = 0;
    
      for (unsigned int i = 0; i < focusedColumns.size(); i++)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        long nbElements = maxNbElements[i];
    
        for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++)
        {
    
    Franck Dary's avatar
    Franck Dary committed
          auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1);
    
          curIndex += nbElements;
    
    Franck Dary's avatar
    Franck Dary committed
          cnnOutputs.emplace_back(cnns[i](cnnInput));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
    {
    
      if (dict.size() >= maxNbEmbeddings)
        util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
    
    
      std::vector<long> contextIndexes = extractContextIndexes(config);
    
    Franck Dary's avatar
    Franck Dary committed
      std::vector<long> context;
    
    
      if (rawInputSize > 0)
      {
        for (int i = 0; i < leftWindowRawInput; i++)
          if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
            context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
          else
            context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
    
        for (int i = 0; i <= rightWindowRawInput; i++)
          if (config.hasCharacter(config.getCharacterIndex()+i))
    
            context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
          else
            context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
      }
    
    Franck Dary's avatar
    Franck Dary committed
      for (auto index : contextIndexes)
        for (auto & col : columns)
    
          if (index == -1)
            context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
          else
    
          {
            int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
            if (col == "FORM" || col == "LEMMA")
              if (dict.getNbOccs(dictIndex) < unknownValueThreshold)
                dictIndex = dict.getIndexOrInsert(Dict::unknownValueStr);
    
            context.push_back(dictIndex);
          }
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
      {
        auto & col = focusedColumns[colIndex];
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        std::vector<int> focusedIndexes;
        for (auto relIndex : focusedBufferIndexes)
        {
          int index = relIndex + leftBorder;
          if (index < 0 || index >= (int)contextIndexes.size())
            focusedIndexes.push_back(-1);
          else
            focusedIndexes.push_back(contextIndexes[index]);
        }
        for (auto index : focusedStackIndexes)
        {
          if (!config.hasStack(index))
            focusedIndexes.push_back(-1);
          else if (!config.has(col, config.getStack(index), 0))
            focusedIndexes.push_back(-1);
    
    Franck Dary's avatar
    Franck Dary committed
          else
    
            focusedIndexes.push_back(config.getStack(index));
        }
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        for (auto index : focusedIndexes)
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
    Franck Dary's avatar
    Franck Dary committed
          {
    
            for (int i = 0; i < maxNbElements[colIndex]; i++)
              context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
            continue;
    
    Franck Dary's avatar
    Franck Dary committed
          }
    
    
          std::vector<std::string> elements;
          if (col == "FORM")
    
    Franck Dary's avatar
    Franck Dary committed
          {
    
            auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
    
    Franck Dary's avatar
    Franck Dary committed
    
    
            for (int i = 0; i < maxNbElements[colIndex]; i++)
              if (i < (int)asUtf8.size())
                elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
              else
                elements.emplace_back(Dict::nullValueStr);
          }
          else if (col == "FEATS")
    
            auto splited = util::split(config.getAsFeature(col, index).get(), '|');
    
            for (int i = 0; i < maxNbElements[colIndex]; i++)
              if (i < (int)splited.size())
                elements.emplace_back(fmt::format("FEATS({})", splited[i]));
              else
                elements.emplace_back(Dict::nullValueStr);
    
          else if (col == "ID")
          {
            if (config.isTokenPredicted(index))
              elements.emplace_back("ID(TOKEN)");
            else if (config.isMultiwordPredicted(index))
              elements.emplace_back("ID(MULTIWORD)");
            else if (config.isEmptyNodePredicted(index))
              elements.emplace_back("ID(EMPTYNODE)");
          }
    
            elements.emplace_back(config.getAsFeature(col, index));
    
    
          if ((int)elements.size() != maxNbElements[colIndex])
            util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
    
          for (auto & element : elements)
            context.emplace_back(dict.getIndexOrInsert(element));
    
    Franck Dary's avatar
    Franck Dary committed
      return context;
    }