-
Franck Dary authoredFranck Dary authored
RLTNetwork.cpp 7.32 KiB
#include "RLTNetwork.hpp"
RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
constexpr int embeddingsSize = 30;
constexpr int lstmOutputSize = 128;
constexpr int treeEmbeddingsSize = 256;
constexpr int hiddenSize = 500;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false)));
S = register_parameter("S", torch::randn(treeEmbeddingsSize));
nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
}
torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
auto computeOrder = input.narrow(1, focusedIndexes.size(1), leftBorder+rightBorder+1);
auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(leftBorder+rightBorder+1));
auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds});
auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), columns.size()*(leftBorder+rightBorder+1));
auto baseEmbeddings = wordEmbeddings(wordIndexes);
auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()});
auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output;
std::vector<std::map<int, torch::Tensor>> treeRepresentations;
for (unsigned int batch = 0; batch < computeOrder.size(0); batch++)
{
treeRepresentations.emplace_back();
for (unsigned int i = 0; i < computeOrder[batch].size(0); i++)
{
int index = computeOrder[batch][i].item<int>();
if (index == -1)
break;
std::vector<torch::Tensor> inputVector;
inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], S}, 0));
for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
{
int child = childs[batch][index][childIndex].item<int>();
if (child == -1)
break;
inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], treeRepresentations[batch].count(child) ? treeRepresentations[batch][child] : nullTree}, 0));
}
auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
treeRepresentations[batch][index] = lstmOut;
}
}
std::vector<torch::Tensor> focusedTrees;
std::vector<torch::Tensor> representations;
for (unsigned int batch = 0; batch < focusedIndexes.size(0); batch++)
{
focusedTrees.clear();
for (unsigned int i = 0; i < focusedIndexes[batch].size(0); i++)
{
int index = focusedIndexes[batch][i].item<int>();
if (index == -1)
focusedTrees.emplace_back(nullTree);
else
focusedTrees.emplace_back(treeRepresentations[batch].count(index) ? treeRepresentations[batch][index] : nullTree);
}
representations.emplace_back(torch::cat(focusedTrees, 0).unsqueeze(0));
}
auto representation = torch::cat(representations, 0);
return linear2(torch::relu(linear1(representation)));
}
std::vector<std::vector<long>> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<long> contextIndexes;
std::stack<int> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
if (config.isToken(index))
leftContext.push(index);
while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size())
contextIndexes.emplace_back(-1);
while (!leftContext.empty())
{
contextIndexes.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index)
if (config.isToken(index))
contextIndexes.emplace_back(index);
while ((int)contextIndexes.size() < leftBorder+rightBorder+1)
contextIndexes.emplace_back(-1);
std::map<long, long> indexInContext;
for (auto & l : contextIndexes)
indexInContext.emplace(std::make_pair(l, indexInContext.size()));
std::vector<long> headOf;
for (auto & l : contextIndexes)
{
if (l == -1)
headOf.push_back(-1);
else
{
auto & head = config.getAsFeature(Config::headColName, l);
if (util::isEmpty(head) or head == "_")
headOf.push_back(-1);
else if (indexInContext.count(std::stoi(head)))
headOf.push_back(std::stoi(head));
else
headOf.push_back(-1);
}
}
std::vector<std::vector<long>> childs(headOf.size());
for (unsigned int i = 0; i < headOf.size(); i++)
if (headOf[i] != -1)
childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]);
std::vector<long> treeComputationOrder;
std::vector<bool> treeIsComputed(contextIndexes.size(), false);
std::function<void(long)> depthFirst;
depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root)
{
if (!indexInContext.count(root))
return;
if (treeIsComputed[indexInContext[root]])
return;
for (auto child : childs[indexInContext[root]])
depthFirst(child);
treeIsComputed[indexInContext[root]] = true;
treeComputationOrder.push_back(indexInContext[root]);
};
for (auto & l : focusedBufferIndexes)
if (contextIndexes[leftBorder+l] != -1)
depthFirst(contextIndexes[leftBorder+l]);
for (auto & l : focusedStackIndexes)
if (config.hasStack(l))
depthFirst(config.getStack(l));
std::vector<long> context;
for (auto & c : focusedBufferIndexes)
context.push_back(leftBorder+c);
for (auto & c : focusedStackIndexes)
if (config.hasStack(c) && indexInContext.count(config.getStack(c)))
context.push_back(indexInContext[config.getStack(c)]);
else
context.push_back(-1);
for (auto & c : treeComputationOrder)
context.push_back(c);
while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size())
context.push_back(-1);
for (auto & c : childs)
{
for (unsigned int i = 0; i < maxNbChilds; i++)
if (i < c.size())
context.push_back(indexInContext[c[i]]);
else
context.push_back(-1);
}
for (auto & l : contextIndexes)
for (auto & col : columns)
if (l == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));
return {context};
}