diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 8bc9d3a8d38ef5ca99de202ba4973b243f783a41..353c333aa8b9c8a76c0b55de28e9d35b7b5ceb24 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -48,6 +48,7 @@ class Dict void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const; std::size_t size() const; int getNbOccs(int index) const; + void removeRareElements(); }; #endif diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index b75457ed71a5db07ba47b3702521dc84178d81d3..4546702b80a1c2fddb76d75a0bc301039f3c6059 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -173,3 +173,24 @@ int Dict::getNbOccs(int index) const return nbOccs[index]; } +void Dict::removeRareElements() +{ + int minNbOcc = std::numeric_limits<int>::max(); + for (int nbOcc : nbOccs) + if (nbOcc < minNbOcc) + minNbOcc = nbOcc; + + std::unordered_map<std::string, int> newElementsToIndexes; + std::vector<int> newNbOccs; + + for (auto & it : elementsToIndexes) + if (nbOccs[it.second] > minNbOcc) + { + newElementsToIndexes.emplace(it.first, newElementsToIndexes.size()); + newNbOccs.emplace_back(nbOccs[it.second]); + } + + elementsToIndexes = newElementsToIndexes; + nbOccs = newNbOccs; +} + diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 352a9e8ad96542f7a9985a3a196d9b1f29d1b567..f91a10a589fc6217e68237612ff28b42d8805b67 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -9,7 +9,6 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool { torch::AutoGradMode useGrad(false); machine.trainMode(false); - machine.splitUnknown(false); machine.setDictsState(Dict::State::Closed); machine.getStrategy().reset(); config.addPredicted(machine.getPredicted()); diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 5f3ff1c6449e98f666a007f84f4b2b1b4d673726..9eb09d038a853625dcbb0b649f02556a06eea94c 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -47,7 +47,6 @@ class ReadingMachine bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; void trainMode(bool isTrainMode); - void splitUnknown(bool splitUnknown); void setDictsState(Dict::State state); void saveBest() const; void saveLast() const; diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 138c2494791fda62ddf933589d39c6b297260e6e..0ff56500fc6f890654eb24d2c73a8bb89e2ee6ce 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -21,8 +21,13 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file { readFromFile(path); + std::size_t maxDictSize = 0; for (auto path : dicts) + { this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed}); + maxDictSize = std::max<std::size_t>(maxDictSize, this->dicts.at(path.stem().string()).size()); + } + classifier->getNN()->registerEmbeddings(maxDictSize); torch::load(classifier->getNN(), models[0]); } @@ -182,11 +187,6 @@ void ReadingMachine::trainMode(bool isTrainMode) classifier->getNN()->train(isTrainMode); } -void ReadingMachine::splitUnknown(bool splitUnknown) -{ - classifier->getNN()->setSplitUnknown(splitUnknown); -} - void ReadingMachine::setDictsState(Dict::State state) { for (auto & it : dicts) diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index a9b609034023857a245c985990cc37f1d06f2bb7..c48eb9f860578fda9cb02d75664c8d2d2602e17a 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -16,8 +16,7 @@ class ContextModuleImpl : public Submodule std::vector<std::string> columns; std::vector<int> bufferContext; std::vector<int> stackContext; - int unknownValueThreshold; - std::vector<std::string> unknownValueColumns{"FORM", "LEMMA"}; + int inSize; public : @@ -25,7 +24,8 @@ class ContextModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void registerEmbeddings(std::size_t nbElements) override; }; TORCH_MODULE(ContextModule); diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 0d5cedda715558c166412d059bab28ce50d379c2..970e3bc535e2f30ce21fbb26c61d68fa7de28ee9 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -17,6 +17,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule std::vector<int> focusedStack; torch::nn::Embedding wordEmbeddings{nullptr}; std::vector<std::shared_ptr<MyModule>> depthModules; + int inSize; public : @@ -24,7 +25,8 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void registerEmbeddings(std::size_t nbElements) override; }; TORCH_MODULE(DepthLayerTreeEmbeddingModule); diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index c105193b65c3f26fb5cfdd6e674910b70843feb0..f7814a08edcc60b52f5db3f5c0be0cc6aed0914a 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -16,6 +16,7 @@ class FocusedColumnModuleImpl : public Submodule std::vector<int> focusedBuffer, focusedStack; std::string column; int maxNbElements; + int inSize; public : @@ -23,7 +24,8 @@ class FocusedColumnModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void registerEmbeddings(std::size_t nbElements) override; }; TORCH_MODULE(FocusedColumnModule); diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 931aca8187ffaeff43871fea3413360c7e069ed8..08ace9018e1920acf96fb65e3deebd4eb57592db 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -27,6 +27,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; + void registerEmbeddings(std::size_t nbElements) override; }; #endif diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 3db86517d391b5ee6384c6e9d8dba4e1b3810837..5372255039138d0e9a5b50bb796d9d466af3eafb 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -13,7 +13,6 @@ class NeuralNetworkImpl : public torch::nn::Module private : - bool splitUnknown{false}; std::string state; protected : @@ -24,8 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; - bool mustSplitUnknown() const; - void setSplitUnknown(bool splitUnknown); + virtual void registerEmbeddings(std::size_t nbElements) = 0; void setState(const std::string & state); const std::string & getState() const; }; diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index f40c1b02a753999de3b7649d18d6e75765f5506d..b26c6f427354266a9210626f18efa25bc9ba93f2 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -14,6 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config &, Dict &) const override; + void registerEmbeddings(std::size_t nbElements) override; }; #endif diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index 4ded9154d4536031bc36e098ee47a15335f96054..02e1dd369cc439bc8e6b3ed79e6c761e2d34ad9f 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -14,6 +14,7 @@ class RawInputModuleImpl : public Submodule torch::nn::Embedding wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int leftWindow, rightWindow; + int inSize; public : @@ -21,7 +22,8 @@ class RawInputModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void registerEmbeddings(std::size_t nbElements) override; }; TORCH_MODULE(RawInputModule); diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index 24c68411e5b5bf3c232b561fc77a16030326a25b..f614588131531087af6a8c90216e5756391539ac 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -14,6 +14,7 @@ class SplitTransModuleImpl : public Submodule torch::nn::Embedding wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int maxNbTrans; + int inSize; public : @@ -21,7 +22,8 @@ class SplitTransModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void registerEmbeddings(std::size_t nbElements) override; }; TORCH_MODULE(SplitTransModule); diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 77c1a4feb08628615d1d163369f0a9272970d475..849eb225061e9a29232763c46d7643968a1723d7 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -16,8 +16,9 @@ class Submodule : public torch::nn::Module void setFirstInputIndex(std::size_t firstInputIndex); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; - virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0; + virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; + virtual void registerEmbeddings(std::size_t nbElements) = 0; }; #endif diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index d2c1e6af087e2696960a6cdfefc03a9e3827f2b6..248da9387d4f3770484b6c11bc1b307e2e4a162c 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -2,23 +2,21 @@ ContextModuleImpl::ContextModuleImpl(const std::string & definition) { - std::regex regex("(?:(?:\\s|\\t)*)Unk\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + std::regex regex("(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { try { - unknownValueThreshold = std::stoi(sm.str(1)); - - for (auto & index : util::split(sm.str(2), ' ')) + for (auto & index : util::split(sm.str(1), ' ')) bufferContext.emplace_back(std::stoi(index)); - for (auto & index : util::split(sm.str(3), ' ')) + for (auto & index : util::split(sm.str(2), ' ')) stackContext.emplace_back(std::stoi(index)); - columns = util::split(sm.str(4), ' '); + columns = util::split(sm.str(3), ' '); - auto subModuleType = sm.str(5); - auto subModuleArguments = util::split(sm.str(6), ' '); + auto subModuleType = sm.str(4); + auto subModuleArguments = util::split(sm.str(5), ' '); auto options = MyModule::ModuleOptions(true) .bidirectional(std::stoi(subModuleArguments[0])) @@ -26,10 +24,8 @@ ContextModuleImpl::ContextModuleImpl(const std::string & definition) .dropout(std::stof(subModuleArguments[2])) .complete(std::stoi(subModuleArguments[3])); - int inSize = std::stoi(sm.str(7)); - int outSize = std::stoi(sm.str(8)); - - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); + inSize = std::stoi(sm.str(6)); + int outSize = std::stoi(sm.str(7)); if (subModuleType == "LSTM") myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options)); @@ -53,7 +49,7 @@ std::size_t ContextModuleImpl::getInputSize() return columns.size()*(bufferContext.size()+stackContext.size()); } -void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const +void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { std::vector<long> contextIndexes; @@ -79,11 +75,6 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, D for (auto & contextElement : context) contextElement.push_back(dictIndex); - - for (auto & targetCol : unknownValueColumns) - if (col == targetCol) - if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) - context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); } } @@ -96,3 +87,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) return myModule->forward(context); } +void ContextModuleImpl::registerEmbeddings(std::size_t nbElements) +{ + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); +} + diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 7e13cdc2d1d4447f4b6141f2932dc981aed9a905..df9c2df62ebc87d324c6be4f5ff5cd4528776c96 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -27,11 +27,9 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(const std:: .dropout(std::stof(subModuleArguments[2])) .complete(std::stoi(subModuleArguments[3])); - int inSize = std::stoi(sm.str(7)); + inSize = std::stoi(sm.str(7)); int outSize = std::stoi(sm.str(8)); - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); - for (unsigned int i = 0; i < maxElemPerDepth.size(); i++) { std::string name = fmt::format("{}_{}", i, subModuleType); @@ -83,7 +81,7 @@ std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize() return inputSize; } -void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { std::vector<long> focusedIndexes; @@ -122,3 +120,8 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon } } +void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::size_t nbElements) +{ + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); +} + diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 08cb9eb5aefc7589cebfac18aaebcfcf1204c69a..03cf9b65a6652e9bcd4c55d319f4a865772ee7ad 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -25,11 +25,9 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(const std::string & definition) .dropout(std::stof(subModuleArguments[2])) .complete(std::stoi(subModuleArguments[3])); - int inSize = std::stoi(sm.str(7)); + inSize = std::stoi(sm.str(7)); int outSize = std::stoi(sm.str(8)); - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); - if (subModuleType == "LSTM") myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") @@ -61,7 +59,7 @@ std::size_t FocusedColumnModuleImpl::getInputSize() return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; } -void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { std::vector<long> focusedIndexes; @@ -134,3 +132,8 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont } } +void FocusedColumnModuleImpl::registerEmbeddings(std::size_t nbElements) +{ + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); +} + diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 47bf5b16d285451522eccdc2c5708ea4b937724b..13b7ca4936a3b1371387d35cabeb5fdb8d4b3795 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -79,7 +79,13 @@ std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & confi { std::vector<std::vector<long>> context(1); for (auto & mod : modules) - mod->addToContext(context, dict, config, mustSplitUnknown()); + mod->addToContext(context, dict, config); return context; } +void ModularNetworkImpl::registerEmbeddings(std::size_t nbElements) +{ + for (auto & mod : modules) + mod->registerEmbeddings(nbElements); +} + diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 987cfcb44c1e243348d6073da4252b7570658fa1..aa149fa00bf82210021569bf06da946bae6002c6 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -2,16 +2,6 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); -bool NeuralNetworkImpl::mustSplitUnknown() const -{ - return splitUnknown; -} - -void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown) -{ - this->splitUnknown = splitUnknown; -} - void NeuralNetworkImpl::setState(const std::string & state) { this->state = state; diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 8dfafc2ab2b044523daf5764047f394e2699ff91..6622732208f5dfd84cb679e7f0266420046b5191 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -17,3 +17,7 @@ std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict return std::vector<std::vector<long>>{{0}}; } +void RandomNetworkImpl::registerEmbeddings(std::size_t) +{ +} + diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index 9c5e5412bfaf929eaab7455d229208e77f4cc599..ac0f5e45b0e11cc85d8cc709cbeab1983e3a074c 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -19,11 +19,9 @@ RawInputModuleImpl::RawInputModuleImpl(const std::string & definition) .dropout(std::stof(subModuleArguments[2])) .complete(std::stoi(subModuleArguments[3])); - int inSize = std::stoi(sm.str(5)); + inSize = std::stoi(sm.str(5)); int outSize = std::stoi(sm.str(6)); - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); - if (subModuleType == "LSTM") myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") @@ -51,7 +49,7 @@ std::size_t RawInputModuleImpl::getInputSize() return leftWindow + rightWindow + 1; } -void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { if (leftWindow < 0 or rightWindow < 0) return; @@ -72,3 +70,8 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, } } +void RawInputModuleImpl::registerEmbeddings(std::size_t nbElements) +{ + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); +} + diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index ab1276c10d9e16eedec4638129b0287f0001a0e9..4ddd818b596980a583bd6006087da49a02bd86a0 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -17,11 +17,9 @@ SplitTransModuleImpl::SplitTransModuleImpl(int maxNbTrans, const std::string & d .dropout(std::stof(subModuleArguments[2])) .complete(std::stoi(subModuleArguments[3])); - int inSize = std::stoi(sm.str(3)); + inSize = std::stoi(sm.str(3)); int outSize = std::stoi(sm.str(4)); - wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); - if (subModuleType == "LSTM") myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") @@ -49,7 +47,7 @@ std::size_t SplitTransModuleImpl::getInputSize() return maxNbTrans; } -void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { auto & splitTransitions = config.getAppliableSplitTransitions(); for (auto & contextElement : context) @@ -60,3 +58,8 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } +void SplitTransModuleImpl::registerEmbeddings(std::size_t nbElements) +{ + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize))); +} + diff --git a/trainer/include/MacaonTrain.hpp b/trainer/include/MacaonTrain.hpp index ad00e9deeececc2652c4df0d92fc5f538f1b86fd..9a92664a74f8b074c38d3a8adb494e14ac6a63d7 100644 --- a/trainer/include/MacaonTrain.hpp +++ b/trainer/include/MacaonTrain.hpp @@ -19,7 +19,6 @@ class MacaonTrain po::options_description getOptionsDescription(); po::variables_map checkOptions(po::options_description & od); - void fillDicts(ReadingMachine & rm, const Config & config); public : diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index c2099fd6a8a21e897fc9490e4f7732d5ca613ec3..9c08f68940c8ed59c8a3a23e68e11e36253c2c52 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -43,12 +43,14 @@ class Trainer void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); + void fillDicts(SubConfig & config); public : Trainer(ReadingMachine & machine, int batchSize); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); + void fillDicts(BaseConfig & goldConfig); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); }; diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 85c049d0aa92dfa605e9d7f1e92c32ce7c102b56..e2cfb32bb80683d6bde3a2e5adb65b67c82ef8be 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -35,6 +35,8 @@ po::options_description MacaonTrain::getOptionsDescription() "Number of examples per batch") ("dynamicOracleInterval", po::value<int>()->default_value(-1), "Number of examples per batch") + ("rarityThreshold", po::value<float>()->default_value(20.0), + "During train, the X% rarest elements will be treated as unknown values") ("machine", po::value<std::string>()->default_value(""), "Reading machine file content") ("help,h", "Produce this help message"); @@ -65,22 +67,6 @@ po::variables_map MacaonTrain::checkOptions(po::options_description & od) return vm; } -void MacaonTrain::fillDicts(ReadingMachine & rm, const Config & config) -{ - static std::vector<std::string> interestingColumns{"FORM", "LEMMA"}; - - for (auto & col : interestingColumns) - if (config.has(col,0,0)) - for (auto & it : rm.getDicts()) - { - it.second.countOcc(true); - for (unsigned int j = 0; j < config.getNbLines(); j++) - for (unsigned int k = 0; k < Config::nbHypothesesMax; k++) - it.second.getIndexOrInsert(config.getConst(col,j,k)); - it.second.countOcc(false); - } -} - int MacaonTrain::main() { auto od = getOptionsDescription(); @@ -96,6 +82,7 @@ int MacaonTrain::main() auto nbEpoch = variables["nbEpochs"].as<int>(); auto batchSize = variables["batchSize"].as<int>(); auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>(); + auto rarityThreshold = variables["rarityThreshold"].as<float>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool computeDevScore = variables.count("devScore") == 0 ? false : true; @@ -124,11 +111,27 @@ int MacaonTrain::main() BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); - fillDicts(machine, goldConfig); - Trainer trainer(machine, batchSize); Decoder decoder(machine); + trainer.fillDicts(goldConfig); + std::size_t maxDictSize = 0; + for (auto & it : machine.getDicts()) + { + std::size_t originalSize = it.second.size(); + for (;;) + { + std::size_t lastSize = it.second.size(); + it.second.removeRareElements(); + float decrease = 100.0*(originalSize-it.second.size())/originalSize; + if (decrease >= rarityThreshold or lastSize == it.second.size()) + break; + } + maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size()); + } + machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize); + machine.saveDicts(); + float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); auto trainInfos = machinePath.parent_path() / "train.info"; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 5306a2cf8cffb1beebd30a1198d6ae386058937d..b928da8c9f8ec168c8c0179adee0684699032e08 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -10,8 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); - machine.splitUnknown(true); - machine.setDictsState(Dict::State::Open); + machine.setDictsState(Dict::State::Closed); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); trainDataset.reset(new Dataset(dir)); @@ -24,7 +23,6 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); - machine.splitUnknown(false); machine.setDictsState(Dict::State::Closed); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); @@ -43,9 +41,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p std::filesystem::create_directories(dir); config.addPredicted(machine.getPredicted()); + machine.getStrategy().reset(); config.setState(machine.getStrategy().getInitialState()); machine.getClassifier()->setState(machine.getStrategy().getInitialState()); - machine.getStrategy().reset(); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile); @@ -154,8 +152,6 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); std::fclose(f); - machine.saveDicts(); - fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples)); } @@ -274,3 +270,66 @@ void Trainer::Examples::addClass(int goldIndex) classes.emplace_back(gold); } +void Trainer::fillDicts(BaseConfig & goldConfig) +{ + SubConfig config(goldConfig, goldConfig.getNbLines()); + + for (auto & it : machine.getDicts()) + it.second.countOcc(true); + + machine.trainMode(false); + machine.setDictsState(Dict::State::Open); + + fillDicts(config); + + for (auto & it : machine.getDicts()) + it.second.countOcc(false); +} + +void Trainer::fillDicts(SubConfig & config) +{ + torch::AutoGradMode useGrad(false); + + config.addPredicted(machine.getPredicted()); + machine.getStrategy().reset(); + config.setState(machine.getStrategy().getInitialState()); + machine.getClassifier()->setState(machine.getStrategy().getInitialState()); + + while (true) + { + if (machine.hasSplitWordTransitionSet()) + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + + try + { + machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); + } catch(std::exception & e) + { + util::myThrow(fmt::format("Failed to extract context : {}", e.what())); + } + + Transition * goldTransition = nullptr; + goldTransition = machine.getTransitionSet().getBestAppliableTransition(config); + + if (!goldTransition) + { + config.printForDebug(stderr); + util::myThrow("No transition appliable !"); + } + + goldTransition->apply(config); + config.addToHistory(goldTransition->getName()); + + auto movement = machine.getStrategy().getMovement(config, goldTransition->getName()); + if (movement == Strategy::endMovement) + break; + + config.setState(movement.first); + machine.getClassifier()->setState(movement.first); + config.moveWordIndexRelaxed(movement.second); + + if (config.needsUpdate()) + config.update(); + } +} +