diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 1d81309a5ffc34c8664784d68775b68e78c832ac..44a8edc0f8cbc8d69ced8c273f63e9824af35ecf 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -25,10 +25,10 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool config.printForDebug(stderr); auto dictState = machine.getDict(config.getState()).getState(); - auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())); + auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back(); machine.getDict(config.getState()).setState(dictState); - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device); + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); int chosenTransition = -1; diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index 2edac4993051841372c293c07d55a6aeee56088c..0cd54b8f087034863e1fd8dbb2e07089be221a56 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -33,7 +33,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); torch::Tensor forward(torch::Tensor input) override; - std::vector<long> extractContext(Config & config, Dict & dict) const override; + std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; #endif diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 1ca0919cc3118a2ef5b01c0a466c97ed3c3bd6a5..34bf14b632cd912e1d0743fc667a14ed49e667c2 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -27,7 +27,7 @@ class NeuralNetworkImpl : public torch::nn::Module public : virtual torch::Tensor forward(torch::Tensor input) = 0; - virtual std::vector<long> extractContext(Config & config, Dict & dict) const; + virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const; std::vector<long> extractContextIndexes(const Config & config) const; int getContextSize() const; void setColumns(const std::vector<std::string> & columns); diff --git a/torch_modules/include/RLTNetwork.hpp b/torch_modules/include/RLTNetwork.hpp index 7d350b38fb36a0b31b55ab89b335eb6de62c4124..b996def57a5e540d738bd9db0874a5d8511d2983 100644 --- a/torch_modules/include/RLTNetwork.hpp +++ b/torch_modules/include/RLTNetwork.hpp @@ -23,7 +23,7 @@ class RLTNetworkImpl : public NeuralNetworkImpl RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); torch::Tensor forward(torch::Tensor input) override; - std::vector<long> extractContext(Config & config, Dict & dict) const override; + std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; #endif diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 5e9696eba7062b67c1b36dccb4dc29dd1fb8f7c5..9f9d6a1598a33f0be2f321dc00542b19562f2c58 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -74,118 +74,129 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); } -std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const +std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const { if (dict.size() >= maxNbEmbeddings) util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); std::vector<long> contextIndexes = extractContextIndexes(config); - std::vector<long> context; + std::vector<std::vector<long>> context; + context.emplace_back(); if (rawInputSize > 0) { for (int i = 0; i < leftWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) - context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); else - context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); for (int i = 0; i <= rightWindowRawInput; i++) if (config.hasCharacter(config.getCharacterIndex()+i)) - - context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); + context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i)))); else - context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); } for (auto index : contextIndexes) for (auto & col : columns) if (index == -1) - context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); else { int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); - if (col == "FORM" || col == "LEMMA") - if (dict.getNbOccs(dictIndex) < unknownValueThreshold) - dictIndex = dict.getIndexOrInsert(Dict::unknownValueStr); - context.push_back(dictIndex); - } + for (auto & contextElement : context) + contextElement.push_back(dictIndex); - for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) - { - auto & col = focusedColumns[colIndex]; + if (is_training()) + if (col == "FORM" || col == "LEMMA") + if (dict.getNbOccs(dictIndex) < unknownValueThreshold) + { + context.emplace_back(context.back()); + context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); + } + } - std::vector<int> focusedIndexes; - for (auto relIndex : focusedBufferIndexes) + for (auto & contextElement : context) + for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) { - int index = relIndex + leftBorder; - if (index < 0 || index >= (int)contextIndexes.size()) - focusedIndexes.push_back(-1); - else - focusedIndexes.push_back(contextIndexes[index]); - } - for (auto index : focusedStackIndexes) - { - if (!config.hasStack(index)) - focusedIndexes.push_back(-1); - else if (!config.has(col, config.getStack(index), 0)) - focusedIndexes.push_back(-1); - else - focusedIndexes.push_back(config.getStack(index)); - } + auto & col = focusedColumns[colIndex]; - for (auto index : focusedIndexes) - { - if (index == -1) + std::vector<int> focusedIndexes; + for (auto relIndex : focusedBufferIndexes) { - for (int i = 0; i < maxNbElements[colIndex]; i++) - context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); - continue; + int index = relIndex + leftBorder; + if (index < 0 || index >= (int)contextIndexes.size()) + focusedIndexes.push_back(-1); + else + focusedIndexes.push_back(contextIndexes[index]); } - - std::vector<std::string> elements; - if (col == "FORM") + for (auto index : focusedStackIndexes) { - auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); - - for (int i = 0; i < maxNbElements[colIndex]; i++) - if (i < (int)asUtf8.size()) - elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); - else - elements.emplace_back(Dict::nullValueStr); + if (!config.hasStack(index)) + focusedIndexes.push_back(-1); + else if (!config.has(col, config.getStack(index), 0)) + focusedIndexes.push_back(-1); + else + focusedIndexes.push_back(config.getStack(index)); } - else if (col == "FEATS") - { - auto splited = util::split(config.getAsFeature(col, index).get(), '|'); - for (int i = 0; i < maxNbElements[colIndex]; i++) - if (i < (int)splited.size()) - elements.emplace_back(fmt::format("FEATS({})", splited[i])); - else - elements.emplace_back(Dict::nullValueStr); - } - else if (col == "ID") + for (auto index : focusedIndexes) { - if (config.isTokenPredicted(index)) - elements.emplace_back("ID(TOKEN)"); - else if (config.isMultiwordPredicted(index)) - elements.emplace_back("ID(MULTIWORD)"); - else if (config.isEmptyNodePredicted(index)) - elements.emplace_back("ID(EMPTYNODE)"); + if (index == -1) + { + for (int i = 0; i < maxNbElements[colIndex]; i++) + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + continue; + } + + std::vector<std::string> elements; + if (col == "FORM") + { + auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); + + for (int i = 0; i < maxNbElements[colIndex]; i++) + if (i < (int)asUtf8.size()) + elements.emplace_back(fmt::format("Letter({})", asUtf8[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (col == "FEATS") + { + auto splited = util::split(config.getAsFeature(col, index).get(), '|'); + + for (int i = 0; i < maxNbElements[colIndex]; i++) + if (i < (int)splited.size()) + elements.emplace_back(fmt::format("FEATS({})", splited[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (col == "ID") + { + if (config.isTokenPredicted(index)) + elements.emplace_back("ID(TOKEN)"); + else if (config.isMultiwordPredicted(index)) + elements.emplace_back("ID(MULTIWORD)"); + else if (config.isEmptyNodePredicted(index)) + elements.emplace_back("ID(EMPTYNODE)"); + } + else + { + elements.emplace_back(config.getAsFeature(col, index)); + } + + if ((int)elements.size() != maxNbElements[colIndex]) + util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); + + for (auto & element : elements) + contextElement.emplace_back(dict.getIndexOrInsert(element)); } - else - { - elements.emplace_back(config.getAsFeature(col, index)); - } - - if ((int)elements.size() != maxNbElements[colIndex]) - util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); - - for (auto & element : elements) - context.emplace_back(dict.getIndexOrInsert(element)); } - } + + if (!is_training() && context.size() > 1) + util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); return context; } diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 37c206cf1baf43d2d3d230529e88f1e7271a48ca..fef5519623d91e48bbac5dead3dcd575216c4121 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -35,7 +35,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config return context; } -std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const +std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const { std::vector<long> indexes = extractContextIndexes(config); std::vector<long> context; @@ -47,7 +47,7 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict else context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index))); - return context; + return {context}; } int NeuralNetworkImpl::getContextSize() const diff --git a/torch_modules/src/RLTNetwork.cpp b/torch_modules/src/RLTNetwork.cpp index 85223e776bc2595a594c69fb2fa7abe9c1320b92..38fe64203cbdb0f9546e96cd1b6ac758265af364 100644 --- a/torch_modules/src/RLTNetwork.cpp +++ b/torch_modules/src/RLTNetwork.cpp @@ -79,7 +79,7 @@ torch::Tensor RLTNetworkImpl::forward(torch::Tensor input) return linear2(torch::relu(linear1(representation))); } -std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const +std::vector<std::vector<long>> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const { std::vector<long> contextIndexes; std::stack<int> leftContext; @@ -183,6 +183,6 @@ std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) c else context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l))); - return context; + return {context}; } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 29637014808715bfbbd8f0e546f1998ce61d53e0..501af8e8d87456e2fb8699210972ebf0a130c460 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -48,20 +48,24 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: util::myThrow("No transition appliable !"); } + std::vector<std::vector<long>> context; + try { - auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); - contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); + context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); + for (auto & element : context) + contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device)); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); - auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)); + auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); gold[0] = goldIndex; - classes.emplace_back(gold); + for (auto & element : context) + classes.emplace_back(gold); transition->apply(config); config.addToHistory(transition->getName());