#include "RLTNetwork.hpp" RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) { constexpr int embeddingsSize = 30; constexpr int lstmOutputSize = 128; constexpr int treeEmbeddingsSize = 256; constexpr int hiddenSize = 500; setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true))); treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false))); S = register_parameter("S", torch::randn(treeEmbeddingsSize)); nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize)); } torch::Tensor RLTNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size()); auto computeOrder = input.narrow(1, focusedIndexes.size(1), leftBorder+rightBorder+1); auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(leftBorder+rightBorder+1)); auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds}); auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), columns.size()*(leftBorder+rightBorder+1)); auto baseEmbeddings = wordEmbeddings(wordIndexes); auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()}); auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output; std::vector<std::map<int, torch::Tensor>> treeRepresentations; for (unsigned int batch = 0; batch < computeOrder.size(0); batch++) { treeRepresentations.emplace_back(); for (unsigned int i = 0; i < computeOrder[batch].size(0); i++) { int index = computeOrder[batch][i].item<int>(); if (index == -1) break; std::vector<torch::Tensor> inputVector; inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], S}, 0)); for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++) { int child = childs[batch][index][childIndex].item<int>(); if (child == -1) break; inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], treeRepresentations[batch].count(child) ? treeRepresentations[batch][child] : nullTree}, 0)); } auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0); auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze(); treeRepresentations[batch][index] = lstmOut; } } std::vector<torch::Tensor> focusedTrees; std::vector<torch::Tensor> representations; for (unsigned int batch = 0; batch < focusedIndexes.size(0); batch++) { focusedTrees.clear(); for (unsigned int i = 0; i < focusedIndexes[batch].size(0); i++) { int index = focusedIndexes[batch][i].item<int>(); if (index == -1) focusedTrees.emplace_back(nullTree); else focusedTrees.emplace_back(treeRepresentations[batch].count(index) ? treeRepresentations[batch][index] : nullTree); } representations.emplace_back(torch::cat(focusedTrees, 0).unsqueeze(0)); } auto representation = torch::cat(representations, 0); return linear2(torch::relu(linear1(representation))); } std::vector<std::vector<long>> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const { std::vector<long> contextIndexes; std::stack<int> leftContext; for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index) if (config.isToken(index)) leftContext.push(index); while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size()) contextIndexes.emplace_back(-1); while (!leftContext.empty()) { contextIndexes.emplace_back(leftContext.top()); leftContext.pop(); } for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index) if (config.isToken(index)) contextIndexes.emplace_back(index); while ((int)contextIndexes.size() < leftBorder+rightBorder+1) contextIndexes.emplace_back(-1); std::map<long, long> indexInContext; for (auto & l : contextIndexes) indexInContext.emplace(std::make_pair(l, indexInContext.size())); std::vector<long> headOf; for (auto & l : contextIndexes) { if (l == -1) headOf.push_back(-1); else { auto & head = config.getAsFeature(Config::headColName, l); if (util::isEmpty(head) or head == "_") headOf.push_back(-1); else if (indexInContext.count(std::stoi(head))) headOf.push_back(std::stoi(head)); else headOf.push_back(-1); } } std::vector<std::vector<long>> childs(headOf.size()); for (unsigned int i = 0; i < headOf.size(); i++) if (headOf[i] != -1) childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]); std::vector<long> treeComputationOrder; std::vector<bool> treeIsComputed(contextIndexes.size(), false); std::function<void(long)> depthFirst; depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root) { if (!indexInContext.count(root)) return; if (treeIsComputed[indexInContext[root]]) return; for (auto child : childs[indexInContext[root]]) depthFirst(child); treeIsComputed[indexInContext[root]] = true; treeComputationOrder.push_back(indexInContext[root]); }; for (auto & l : focusedBufferIndexes) if (contextIndexes[leftBorder+l] != -1) depthFirst(contextIndexes[leftBorder+l]); for (auto & l : focusedStackIndexes) if (config.hasStack(l)) depthFirst(config.getStack(l)); std::vector<long> context; for (auto & c : focusedBufferIndexes) context.push_back(leftBorder+c); for (auto & c : focusedStackIndexes) if (config.hasStack(c) && indexInContext.count(config.getStack(c))) context.push_back(indexInContext[config.getStack(c)]); else context.push_back(-1); for (auto & c : treeComputationOrder) context.push_back(c); while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size()) context.push_back(-1); for (auto & c : childs) { for (unsigned int i = 0; i < maxNbChilds; i++) if (i < c.size()) context.push_back(indexInContext[c[i]]); else context.push_back(-1); } for (auto & l : contextIndexes) for (auto & col : columns) if (l == -1) context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); else context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l))); return {context}; }