Skip to content
Snippets Groups Projects
TestNetwork.cpp 1.74 KiB
Newer Older
#include "TestNetwork.hpp"

TestNetworkImpl::TestNetworkImpl(int nbOutputs)
{
  getOrAddDictValue(Config::String("_null_"));
  getOrAddDictValue(Config::String("_unknown_"));
  getOrAddDictValue(Config::String("_S_"));

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
  linear = register_module("linear", torch::nn::Linear(100, nbOutputs));
}

torch::Tensor TestNetworkImpl::forward(const Config & config)
{
//  std::vector<std::size_t> test{0,1};
//  torch::Tensor tens = torch::from_blob(test.data(), {1,2});
//  return wordEmbeddings(tens);
  constexpr int windowSize = 5;
  int wordIndex = config.getWordIndex();
  int startIndex = wordIndex;
  while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize)
    startIndex--;
  int endIndex = wordIndex;
  while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize)
    endIndex++;

  std::vector<std::size_t> words;
  for (int i = startIndex; i <= endIndex; ++i)
  {
    if (!config.has(0, i, 0))
      util::myThrow(fmt::format("Config do not have line %d", i));

    words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
  }

  if (words.empty())
    util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex));
  auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong));

  return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0);
}

std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
{
  if (s.get().empty())
    return dict[Config::String("_null_")];

  const auto & found = dict.find(s);

  if (found == dict.end())
    return dict[s] = dict.size();

  return found->second;
}