Newer
Older
TestNetworkImpl::TestNetworkImpl(int nbOutputs)
{
getOrAddDictValue(Config::String("_null_"));
getOrAddDictValue(Config::String("_unknown_"));
getOrAddDictValue(Config::String("_S_"));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
linear = register_module("linear", torch::nn::Linear(100, nbOutputs));
}
torch::Tensor TestNetworkImpl::forward(const Config & config)
{
// std::vector<std::size_t> test{0,1};
// torch::Tensor tens = torch::from_blob(test.data(), {1,2});
// return wordEmbeddings(tens);
constexpr int windowSize = 5;
int wordIndex = config.getWordIndex();
int startIndex = wordIndex;
while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize)
int endIndex = wordIndex;
while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize)
endIndex++;
std::vector<std::size_t> words;
for (int i = startIndex; i <= endIndex; ++i)
{
if (!config.has(0, i, 0))
util::myThrow(fmt::format("Config do not have line %d", i));
words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
}
if (words.empty())
util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex));
auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong));
return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0);
}
std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
{
if (s.get().empty())
return dict[Config::String("_null_")];
const auto & found = dict.find(s);
if (found == dict.end())
return dict[s] = dict.size();
return found->second;
}