Skip to content
Snippets Groups Projects
TestNetwork.cpp 1.59 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "TestNetwork.hpp"
    
    TestNetworkImpl::TestNetworkImpl()
    {
      getOrAddDictValue(Config::String("_null_"));
      getOrAddDictValue(Config::String("_unknown_"));
      getOrAddDictValue(Config::String("_S_"));
    
      wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
    }
    
    torch::Tensor TestNetworkImpl::forward(const Config & config)
    {
    //  std::vector<std::size_t> test{0,1};
    //  torch::Tensor tens = torch::from_blob(test.data(), {1,2});
    //  return wordEmbeddings(tens);
      constexpr int windowSize = 5;
      int startIndex = config.getWordIndex();
      while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize)
        startIndex--;
      int endIndex = config.getWordIndex();
      while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize)
        endIndex++;
    
      std::vector<std::size_t> words;
      for (int i = startIndex; i <= endIndex; ++i)
      {
        if (!config.has(0, i, 0))
          util::myThrow(fmt::format("Config do not have line %d", i));
    
        words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
      }
    
      if (words.empty())
        util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex));
    
      return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong));
    }
    
    std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
    {
      if (s.get().empty())
        return dict[Config::String("_null_")];
    
      const auto & found = dict.find(s);
    
      if (found == dict.end())
        return dict[s] = dict.size();
    
      return found->second;
    }