From 9032ef490871ad110ee46e660071c3f6a2427d16 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 9 Nov 2021 21:05:25 +0100 Subject: [PATCH] Tried to improve pretrained --- CMakeLists.txt | 2 +- common/include/Dict.hpp | 2 + common/include/util.hpp | 2 + common/src/Dict.cpp | 10 +++- common/src/util.cpp | 56 +++++++++++++++++++ reading_machine/include/Classifier.hpp | 2 +- reading_machine/src/Classifier.cpp | 6 +- reading_machine/src/ReadingMachine.cpp | 2 +- .../include/AppliableTransModule.hpp | 2 +- torch_modules/include/ContextModule.hpp | 2 +- torch_modules/include/ContextualModule.hpp | 2 +- .../include/DepthLayerTreeEmbeddingModule.hpp | 2 +- torch_modules/include/DistanceModule.hpp | 2 +- torch_modules/include/FocusedColumnModule.hpp | 2 +- torch_modules/include/HistoryMineModule.hpp | 2 +- torch_modules/include/HistoryModule.hpp | 2 +- torch_modules/include/ModularNetwork.hpp | 2 +- torch_modules/include/NeuralNetwork.hpp | 2 +- torch_modules/include/NumericColumnModule.hpp | 2 +- torch_modules/include/RandomNetwork.hpp | 2 +- torch_modules/include/RawInputModule.hpp | 2 +- torch_modules/include/SplitTransModule.hpp | 2 +- torch_modules/include/StateNameModule.hpp | 2 +- torch_modules/include/Submodule.hpp | 4 +- torch_modules/include/UppercaseRateModule.hpp | 2 +- torch_modules/src/AppliableTransModule.cpp | 2 +- torch_modules/src/ContextModule.cpp | 4 +- torch_modules/src/ContextualModule.cpp | 4 +- .../src/DepthLayerTreeEmbeddingModule.cpp | 2 +- torch_modules/src/DistanceModule.cpp | 2 +- torch_modules/src/FocusedColumnModule.cpp | 4 +- torch_modules/src/HistoryMineModule.cpp | 2 +- torch_modules/src/HistoryModule.cpp | 2 +- torch_modules/src/ModularNetwork.cpp | 4 +- torch_modules/src/NumericColumnModule.cpp | 2 +- torch_modules/src/RandomNetwork.cpp | 2 +- torch_modules/src/RawInputModule.cpp | 2 +- torch_modules/src/SplitTransModule.cpp | 2 +- torch_modules/src/StateNameModule.cpp | 2 +- torch_modules/src/Submodule.cpp | 19 ++++++- torch_modules/src/UppercaseRateModule.cpp | 2 +- 41 files changed, 130 insertions(+), 45 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e841bd..de44daa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ if(NOT CMAKE_BUILD_TYPE) endif() set(CMAKE_CXX_FLAGS "-Wall -Wextra") -set(CMAKE_CXX_FLAGS_DEBUG "-g3") +set(CMAKE_CXX_FLAGS_DEBUG "-g3 -rdynamic") set(CMAKE_CXX_FLAGS_RELEASE "-Ofast") include_directories(fmt/include) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 93774a1..fa33fef 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -36,6 +36,7 @@ class Dict State state; bool isCountingOccs{false}; std::set<std::string> prefixes{""}; + bool locked; public : @@ -51,6 +52,7 @@ class Dict public : + void lock(); void countOcc(bool isCountingOccs); std::set<std::size_t> getSpecialIndexes(); int getIndexOrInsert(const std::string & element, const std::string & prefix); diff --git a/common/include/util.hpp b/common/include/util.hpp index 165328c..331e95c 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -29,6 +29,8 @@ void myThrow(std::string_view message, const std::experimental::source_location std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path directory, std::string extension); +std::string getStackTrace(); + std::string_view getFilenameFromPath(std::string_view s); std::vector<std::string> split(std::string_view s, char delimiter); diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 190cd06..2a4f118 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -3,6 +3,7 @@ Dict::Dict(State state) { + locked = false; setState(state); insert(unknownValueStr); insert(nullValueStr); @@ -18,6 +19,12 @@ Dict::Dict(const char * filename, State state) { readFromFile(filename); setState(state); + locked = false; +} + +void Dict::lock() +{ + locked = true; } void Dict::readFromFile(const char * filename) @@ -161,7 +168,8 @@ int Dict::_getIndexOrInsert(const std::string & element, const std::string & pre void Dict::setState(State state) { - this->state = state; + if (!locked) + this->state = state; } Dict::State Dict::getState() const diff --git a/common/src/util.cpp b/common/src/util.cpp index 9a9f21a..52a362f 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -5,6 +5,8 @@ #include <iostream> #include <fstream> #include <unistd.h> +#include <execinfo.h> +#include <cxxabi.h> #include "upper2lower" float util::long2float(long l) @@ -445,3 +447,57 @@ std::vector<std::vector<std::string>> util::readTSV(std::string_view tsvFilename return sentences; } +std::string util::getStackTrace() +{ + std::string res; + + try + { + void * array[100]; + size_t size; + + size = backtrace(array, 100); + + char ** messages = backtrace_symbols(array, size); + + for (unsigned int i = 1; i < size && messages != NULL; ++i) + { + char *mangled_name = 0, *offset_begin = 0, *offset_end = 0; + + for (char *p = messages[i]; *p; ++p) + { + if (*p == '(') + mangled_name = p; + else if (*p == '+') + offset_begin = p; + else if (*p == ')') + { + offset_end = p; + break; + } + } + + if (mangled_name && offset_begin && offset_end && + mangled_name < offset_begin) + { + *mangled_name++ = '\0'; + *offset_begin++ = '\0'; + *offset_end++ = '\0'; + + int status = 0; + char * real_name = abi::__cxa_demangle(mangled_name, 0, 0, &status); + + res = fmt::format("{}{}[bt] : ({}) {} : {}+{}{}", res, res.size() == 0 ? "" : "\n", i, messages[i], status == 0 ? real_name : mangled_name, offset_begin, offset_end); + } + else + res = fmt::format("{}\n[bt] : ({}) {}", res, i, messages[i]); + } + } + catch (std::exception & e) + { + error(e); + } + + return res; +} + diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 41285a3..01329e5 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -37,7 +37,7 @@ class Classifier public : - Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train); + Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train, bool loadPretrained=false); TransitionSet & getTransitionSet(const std::string & state); NeuralNetwork & getNN(); const std::string & getName() const; diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index e2dd7c9..86b3d65 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -3,7 +3,7 @@ #include "RandomNetwork.hpp" #include "ModularNetwork.hpp" -Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path) +Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train, bool loadPretrained) : path(path) { this->name = name; std::size_t curIndex = 0; @@ -79,12 +79,12 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std getNN()->eval(); getNN()->loadDicts(path); - getNN()->registerEmbeddings(); + getNN()->registerEmbeddings(loadPretrained); if (!train) { torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice()); - getNN()->registerEmbeddings(); + getNN()->registerEmbeddings(loadPretrained); } else if (std::filesystem::exists(getLastFilename())) { diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 7c06ebd..2b40fb6 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -175,7 +175,7 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold) void ReadingMachine::resetClassifiers() { for (unsigned int i = 0; i < classifiers.size(); i++) - classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train)); + classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train, true)); } int ReadingMachine::getNbParameters() const diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp index 98f5fe1..47dbd15 100644 --- a/torch_modules/include/AppliableTransModule.hpp +++ b/torch_modules/include/AppliableTransModule.hpp @@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(AppliableTransModule); diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index c2e0668..c9c9b33 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -31,7 +31,7 @@ class ContextModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(ContextModule); diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp index 8483b1a..cf4e81b 100644 --- a/torch_modules/include/ContextualModule.hpp +++ b/torch_modules/include/ContextualModule.hpp @@ -32,7 +32,7 @@ class ContextualModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(ContextualModule); diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 3621e6e..e2e606e 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -28,7 +28,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(DepthLayerTreeEmbeddingModule); diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp index bafa0b8..23817df 100644 --- a/torch_modules/include/DistanceModule.hpp +++ b/torch_modules/include/DistanceModule.hpp @@ -27,7 +27,7 @@ class DistanceModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(DistanceModule); diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index a7df331..622dab0 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -30,7 +30,7 @@ class FocusedColumnModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(FocusedColumnModule); diff --git a/torch_modules/include/HistoryMineModule.hpp b/torch_modules/include/HistoryMineModule.hpp index 7f6afd6..7ca3477 100644 --- a/torch_modules/include/HistoryMineModule.hpp +++ b/torch_modules/include/HistoryMineModule.hpp @@ -26,7 +26,7 @@ class HistoryMineModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(HistoryMineModule); diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index b4a725b..ff859e8 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -26,7 +26,7 @@ class HistoryModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(HistoryModule); diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index e2d4643..910cc40 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -33,7 +33,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path); torch::Tensor forward(torch::Tensor input, const std::string & state) override; torch::Tensor extractContext(Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; void setDictsState(Dict::State state) override; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 6e2319b..f34f966 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -16,7 +16,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0; virtual torch::Tensor extractContext(Config & config) = 0; - virtual void registerEmbeddings() = 0; + virtual void registerEmbeddings(bool loadPretrained) = 0; virtual void saveDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0; virtual void setDictsState(Dict::State state) = 0; diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp index 3ee9cb2..8b39dba 100644 --- a/torch_modules/include/NumericColumnModule.hpp +++ b/torch_modules/include/NumericColumnModule.hpp @@ -25,7 +25,7 @@ class NumericColumnModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(NumericColumnModule); diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index 33d99a1..68909fd 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); torch::Tensor forward(torch::Tensor input, const std::string & state) override; torch::Tensor extractContext(Config &) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; void setDictsState(Dict::State state) override; diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index 0ca658b..ef0f605 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -26,7 +26,7 @@ class RawInputModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(RawInputModule); diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index b88491e..10f8d8c 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -25,7 +25,7 @@ class SplitTransModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(SplitTransModule); diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp index ace1cbc..a4c62e4 100644 --- a/torch_modules/include/StateNameModule.hpp +++ b/torch_modules/include/StateNameModule.hpp @@ -22,7 +22,7 @@ class StateNameModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(StateNameModule); diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index f4722bf..7d5a544 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -21,12 +21,12 @@ class Submodule : public torch::nn::Module, public DictHolder static void setReloadPretrained(bool reloadPretrained); void setFirstInputIndex(std::size_t firstInputIndex); - void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix); + void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix, bool loadPretrained); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; virtual void addToContext(torch::Tensor & context, const Config & config) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; - virtual void registerEmbeddings() = 0; + virtual void registerEmbeddings(bool loadPretrained) = 0; std::function<std::string(const std::string &)> getFunction(const std::string functionNames); }; diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp index 9495661..8576de9 100644 --- a/torch_modules/include/UppercaseRateModule.hpp +++ b/torch_modules/include/UppercaseRateModule.hpp @@ -23,7 +23,7 @@ class UppercaseRateModuleImpl : public Submodule std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(torch::Tensor & context, const Config & config) override; - void registerEmbeddings() override; + void registerEmbeddings(bool loadPretrained) override; }; TORCH_MODULE(UppercaseRateModule); diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp index 7a5c830..56373eb 100644 --- a/torch_modules/src/AppliableTransModule.cpp +++ b/torch_modules/src/AppliableTransModule.cpp @@ -28,7 +28,7 @@ void AppliableTransModuleImpl::addToContext(torch::Tensor & context, const Confi context[firstInputIndex+i] = appliableTrans[i]; } -void AppliableTransModuleImpl::registerEmbeddings() +void AppliableTransModuleImpl::registerEmbeddings(bool) { } diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 67bbb29..48f9a00 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -184,7 +184,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) return myModule->forward(context).reshape({input.size(0), -1}); } -void ContextModuleImpl::registerEmbeddings() +void ContextModuleImpl::registerEmbeddings(bool loadPretrained) { if (!wordEmbeddings) wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); @@ -192,7 +192,7 @@ void ContextModuleImpl::registerEmbeddings() for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0], loadPretrained); } } diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index bd825f7..564c95f 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -231,7 +231,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1}); } -void ContextualModuleImpl::registerEmbeddings() +void ContextualModuleImpl::registerEmbeddings(bool loadPretrained) { if (!wordEmbeddings) wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); @@ -240,7 +240,7 @@ void ContextualModuleImpl::registerEmbeddings() for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0], loadPretrained); } } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index acc45d5..94e703c 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -128,7 +128,7 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, co } } -void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() +void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(bool) { if (!wordEmbeddings) wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index 0aebe58..fbd972c 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -110,7 +110,7 @@ void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & co } } -void DistanceModuleImpl::registerEmbeddings() +void DistanceModuleImpl::registerEmbeddings(bool) { if (!wordEmbeddings) wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 62da3de..23ebe6f 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -161,7 +161,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config } } -void FocusedColumnModuleImpl::registerEmbeddings() +void FocusedColumnModuleImpl::registerEmbeddings(bool loadPretrained) { if (!wordEmbeddings) wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); @@ -169,7 +169,7 @@ void FocusedColumnModuleImpl::registerEmbeddings() for (auto & p : pathes) { auto splited = util::split(p, ','); - loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]); + loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0], loadPretrained); } } diff --git a/torch_modules/src/HistoryMineModule.cpp b/torch_modules/src/HistoryMineModule.cpp index 7d1c6f5..75a6466 100644 --- a/torch_modules/src/HistoryMineModule.cpp +++ b/torch_modules/src/HistoryMineModule.cpp @@ -66,7 +66,7 @@ void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config & context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix); } -void HistoryMineModuleImpl::registerEmbeddings() +void HistoryMineModuleImpl::registerEmbeddings(bool) { if (!wordEmbeddings) wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index c897364..d116a05 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -66,7 +66,7 @@ void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & con context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix); } -void HistoryModuleImpl::registerEmbeddings() +void HistoryModuleImpl::registerEmbeddings(bool) { if (!wordEmbeddings) wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index e288df2..84c0e13 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -107,10 +107,10 @@ torch::Tensor ModularNetworkImpl::extractContext(Config & config) return context; } -void ModularNetworkImpl::registerEmbeddings() +void ModularNetworkImpl::registerEmbeddings(bool loadPretrained) { for (auto & mod : modules) - mod->registerEmbeddings(); + mod->registerEmbeddings(loadPretrained); } void ModularNetworkImpl::saveDicts(std::filesystem::path path) diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index a666fc3..5a1191d 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -90,7 +90,7 @@ void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config } } -void NumericColumnModuleImpl::registerEmbeddings() +void NumericColumnModuleImpl::registerEmbeddings(bool) { } diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index d27ffe9..46d1f80 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -19,7 +19,7 @@ torch::Tensor RandomNetworkImpl::extractContext(Config &) return context; } -void RandomNetworkImpl::registerEmbeddings() +void RandomNetworkImpl::registerEmbeddings(bool) { } diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index c5dc9a5..659da66 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -84,7 +84,7 @@ void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & co } } -void RawInputModuleImpl::registerEmbeddings() +void RawInputModuleImpl::registerEmbeddings(bool) { if (!wordEmbeddings) wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index ee3fa38..0e362d6 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -62,7 +62,7 @@ void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config & context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, ""); } -void SplitTransModuleImpl::registerEmbeddings() +void SplitTransModuleImpl::registerEmbeddings(bool) { if (!wordEmbeddings) wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>())); diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index b5e81af..0c9e2b7 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -35,7 +35,7 @@ void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & c context[firstInputIndex] = dict.getIndexOrInsert(config.getState(), ""); } -void StateNameModuleImpl::registerEmbeddings() +void StateNameModuleImpl::registerEmbeddings(bool) { if (!embeddings) embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize, std::set<std::size_t>())); diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 51fff9d..b5f43ab 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -13,7 +13,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex) this->firstInputIndex = firstInputIndex; } -void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix) +void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix, bool loadPretrained) { if (path.empty()) return; @@ -22,6 +22,8 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std if (!std::filesystem::exists(path)) util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string())); + if (loadPretrained) + fmt::print(stderr, "[{}] Loading pretrained embeddings '{}'\n", util::getTime(), std::string(path)); std::vector<std::vector<float>> toAdd; @@ -35,6 +37,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std bool firstLine = true; std::size_t embeddingsSize = embeddings->parameters()[0].size(-1); + int nbLoaded = 0; try { @@ -55,6 +58,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); std::string word; + nbLoaded += 1; if (splited[0] == "<unk>") word = Dict::unknownValueStr; @@ -70,6 +74,9 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std auto dictIndex = getDict().getIndexOrInsert(word, prefix); + if (not loadPretrained) + continue; + if (embeddingsSize > splited.size()-1) util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1)); @@ -103,6 +110,14 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std if (firstLine) util::myThrow(fmt::format("file '{}' is empty", path.string())); + if (not loadPretrained) + { + getDict().setState(Dict::State::Closed); + embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained()); + getDict().lock(); + return; + } + if (!toAdd.empty()) { auto newEmb = torch::nn::Embedding(embeddings->weight.size(0)+toAdd.size(), embeddingsSize); @@ -116,6 +131,8 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std getDict().setState(originalState); embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained()); + + fmt::print(stderr, "[{}] Done loading {} embeddings. Frozen={}\n", util::getTime(), nbLoaded, !embeddings->weight.requires_grad()); } std::function<std::string(const std::string &)> Submodule::getFunction(const std::string functionNames) diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 8d86c74..a057439 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -89,7 +89,7 @@ void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config } } -void UppercaseRateModuleImpl::registerEmbeddings() +void UppercaseRateModuleImpl::registerEmbeddings(bool) { } -- GitLab