Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include "TestNetwork.hpp"
TestNetworkImpl::TestNetworkImpl()
{
getOrAddDictValue(Config::String("_null_"));
getOrAddDictValue(Config::String("_unknown_"));
getOrAddDictValue(Config::String("_S_"));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
}
torch::Tensor TestNetworkImpl::forward(const Config & config)
{
// std::vector<std::size_t> test{0,1};
// torch::Tensor tens = torch::from_blob(test.data(), {1,2});
// return wordEmbeddings(tens);
constexpr int windowSize = 5;
int startIndex = config.getWordIndex();
while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize)
startIndex--;
int endIndex = config.getWordIndex();
while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize)
endIndex++;
std::vector<std::size_t> words;
for (int i = startIndex; i <= endIndex; ++i)
{
if (!config.has(0, i, 0))
util::myThrow(fmt::format("Config do not have line %d", i));
words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
}
if (words.empty())
util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex));
return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong));
}
std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
{
if (s.get().empty())
return dict[Config::String("_null_")];
const auto & found = dict.find(s);
if (found == dict.end())
return dict[s] = dict.size();
return found->second;
}