#include "TestNetwork.hpp" TestNetworkImpl::TestNetworkImpl(int nbOutputs) { getOrAddDictValue(Config::String("_null_")); getOrAddDictValue(Config::String("_unknown_")); getOrAddDictValue(Config::String("_S_")); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100)); linear = register_module("linear", torch::nn::Linear(100, nbOutputs)); } torch::Tensor TestNetworkImpl::forward(const Config & config) { // std::vector<std::size_t> test{0,1}; // torch::Tensor tens = torch::from_blob(test.data(), {1,2}); // return wordEmbeddings(tens); constexpr int windowSize = 5; int wordIndex = config.getWordIndex(); int startIndex = wordIndex; while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize) startIndex--; int endIndex = wordIndex; while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize) endIndex++; std::vector<std::size_t> words; for (int i = startIndex; i <= endIndex; ++i) { if (!config.has(0, i, 0)) util::myThrow(fmt::format("Config do not have line %d", i)); words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i))); } if (words.empty()) util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex)); auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong)); return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0); } std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s) { if (s.get().empty()) return dict[Config::String("_null_")]; const auto & found = dict.find(s); if (found == dict.end()) return dict[s] = dict.size(); return found->second; }