From 05062ca77b2d6e933c7b847137fa6ef7d5842a74 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 13 Jun 2020 17:14:07 +0200 Subject: [PATCH] Removed pretrainedEmbeddings as a global parameter, instead submodules can now have their own pretrained w2v --- common/include/Dict.hpp | 1 + common/src/Dict.cpp | 49 +++++++++++++++++++ reading_machine/src/ReadingMachine.cpp | 2 +- .../include/AppliableTransModule.hpp | 2 +- torch_modules/include/ContextModule.hpp | 3 +- .../include/DepthLayerTreeEmbeddingModule.hpp | 2 +- torch_modules/include/DictHolder.hpp | 3 ++ torch_modules/include/DistanceModule.hpp | 2 +- torch_modules/include/FocusedColumnModule.hpp | 2 +- torch_modules/include/HistoryModule.hpp | 2 +- torch_modules/include/ModularNetwork.hpp | 2 +- torch_modules/include/NeuralNetwork.hpp | 2 +- torch_modules/include/NumericColumnModule.hpp | 2 +- torch_modules/include/RandomNetwork.hpp | 2 +- torch_modules/include/RawInputModule.hpp | 2 +- torch_modules/include/SplitTransModule.hpp | 2 +- torch_modules/include/StateNameModule.hpp | 2 +- torch_modules/include/Submodule.hpp | 2 +- torch_modules/include/UppercaseRateModule.hpp | 2 +- torch_modules/src/AppliableTransModule.cpp | 2 +- torch_modules/src/ContextModule.cpp | 16 ++++-- .../src/DepthLayerTreeEmbeddingModule.cpp | 3 +- torch_modules/src/DictHolder.cpp | 12 ++++- torch_modules/src/DistanceModule.cpp | 3 +- torch_modules/src/FocusedColumnModule.cpp | 3 +- torch_modules/src/HistoryModule.cpp | 3 +- torch_modules/src/ModularNetwork.cpp | 9 ++-- torch_modules/src/NumericColumnModule.cpp | 2 +- torch_modules/src/RandomNetwork.cpp | 2 +- torch_modules/src/RawInputModule.cpp | 3 +- torch_modules/src/SplitTransModule.cpp | 3 +- torch_modules/src/StateNameModule.cpp | 3 +- torch_modules/src/UppercaseRateModule.cpp | 2 +- trainer/include/Trainer.hpp | 1 - trainer/src/MacaonTrain.cpp | 40 +++++++-------- trainer/src/Trainer.cpp | 18 ++----- 36 files changed, 131 insertions(+), 80 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 87741cb..6d3f27a 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -50,6 +50,7 @@ class Dict std::size_t size() const; int getNbOccs(int index) const; void removeRareElements(); + void loadWord2Vec(std::filesystem::path & path); }; #endif diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index a4c060c..cab960b 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -200,3 +200,52 @@ void Dict::removeRareElements() nbOccs = newNbOccs; } +void Dict::loadWord2Vec(std::filesystem::path & path) +{ + if (path.empty()) + return; + + if (!std::filesystem::exists(path)) + util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string())); + + auto originalState = getState(); + setState(Dict::State::Open); + + std::FILE * file = std::fopen(path.c_str(), "r"); + char buffer[100000]; + + bool firstLine = true; + + try + { + while (!std::feof(file)) + { + if (buffer != std::fgets(buffer, 100000, file)) + break; + + if (firstLine) + { + firstLine = false; + continue; + } + + auto splited = util::split(util::strip(buffer), ' '); + + if (splited.size() < 2) + util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); + + auto dictIndex = getIndexOrInsert(splited[0]); + + if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr)) + util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer)); + } + } catch (std::exception & e) + { + util::myThrow(fmt::format("caught '{}'", e.what())); + } + + std::fclose(file); + + setState(originalState); +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 973d680..e1b8169 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -11,7 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file readFromFile(path); loadDicts(); - classifier->getNN()->registerEmbeddings(""); + classifier->getNN()->registerEmbeddings(); classifier->getNN()->to(NeuralNetworkImpl::device); if (models.size() > 1) diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp index c0dea7d..5e6f9e4 100644 --- a/torch_modules/include/AppliableTransModule.hpp +++ b/torch_modules/include/AppliableTransModule.hpp @@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(AppliableTransModule); diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index ed0ce57..123c063 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -19,6 +19,7 @@ class ContextModuleImpl : public Submodule std::vector<int> bufferContext; std::vector<int> stackContext; int inSize; + std::filesystem::path w2vFile; public : @@ -27,7 +28,7 @@ class ContextModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(ContextModule); diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 277f7fb..8a60320 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -27,7 +27,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(DepthLayerTreeEmbeddingModule); diff --git a/torch_modules/include/DictHolder.hpp b/torch_modules/include/DictHolder.hpp index 6edb8e7..781045d 100644 --- a/torch_modules/include/DictHolder.hpp +++ b/torch_modules/include/DictHolder.hpp @@ -13,6 +13,7 @@ class DictHolder : public NameHolder static constexpr char * filenameTemplate = "{}.dict"; std::unique_ptr<Dict> dict; + bool pretrained{false}; private : @@ -24,6 +25,8 @@ class DictHolder : public NameHolder void saveDict(std::filesystem::path path); void loadDict(std::filesystem::path path); Dict & getDict(); + bool dictIsPretrained(); + void dictSetPretrained(bool pretrained); }; #endif diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp index b6e22d8..97a823b 100644 --- a/torch_modules/include/DistanceModule.hpp +++ b/torch_modules/include/DistanceModule.hpp @@ -26,7 +26,7 @@ class DistanceModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(DistanceModule); diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index 7ebd6c5..024c6c1 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -27,7 +27,7 @@ class FocusedColumnModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(FocusedColumnModule); diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index 594df1f..4a0a2bb 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -24,7 +24,7 @@ class HistoryModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(HistoryModule); diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 7e98302..8a8cd0e 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -30,7 +30,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; void setDictsState(Dict::State state) override; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index f7c26b6..ee32d2b 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -21,7 +21,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; - virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; + virtual void registerEmbeddings() = 0; virtual void saveDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0; virtual void setDictsState(Dict::State state) = 0; diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp index 82e3d37..26e295a 100644 --- a/torch_modules/include/NumericColumnModule.hpp +++ b/torch_modules/include/NumericColumnModule.hpp @@ -24,7 +24,7 @@ class NumericColumnModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(NumericColumnModule); diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index 1a4bad7..b20a779 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config &) override; - void registerEmbeddings(std::filesystem::path) override; + void registerEmbeddings() override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; void setDictsState(Dict::State state) override; diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index d3a0e6c..00aaf18 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -24,7 +24,7 @@ class RawInputModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(RawInputModule); diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index f738cdd..3f46093 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -24,7 +24,7 @@ class SplitTransModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(SplitTransModule); diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp index e4c126e..2e1a7d4 100644 --- a/torch_modules/include/StateNameModule.hpp +++ b/torch_modules/include/StateNameModule.hpp @@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(StateNameModule); diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 71b1007..70250e0 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -21,7 +21,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde virtual std::size_t getInputSize() = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; - virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; + virtual void registerEmbeddings() = 0; std::function<std::string(const std::string &)> getFunction(const std::string functionNames); }; diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp index e28366e..dcfb89c 100644 --- a/torch_modules/include/UppercaseRateModule.hpp +++ b/torch_modules/include/UppercaseRateModule.hpp @@ -23,7 +23,7 @@ class UppercaseRateModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; - void registerEmbeddings(std::filesystem::path pretrained) override; + void registerEmbeddings() override; }; TORCH_MODULE(UppercaseRateModule); diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp index 76fd5ed..c50586f 100644 --- a/torch_modules/src/AppliableTransModule.cpp +++ b/torch_modules/src/AppliableTransModule.cpp @@ -31,7 +31,7 @@ void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & con contextElement.emplace_back(0); } -void AppliableTransModuleImpl::registerEmbeddings(std::filesystem::path) +void AppliableTransModuleImpl::registerEmbeddings() { } diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index d1a31c9..75a23b1 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -3,7 +3,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition) { setName(name); - std::regex regex("(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + + std::regex regex("(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { try @@ -43,6 +44,15 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + w2vFile = sm.str(8); + + if (!w2vFile.empty()) + { + getDict().loadWord2Vec(w2vFile); + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} })) util::myThrow(fmt::format("invalid definition '{}'", definition)); @@ -100,9 +110,9 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) return myModule->forward(context); } -void ContextModuleImpl::registerEmbeddings(std::filesystem::path path) +void ContextModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, path); + loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile); } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 4894eb9..2cb88dc 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -124,9 +124,8 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon } } -void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::filesystem::path path) +void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/DictHolder.cpp b/torch_modules/src/DictHolder.cpp index 2f1958f..f712112 100644 --- a/torch_modules/src/DictHolder.cpp +++ b/torch_modules/src/DictHolder.cpp @@ -18,7 +18,7 @@ void DictHolder::saveDict(std::filesystem::path path) void DictHolder::loadDict(std::filesystem::path path) { - dict.reset(new Dict((path / filename()).c_str(), Dict::State::Open)); + dict.reset(new Dict((path / filename()).c_str(), dict->getState())); } Dict & DictHolder::getDict() @@ -26,3 +26,13 @@ Dict & DictHolder::getDict() return *dict; } +bool DictHolder::dictIsPretrained() +{ + return pretrained; +} + +void DictHolder::dictSetPretrained(bool pretrained) +{ + this->pretrained = pretrained; +} + diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index 50deea0..f529537 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -107,9 +107,8 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, } } -void DistanceModuleImpl::registerEmbeddings(std::filesystem::path path) +void DistanceModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 29aef9e..91d22b0 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -137,9 +137,8 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont } } -void FocusedColumnModuleImpl::registerEmbeddings(std::filesystem::path path) +void FocusedColumnModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index be36990..c326f52 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -63,9 +63,8 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } -void HistoryModuleImpl::registerEmbeddings(std::filesystem::path path) +void HistoryModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 22cdb3a..75ca3ae 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -101,10 +101,10 @@ std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & confi return context; } -void ModularNetworkImpl::registerEmbeddings(std::filesystem::path pretrained) +void ModularNetworkImpl::registerEmbeddings() { for (auto & mod : modules) - mod->registerEmbeddings(pretrained); + mod->registerEmbeddings(); } void ModularNetworkImpl::saveDicts(std::filesystem::path path) @@ -122,7 +122,10 @@ void ModularNetworkImpl::loadDicts(std::filesystem::path path) void ModularNetworkImpl::setDictsState(Dict::State state) { for (auto & mod : modules) - mod->getDict().setState(state); + { + if (!mod->dictIsPretrained()) + mod->getDict().setState(state); + } } void ModularNetworkImpl::setCountOcc(bool countOcc) diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index c94ac66..1825023 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -83,7 +83,7 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont } } -void NumericColumnModuleImpl::registerEmbeddings(std::filesystem::path) +void NumericColumnModuleImpl::registerEmbeddings() { } diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 85f7c3c..7a6491b 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -18,7 +18,7 @@ std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &) return std::vector<std::vector<long>>{{0}}; } -void RandomNetworkImpl::registerEmbeddings(std::filesystem::path) +void RandomNetworkImpl::registerEmbeddings() { } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index 14cd3bc..c99c4ae 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -74,9 +74,8 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, } } -void RawInputModuleImpl::registerEmbeddings(std::filesystem::path path) +void RawInputModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 45c268a..822969f 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -63,9 +63,8 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } -void SplitTransModuleImpl::registerEmbeddings(std::filesystem::path path) +void SplitTransModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, path); } diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 0cdc820..42edd50 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -36,9 +36,8 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, contextElement.emplace_back(dict.getIndexOrInsert(config.getState())); } -void StateNameModuleImpl::registerEmbeddings(std::filesystem::path path) +void StateNameModuleImpl::registerEmbeddings() { embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize)); - loadPretrainedW2vEmbeddings(embeddings, path); } diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 478651c..818db8b 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -92,7 +92,7 @@ void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & cont } -void UppercaseRateModuleImpl::registerEmbeddings(std::filesystem::path) +void UppercaseRateModuleImpl::registerEmbeddings() { } diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index fcbae07..d566747 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -65,7 +65,6 @@ class Trainer void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle); void makeDataLoader(std::filesystem::path dir); void makeDevDataLoader(std::filesystem::path dir); - void fillDicts(BaseConfig & goldConfig, bool debug); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); }; diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 8a60dee..63a32de 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -33,14 +33,10 @@ po::options_description MacaonTrain::getOptionsDescription() "Number of training epochs") ("batchSize", po::value<int>()->default_value(64), "Number of examples per batch") - ("rarityThreshold", po::value<float>()->default_value(70.0), - "During train, the X% rarest elements will be treated as unknown values") ("machine", po::value<std::string>()->default_value(""), "Reading machine file content") - ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold"), + ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"), "Description of what should happen during training") - ("pretrainedEmbeddings", po::value<std::string>()->default_value(""), - "File containing pretrained embeddings, w2v format") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -124,12 +120,10 @@ int MacaonTrain::main() auto devRawFile = variables["devTXT"].as<std::string>(); auto nbEpoch = variables["nbEpochs"].as<int>(); auto batchSize = variables["batchSize"].as<int>(); - auto rarityThreshold = variables["rarityThreshold"].as<float>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool computeDevScore = variables.count("devScore") == 0 ? false : true; auto machineContent = variables["machine"].as<std::string>(); - auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>(); auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); auto trainStrategy = parseTrainStrategy(trainStrategyStr); @@ -158,23 +152,14 @@ int MacaonTrain::main() Trainer trainer(machine, batchSize); Decoder decoder(machine); - if (util::findFilesByExtension(machinePath.parent_path(), ".dict").empty()) - { - trainer.fillDicts(goldConfig, debug); - machine.removeRareDictElements(rarityThreshold); - machine.saveDicts(); - } - else + if (!util::findFilesByExtension(machinePath.parent_path(), ".dict").empty()) { machine.loadDicts(); + machine.getClassifier()->getNN()->registerEmbeddings(); + machine.loadLastSaved(); + machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); } - machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings); - machine.loadLastSaved(); - machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); - - fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters())); - float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); auto trainInfos = machinePath.parent_path() / "train.info"; @@ -198,10 +183,12 @@ int MacaonTrain::main() std::fclose(f); } - machine.getClassifier()->resetOptimizer(); auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer"; if (std::filesystem::exists(trainInfos)) + { + machine.getClassifier()->resetOptimizer(); machine.getClassifier()->loadOptimizer(optimizerCheckpoint); + } for (; currentEpoch < nbEpoch; currentEpoch++) { @@ -218,7 +205,7 @@ int MacaonTrain::main() if (entry.is_regular_file()) std::filesystem::remove(entry.path()); } - if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) + if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold)) { trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); if (!computeDevScore) @@ -229,12 +216,19 @@ int MacaonTrain::main() if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters)) { machine.resetClassifier(); - machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings); + machine.getClassifier()->getNN()->registerEmbeddings(); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); + fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters())); } machine.getClassifier()->resetOptimizer(); } + if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) + { + trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); + if (!computeDevScore) + trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); + } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save)) { saved = true; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index ea9031b..66efbfe 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -22,9 +22,11 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); - machine.setDictsState(Dict::State::Closed); + machine.setDictsState(Dict::State::Open); extractExamples(config, debug, dir, epoch, dynamicOracle); + + machine.saveDicts(); } void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle) @@ -259,20 +261,6 @@ void Trainer::Examples::addClass(int goldIndex) classes.emplace_back(gold); } -void Trainer::fillDicts(BaseConfig & goldConfig, bool debug) -{ - SubConfig config(goldConfig, goldConfig.getNbLines()); - - machine.setCountOcc(true); - - machine.trainMode(false); - machine.setDictsState(Dict::State::Open); - - fillDicts(config, debug); - - machine.setCountOcc(false); -} - void Trainer::fillDicts(SubConfig & config, bool debug) { torch::AutoGradMode useGrad(false); -- GitLab