Skip to content
Snippets Groups Projects
TestNetwork.cpp 1.74 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "TestNetwork.hpp"
    
    
    TestNetworkImpl::TestNetworkImpl(int nbOutputs)
    
    {
      getOrAddDictValue(Config::String("_null_"));
      getOrAddDictValue(Config::String("_unknown_"));
      getOrAddDictValue(Config::String("_S_"));
    
      wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
    
      linear = register_module("linear", torch::nn::Linear(100, nbOutputs));
    
    }
    
    torch::Tensor TestNetworkImpl::forward(const Config & config)
    {
    //  std::vector<std::size_t> test{0,1};
    //  torch::Tensor tens = torch::from_blob(test.data(), {1,2});
    //  return wordEmbeddings(tens);
      constexpr int windowSize = 5;
    
      int wordIndex = config.getWordIndex();
      int startIndex = wordIndex;
      while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize)
    
        startIndex--;
    
      int endIndex = wordIndex;
      while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize)
    
        endIndex++;
    
      std::vector<std::size_t> words;
      for (int i = startIndex; i <= endIndex; ++i)
      {
        if (!config.has(0, i, 0))
          util::myThrow(fmt::format("Config do not have line %d", i));
    
        words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
      }
    
      if (words.empty())
    
        util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex));
    
      auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong));
    
      return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0);
    
    }
    
    std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
    {
      if (s.get().empty())
        return dict[Config::String("_null_")];
    
      const auto & found = dict.find(s);
    
      if (found == dict.end())
        return dict[s] = dict.size();
    
      return found->second;
    }