diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 0fc17c65fa325fe2263640f362dbf821567719c7..fd8d1bca540f96c15c9a7ead5306ea9e4e87ac50 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -114,7 +114,7 @@ void Config::print(FILE * dest) const void Config::printForDebug(FILE * dest) const { - static constexpr int windowSize = 5; + static constexpr int windowSize = 10; static constexpr int lettersWindowSize = 40; static constexpr int maxWordLength = 7; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 5c8493e562ed9ea2bf50df7bb6fd5ce2983a391d..cdc55145bd7f20cb593e9e9fc82afd02a0564b9c 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -7,7 +7,7 @@ class NeuralNetworkImpl : public torch::nn::Module { - private : + protected : int leftBorder{5}; int rightBorder{5}; @@ -23,7 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module public : virtual torch::Tensor forward(torch::Tensor input) = 0; - std::vector<long> extractContext(Config & config, Dict & dict) const; + virtual std::vector<long> extractContext(Config & config, Dict & dict) const; int getContextSize() const; void setColumns(const std::vector<std::string> & columns); }; diff --git a/torch_modules/include/RTLSTMNetwork.hpp b/torch_modules/include/RTLSTMNetwork.hpp index d30a6e62efe2f3fd76b58bc4c559a458292eeeb0..5d7692523f7661759314b1fb7f1c7a8d7dcbd0b0 100644 --- a/torch_modules/include/RTLSTMNetwork.hpp +++ b/torch_modules/include/RTLSTMNetwork.hpp @@ -7,16 +7,23 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl { private : + static constexpr long maxNbChilds{8}; + static inline std::vector<long> focusedBufferIndexes{0,1,2}; + static inline std::vector<long> focusedStackIndexes{0,1}; + torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; - torch::nn::Dropout dropout{nullptr}; - torch::nn::LSTM lstm{nullptr}; + torch::nn::LSTM vectorBiLSTM{nullptr}; + torch::nn::LSTM treeLSTM{nullptr}; + torch::Tensor S; + torch::Tensor nullTree; public : RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); torch::Tensor forward(torch::Tensor input) override; + std::vector<long> extractContext(Config & config, Dict & dict) const override; }; #endif diff --git a/torch_modules/src/RTLSTMNetwork.cpp b/torch_modules/src/RTLSTMNetwork.cpp index 6cc8f70bab61d2924ed6c3c022e63303e5319fc0..b59892c7770404e64b63ea653b600088b4aa1f33 100644 --- a/torch_modules/src/RTLSTMNetwork.cpp +++ b/torch_modules/src/RTLSTMNetwork.cpp @@ -3,31 +3,176 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) { constexpr int embeddingsSize = 30; - constexpr int lstmOutputSize = 500; + constexpr int lstmOutputSize = 128; + constexpr int treeEmbeddingsSize = 256; constexpr int hiddenSize = 500; + setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); - linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize)); + linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); - dropout = register_module("dropout", torch::nn::Dropout(0.3)); - lstm = register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, lstmOutputSize).batch_first(true))); + vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true))); + treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false))); + S = register_parameter("S", torch::randn(treeEmbeddingsSize)); + nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize)); } torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) { - // input dim = {batch, sequence, embeddings} - auto wordsAsEmb = wordEmbeddings(input); - if (wordsAsEmb.dim() == 2) - wordsAsEmb = torch::unsqueeze(wordsAsEmb, 0); - auto lstmOut = lstm(wordsAsEmb).output; - // reshaped dim = {sequence, batch, embeddings} - auto reshaped = lstmOut.permute({1,0,2}); - auto res = linear2(torch::relu(linear1(reshaped[-1]))); - - return res; + input = input.squeeze(); + if (input.dim() != 1) + util::myThrow(fmt::format("Does not support batched input (dim()={})", input.dim())); + + auto focusedIndexes = input.narrow(0, 0, focusedBufferIndexes.size()+focusedStackIndexes.size()); + auto computeOrder = input.narrow(0, focusedIndexes.size(0), leftBorder+rightBorder+1); + auto childsFlat = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0), maxNbChilds*(leftBorder+rightBorder+1)); + auto childs = torch::reshape(childsFlat, {computeOrder.size(0), maxNbChilds}); + auto wordIndexes = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0)+childsFlat.size(0), columns.size()*(leftBorder+rightBorder+1)); + auto baseEmbeddings = wordEmbeddings(wordIndexes); + auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {(int)baseEmbeddings.size(0)/(int)columns.size(), (int)baseEmbeddings.size(1)*(int)columns.size()}).unsqueeze(0); + auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output.squeeze(); + std::vector<torch::Tensor> treeRepresentations(vectorRepresentations.size(0), nullTree); + for (unsigned int i = 0; i < computeOrder.size(0); i++) + { + int index = computeOrder[i].item<int>(); + if (index == -1) + break; + std::vector<torch::Tensor> inputVector; + inputVector.emplace_back(torch::cat({vectorRepresentations[index], S}, 0)); + for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++) + { + int child = childs[index][childIndex].item<int>(); + if (child == -1) + break; + inputVector.emplace_back(torch::cat({vectorRepresentations[index], treeRepresentations[child]}, 0)); + } + auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0); + auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze(); + treeRepresentations[index] = lstmOut; + } + + std::vector<torch::Tensor> focusedTrees; + for (unsigned int i = 0; i < focusedIndexes.size(0); i++) + { + int index = focusedIndexes[i].item<int>(); + if (index == -1) + focusedTrees.emplace_back(nullTree); + else + focusedTrees.emplace_back(treeRepresentations[index]); + } + + auto representation = torch::cat(focusedTrees, 0); + return linear2(torch::relu(linear1(representation))); +} + +std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const +{ + std::vector<long> contextIndexes; + std::stack<int> leftContext; + for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index) + if (config.isToken(index)) + leftContext.push(index); + + while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size()) + contextIndexes.emplace_back(-1); + while (!leftContext.empty()) + { + contextIndexes.emplace_back(leftContext.top()); + leftContext.pop(); + } + + for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index) + if (config.isToken(index)) + contextIndexes.emplace_back(index); + + while ((int)contextIndexes.size() < leftBorder+rightBorder+1) + contextIndexes.emplace_back(-1); + + std::map<long, long> indexInContext; + for (auto & l : contextIndexes) + indexInContext.emplace(std::make_pair(l, indexInContext.size())); + + std::vector<long> headOf; + for (auto & l : contextIndexes) + { + if (l == -1) + headOf.push_back(-1); + else + { + auto & head = config.getAsFeature(Config::headColName, l); + if (util::isEmpty(head) or head == "_") + headOf.push_back(-1); + else if (indexInContext.count(std::stoi(head))) + headOf.push_back(std::stoi(head)); + else + headOf.push_back(-1); + } + } + + std::vector<std::vector<long>> childs(headOf.size()); + for (unsigned int i = 0; i < headOf.size(); i++) + if (headOf[i] != -1) + childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]); + + std::vector<long> treeComputationOrder; + std::vector<bool> treeIsComputed(contextIndexes.size(), false); + + std::function<void(long)> depthFirst; + depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root) + { + if (!indexInContext.count(root)) + return; + + if (treeIsComputed[indexInContext[root]]) + return; + + for (auto child : childs[indexInContext[root]]) + depthFirst(child); + + treeIsComputed[indexInContext[root]] = true; + treeComputationOrder.push_back(indexInContext[root]); + }; + + for (auto & l : focusedBufferIndexes) + if (contextIndexes[leftBorder+l] != -1) + depthFirst(contextIndexes[leftBorder+l]); + + for (auto & l : focusedStackIndexes) + if (config.hasStack(l)) + depthFirst(config.getStack(l)); + + std::vector<long> context; + + for (auto & c : focusedBufferIndexes) + context.push_back(leftBorder+c); + for (auto & c : focusedStackIndexes) + if (config.hasStack(c) && indexInContext.count(config.getStack(c))) + context.push_back(indexInContext[config.getStack(c)]); + else + context.push_back(-1); + for (auto & c : treeComputationOrder) + context.push_back(c); + while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size()) + context.push_back(-1); + for (auto & c : childs) + { + for (unsigned int i = 0; i < maxNbChilds; i++) + if (i < c.size()) + context.push_back(indexInContext[c[i]]); + else + context.push_back(-1); + } + for (auto & l : contextIndexes) + for (auto & col : columns) + if (l == -1) + context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + else + context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l))); + + return context; } diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 6e889171c5f1fa461f333605446fb8545d287270..5a8c30230b0f56b1aeb98b04879d40cf3d51ab20 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -18,7 +18,7 @@ class Trainer DataLoader dataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; - int batchSize{50}; + int batchSize{1}; int nbExamples{0}; public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 590504ef6fc8c75d31a6e597188c9648e099ee96..a74078cd8f6e7a567687e248b9f924b3e4b6535a 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -63,7 +63,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) float Trainer::epoch(bool printAdvancement) { - constexpr int printInterval = 2000; + constexpr int printInterval = 50; int nbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; @@ -81,6 +81,8 @@ float Trainer::epoch(bool printAdvancement) auto labels = batch.target.squeeze(); auto prediction = machine.getClassifier()->getNN()(data); + if (prediction.dim() == 1) + prediction = prediction.unsqueeze(0); labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));