#include "TestNetwork.hpp" TestNetworkImpl::TestNetworkImpl() { getOrAddDictValue(Config::String("_null_")); getOrAddDictValue(Config::String("_unknown_")); getOrAddDictValue(Config::String("_S_")); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100)); } torch::Tensor TestNetworkImpl::forward(const Config & config) { // std::vector<std::size_t> test{0,1}; // torch::Tensor tens = torch::from_blob(test.data(), {1,2}); // return wordEmbeddings(tens); constexpr int windowSize = 5; int startIndex = config.getWordIndex(); while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize) startIndex--; int endIndex = config.getWordIndex(); while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize) endIndex++; std::vector<std::size_t> words; for (int i = startIndex; i <= endIndex; ++i) { if (!config.has(0, i, 0)) util::myThrow(fmt::format("Config do not have line %d", i)); words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i))); } if (words.empty()) util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex)); return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong)); } std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s) { if (s.get().empty()) return dict[Config::String("_null_")]; const auto & found = dict.find(s); if (found == dict.end()) return dict[s] = dict.size(); return found->second; }