diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 9b6b3a67033013989b31b560dab167de3fcc08eb..33c483783f9c02060a40057dba26597735c34dac 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool { torch::AutoGradMode useGrad(false); machine.trainMode(false); + machine.splitUnknown(false); machine.setDictsState(Dict::State::Closed); machine.getStrategy().reset(); config.addPredicted(machine.getPredicted()); diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 9eb09d038a853625dcbb0b649f02556a06eea94c..5f3ff1c6449e98f666a007f84f4b2b1b4d673726 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -47,6 +47,7 @@ class ReadingMachine bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; void trainMode(bool isTrainMode); + void splitUnknown(bool splitUnknown); void setDictsState(Dict::State state); void saveBest() const; void saveLast() const; diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 38f79c84c0a1e14e7adffe168486d47d1674a944..138c2494791fda62ddf933589d39c6b297260e6e 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -182,6 +182,11 @@ void ReadingMachine::trainMode(bool isTrainMode) classifier->getNN()->train(isTrainMode); } +void ReadingMachine::splitUnknown(bool splitUnknown) +{ + classifier->getNN()->setSplitUnknown(splitUnknown); +} + void ReadingMachine::setDictsState(Dict::State state) { for (auto & it : dicts) diff --git a/torch_modules/include/ContextLSTM.hpp b/torch_modules/include/ContextLSTM.hpp index 136029cd33d2c3a1825ce23d8902c860242eecf8..3e3bbacac0e56cfd38e981279a0f6a54c1f41b3d 100644 --- a/torch_modules/include/ContextLSTM.hpp +++ b/torch_modules/include/ContextLSTM.hpp @@ -22,7 +22,7 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(ContextLSTM); diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp index 436a082a06121a2c62f50da0b5f5ef4b79b99ba8..2a8f7e8ca0ccd8fea1313e4b4437700c9bdd6bef 100644 --- a/torch_modules/include/DepthLayerTreeEmbedding.hpp +++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp @@ -21,7 +21,7 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(DepthLayerTreeEmbedding); diff --git a/torch_modules/include/FocusedColumnLSTM.hpp b/torch_modules/include/FocusedColumnLSTM.hpp index 6ea836a041017fb1fdf6725506b6eb1f561bdb99..fd5d915df6d42d24294e6a75dd42c87d6e81dec1 100644 --- a/torch_modules/include/FocusedColumnLSTM.hpp +++ b/torch_modules/include/FocusedColumnLSTM.hpp @@ -20,7 +20,7 @@ class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(FocusedColumnLSTM); diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index be25c873978d61edc57ee014695e2110b8cb189b..1237f09e15989dc6534e150e9ca03cfe983f797b 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -11,6 +11,10 @@ class NeuralNetworkImpl : public torch::nn::Module static torch::Device device; + private : + + bool splitUnknown{false}; + protected : static constexpr int maxNbEmbeddings = 150000; @@ -19,6 +23,8 @@ class NeuralNetworkImpl : public torch::nn::Module virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; + bool mustSplitUnknown() const; + void setSplitUnknown(bool splitUnknown); }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/include/RawInputLSTM.hpp b/torch_modules/include/RawInputLSTM.hpp index db17d6f0014e615474a462635844e3d5251f3fb0..0e08560836b735f181849571ff0beec8f02bc335 100644 --- a/torch_modules/include/RawInputLSTM.hpp +++ b/torch_modules/include/RawInputLSTM.hpp @@ -18,7 +18,7 @@ class RawInputLSTMImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(RawInputLSTM); diff --git a/torch_modules/include/SplitTransLSTM.hpp b/torch_modules/include/SplitTransLSTM.hpp index f90c0edcfedd699d7d9d18626db8d80cd40d385e..85d542ce8510bd0c1d11b2ca6c1f280aeb386d55 100644 --- a/torch_modules/include/SplitTransLSTM.hpp +++ b/torch_modules/include/SplitTransLSTM.hpp @@ -18,7 +18,7 @@ class SplitTransLSTMImpl : public torch::nn::Module, public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; TORCH_MODULE(SplitTransLSTM); diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 437bbfa4e82ca29fb35b33924ba1fc3c16cbb126..cc381013aea518aeefe8422b36537283d5d0da94 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -15,7 +15,7 @@ class Submodule void setFirstInputIndex(std::size_t firstInputIndex); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; - virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0; + virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0; }; #endif diff --git a/torch_modules/src/ContextLSTM.cpp b/torch_modules/src/ContextLSTM.cpp index 95daa696df15c70797332f0d938c5646111e418d..5da02e729e4425d366ff9c1220c43ae477f5c926 100644 --- a/torch_modules/src/ContextLSTM.cpp +++ b/torch_modules/src/ContextLSTM.cpp @@ -15,7 +15,7 @@ std::size_t ContextLSTMImpl::getInputSize() return columns.size()*(bufferContext.size()+stackContext.size()); } -void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const { std::vector<long> contextIndexes; @@ -31,8 +31,10 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic for (auto index : contextIndexes) for (auto & col : columns) if (index == -1) + { for (auto & contextElement : context) contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } else { int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); @@ -40,7 +42,8 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic for (auto & contextElement : context) contextElement.push_back(dictIndex); - if (is_training()) + + if (splitUnknown) for (auto & targetCol : unknownValueColumns) if (col == targetCol) if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp index 6e1342a0270a0dea0fcc4d41bef758a050709595..b506f9219fd8284094960907975294fbd3a5b28a 100644 --- a/torch_modules/src/DepthLayerTreeEmbedding.cpp +++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp @@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize() return inputSize; } -void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { std::vector<long> focusedIndexes; diff --git a/torch_modules/src/FocusedColumnLSTM.cpp b/torch_modules/src/FocusedColumnLSTM.cpp index 4e0da0ebb99e7aab3abe9f18b89301df0a448591..e39af636c817fdc1677cfd9131b85ec7fb1bd3ba 100644 --- a/torch_modules/src/FocusedColumnLSTM.cpp +++ b/torch_modules/src/FocusedColumnLSTM.cpp @@ -24,7 +24,7 @@ std::size_t FocusedColumnLSTMImpl::getInputSize() return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; } -void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { std::vector<long> focusedIndexes; diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index cfa004e58e66b3162231fd4046b499ac62dc5636..a4f5863ce39dab0fd51730e9edb52a8a0a43e2a2 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -94,21 +94,21 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, context.back().emplace_back(dict.getIndexOrInsert(config.getState())); - contextLSTM->addToContext(context, dict, config); + contextLSTM->addToContext(context, dict, config, mustSplitUnknown()); if (!rawInputLSTM.is_empty()) - rawInputLSTM->addToContext(context, dict, config); + rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown()); if (!treeEmbedding.is_empty()) - treeEmbedding->addToContext(context, dict, config); + treeEmbedding->addToContext(context, dict, config, mustSplitUnknown()); - splitTransLSTM->addToContext(context, dict, config); + splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown()); for (auto & lstm : focusedLstms) - lstm->addToContext(context, dict, config); + lstm->addToContext(context, dict, config, mustSplitUnknown()); - if (!is_training() && context.size() > 1) - util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); + if (!mustSplitUnknown() && context.size() > 1) + util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size())); return context; } diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 02e8a191bfb4b2bc718b6e815a266bec252fb24b..235c67793305280d0e09a3f1d45593fa727d13a3 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -2,3 +2,13 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); +bool NeuralNetworkImpl::mustSplitUnknown() const +{ + return splitUnknown; +} + +void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown) +{ + this->splitUnknown = splitUnknown; +} + diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputLSTM.cpp index 2aa8cfd9c2f97c9fd99d6ed08d3e2a0aa75c35d8..c6da426a7807b90bfd52eaf06abe7599c4c517c3 100644 --- a/torch_modules/src/RawInputLSTM.cpp +++ b/torch_modules/src/RawInputLSTM.cpp @@ -20,7 +20,7 @@ std::size_t RawInputLSTMImpl::getInputSize() return leftWindow + rightWindow + 1; } -void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { if (leftWindow < 0 or rightWindow < 0) return; diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp index a83894abf93efda1ed0124fe58057e3ce042be06..99a1b35650e0b60c8c34c22f0a863d1ab1f8c990 100644 --- a/torch_modules/src/SplitTransLSTM.cpp +++ b/torch_modules/src/SplitTransLSTM.cpp @@ -21,7 +21,7 @@ std::size_t SplitTransLSTMImpl::getInputSize() return maxNbTrans; } -void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { auto & splitTransitions = config.getAppliableSplitTransitions(); for (auto & contextElement : context) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 03b7f880742df3dd737f555fd96ec1490a6c5a7c..95e98eba3e8162e01a71d8a6a5373a0d96293d30 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -10,6 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); + machine.splitUnknown(true); machine.setDictsState(Dict::State::Open); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); @@ -23,6 +24,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); + machine.splitUnknown(false); machine.setDictsState(Dict::State::Closed); extractExamples(config, debug, dir, epoch, dynamicOracleInterval);