Skip to content
Snippets Groups Projects
TestNetwork.cpp 1.59 KiB
Newer Older
#include "TestNetwork.hpp"

TestNetworkImpl::TestNetworkImpl()
{
  getOrAddDictValue(Config::String("_null_"));
  getOrAddDictValue(Config::String("_unknown_"));
  getOrAddDictValue(Config::String("_S_"));

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
}

torch::Tensor TestNetworkImpl::forward(const Config & config)
{
//  std::vector<std::size_t> test{0,1};
//  torch::Tensor tens = torch::from_blob(test.data(), {1,2});
//  return wordEmbeddings(tens);
  constexpr int windowSize = 5;
  int startIndex = config.getWordIndex();
  while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize)
    startIndex--;
  int endIndex = config.getWordIndex();
  while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize)
    endIndex++;

  std::vector<std::size_t> words;
  for (int i = startIndex; i <= endIndex; ++i)
  {
    if (!config.has(0, i, 0))
      util::myThrow(fmt::format("Config do not have line %d", i));

    words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
  }

  if (words.empty())
    util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex));

  return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong));
}

std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
{
  if (s.get().empty())
    return dict[Config::String("_null_")];

  const auto & found = dict.find(s);

  if (found == dict.end())
    return dict[s] = dict.size();

  return found->second;
}