From 28045459af02c15661908a8588fc0eff576ff17b Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 17 Apr 2020 19:35:13 +0200 Subject: [PATCH] split unknown only when extracting train dataset --- decoder/src/Decoder.cpp | 1 + reading_machine/include/ReadingMachine.hpp | 1 + reading_machine/src/ReadingMachine.cpp | 5 +++++ torch_modules/include/ContextLSTM.hpp | 2 +- torch_modules/include/DepthLayerTreeEmbedding.hpp | 2 +- torch_modules/include/FocusedColumnLSTM.hpp | 2 +- torch_modules/include/NeuralNetwork.hpp | 6 ++++++ torch_modules/include/RawInputLSTM.hpp | 2 +- torch_modules/include/SplitTransLSTM.hpp | 2 +- torch_modules/include/Submodule.hpp | 2 +- torch_modules/src/ContextLSTM.cpp | 7 +++++-- torch_modules/src/DepthLayerTreeEmbedding.cpp | 2 +- torch_modules/src/FocusedColumnLSTM.cpp | 2 +- torch_modules/src/LSTMNetwork.cpp | 14 +++++++------- torch_modules/src/NeuralNetwork.cpp | 10 ++++++++++ torch_modules/src/RawInputLSTM.cpp | 2 +- torch_modules/src/SplitTransLSTM.cpp | 2 +- trainer/src/Trainer.cpp | 2 ++ 18 files changed, 47 insertions(+), 19 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 9b6b3a6..33c4837 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool { torch::AutoGradMode useGrad(false); machine.trainMode(false); + machine.splitUnknown(false); machine.setDictsState(Dict::State::Closed); machine.getStrategy().reset(); config.addPredicted(machine.getPredicted()); diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 9eb09d0..5f3ff1c 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -47,6 +47,7 @@ class ReadingMachine bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; void trainMode(bool isTrainMode); + void splitUnknown(bool splitUnknown); void setDictsState(Dict::State state); void saveBest() const; void saveLast() const; diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 38f79c8..138c249 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -182,6 +182,11 @@ void ReadingMachine::trainMode(bool isTrainMode) classifier->getNN()->train(isTrainMode); } +void ReadingMachine::splitUnknown(bool splitUnknown) +{ + classifier->getNN()->setSplitUnknown(splitUnknown); +} + void ReadingMachine::setDictsState(Dict::State state) { for (auto & it : dicts) diff --git a/torch_modules/include/ContextLSTM.hpp b/torch_modules/include/ContextLSTM.hpp index 136029c..3e3bbac 100644 --- a/torch_modules/include/ContextLSTM.hpp +++ b/torch_modules/include/ContextLSTM.hpp @@ -22,7 +22,7 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(ContextLSTM); diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp index 436a082..2a8f7e8 100644 --- a/torch_modules/include/DepthLayerTreeEmbedding.hpp +++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp @@ -21,7 +21,7 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(DepthLayerTreeEmbedding); diff --git a/torch_modules/include/FocusedColumnLSTM.hpp b/torch_modules/include/FocusedColumnLSTM.hpp index 6ea836a..fd5d915 100644 --- a/torch_modules/include/FocusedColumnLSTM.hpp +++ b/torch_modules/include/FocusedColumnLSTM.hpp @@ -20,7 +20,7 @@ class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(FocusedColumnLSTM); diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index be25c87..1237f09 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -11,6 +11,10 @@ class NeuralNetworkImpl : public torch::nn::Module static torch::Device device; + private : + + bool splitUnknown{false}; + protected : static constexpr int maxNbEmbeddings = 150000; @@ -19,6 +23,8 @@ class NeuralNetworkImpl : public torch::nn::Module virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; + bool mustSplitUnknown() const; + void setSplitUnknown(bool splitUnknown); }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/include/RawInputLSTM.hpp b/torch_modules/include/RawInputLSTM.hpp index db17d6f..0e08560 100644 --- a/torch_modules/include/RawInputLSTM.hpp +++ b/torch_modules/include/RawInputLSTM.hpp @@ -18,7 +18,7 @@ class RawInputLSTMImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(RawInputLSTM); diff --git a/torch_modules/include/SplitTransLSTM.hpp b/torch_modules/include/SplitTransLSTM.hpp index f90c0ed..85d542c 100644 --- a/torch_modules/include/SplitTransLSTM.hpp +++ b/torch_modules/include/SplitTransLSTM.hpp @@ -18,7 +18,7 @@ class SplitTransLSTMImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(SplitTransLSTM); diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 437bbfa..cc38101 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -15,7 +15,7 @@ class Submodule void setFirstInputIndex(std::size_t firstInputIndex); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; - virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0; + virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0; }; #endif diff --git a/torch_modules/src/ContextLSTM.cpp b/torch_modules/src/ContextLSTM.cpp index 95daa69..5da02e7 100644 --- a/torch_modules/src/ContextLSTM.cpp +++ b/torch_modules/src/ContextLSTM.cpp @@ -15,7 +15,7 @@ std::size_t ContextLSTMImpl::getInputSize() return columns.size()*(bufferContext.size()+stackContext.size()); } -void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const { std::vector<long> contextIndexes; @@ -31,8 +31,10 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic for (auto index : contextIndexes) for (auto & col : columns) if (index == -1) + { for (auto & contextElement : context) contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } else { int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); @@ -40,7 +42,8 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic for (auto & contextElement : context) contextElement.push_back(dictIndex); - if (is_training()) + + if (splitUnknown) for (auto & targetCol : unknownValueColumns) if (col == targetCol) if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp index 6e1342a..b506f92 100644 --- a/torch_modules/src/DepthLayerTreeEmbedding.cpp +++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp @@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize() return inputSize; } -void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { std::vector<long> focusedIndexes; diff --git a/torch_modules/src/FocusedColumnLSTM.cpp b/torch_modules/src/FocusedColumnLSTM.cpp index 4e0da0e..e39af63 100644 --- a/torch_modules/src/FocusedColumnLSTM.cpp +++ b/torch_modules/src/FocusedColumnLSTM.cpp @@ -24,7 +24,7 @@ std::size_t FocusedColumnLSTMImpl::getInputSize() return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; } -void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { std::vector<long> focusedIndexes; diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index cfa004e..a4f5863 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -94,21 +94,21 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, context.back().emplace_back(dict.getIndexOrInsert(config.getState())); - contextLSTM->addToContext(context, dict, config); + contextLSTM->addToContext(context, dict, config, mustSplitUnknown()); if (!rawInputLSTM.is_empty()) - rawInputLSTM->addToContext(context, dict, config); + rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown()); if (!treeEmbedding.is_empty()) - treeEmbedding->addToContext(context, dict, config); + treeEmbedding->addToContext(context, dict, config, mustSplitUnknown()); - splitTransLSTM->addToContext(context, dict, config); + splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown()); for (auto & lstm : focusedLstms) - lstm->addToContext(context, dict, config); + lstm->addToContext(context, dict, config, mustSplitUnknown()); - if (!is_training() && context.size() > 1) - util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); + if (!mustSplitUnknown() && context.size() > 1) + util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size())); return context; } diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 02e8a19..235c677 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -2,3 +2,13 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); +bool NeuralNetworkImpl::mustSplitUnknown() const +{ + return splitUnknown; +} + +void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown) +{ + this->splitUnknown = splitUnknown; +} + diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputLSTM.cpp index 2aa8cfd..c6da426 100644 --- a/torch_modules/src/RawInputLSTM.cpp +++ b/torch_modules/src/RawInputLSTM.cpp @@ -20,7 +20,7 @@ std::size_t RawInputLSTMImpl::getInputSize() return leftWindow + rightWindow + 1; } -void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { if (leftWindow < 0 or rightWindow < 0) return; diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp index a83894a..99a1b35 100644 --- a/torch_modules/src/SplitTransLSTM.cpp +++ b/torch_modules/src/SplitTransLSTM.cpp @@ -21,7 +21,7 @@ std::size_t SplitTransLSTMImpl::getInputSize() return maxNbTrans; } -void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { auto & splitTransitions = config.getAppliableSplitTransitions(); for (auto & contextElement : context) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 03b7f88..95e98eb 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -10,6 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); + machine.splitUnknown(true); machine.setDictsState(Dict::State::Open); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); @@ -23,6 +24,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); + machine.splitUnknown(false); machine.setDictsState(Dict::State::Closed); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); -- GitLab