Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "ConcatWordsNetwork.hpp"
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs)
{
constexpr int embeddingsSize = 30;
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
auto params = wordEmbeddings->parameters();
_sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
params = linear->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
}
std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters()
{
return _denseParameters;
}
std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters()
{
return _sparseParameters;
}
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
// reshaped dim = {batch, sequence of embeddings}
auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
auto res = torch::softmax(linear(reshaped), reshaped.dim() == 2 ? 1 : 0);
return res;
}