From 2a0a58f1e321047940ee0c3ff26c61d766653a8f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 5 Aug 2020 11:57:42 +0200 Subject: [PATCH] Do not lock dicts when the pretraiend file is empty --- common/include/Dict.hpp | 2 +- common/src/Dict.cpp | 8 ++++++-- torch_modules/src/ContextModule.cpp | 10 +++++++--- torch_modules/src/ContextualModule.cpp | 9 ++++++--- torch_modules/src/FocusedColumnModule.cpp | 9 ++++++--- 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index dda547b..031d463 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -54,7 +54,7 @@ class Dict std::size_t size() const; int getNbOccs(int index) const; void removeRareElements(); - void loadWord2Vec(std::filesystem::path path, std::string prefix); + bool loadWord2Vec(std::filesystem::path path, std::string prefix); bool isSpecialValue(const std::string & value); }; diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index d4e7ba2..10cc040 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -223,10 +223,10 @@ void Dict::removeRareElements() nbOccs = newNbOccs; } -void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix) +bool Dict::loadWord2Vec(std::filesystem::path path, std::string prefix) { if (path.empty()) - return; + return false; if (!std::filesystem::exists(path)) util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string())); @@ -238,6 +238,7 @@ void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix) char buffer[100000]; bool firstLine = true; + bool pretrained = false; try { @@ -262,6 +263,7 @@ void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix) continue; } + pretrained = true; auto splited = util::split(util::strip(buffer), ' '); if (splited.size() < 2) @@ -287,6 +289,8 @@ void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix) util::myThrow(fmt::format("file '{}' is empty", path.string())); setState(originalState); + + return pretrained; } bool Dict::isSpecialValue(const std::string & value) diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 2e9a383..6629b18 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -58,9 +58,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin auto splited = util::split(p, ','); if (splited.size() != 2) util::myThrow("expected 'prefix,pretrained.w2v'"); - getDict().loadWord2Vec(this->path / splited[1], splited[0]); - getDict().setState(Dict::State::Closed); - dictSetPretrained(true); + + auto pretrained = getDict().loadWord2Vec(this->path / splited[1], splited[0]); + if (pretrained) + { + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } } } diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 8b76987..4829596 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -63,9 +63,12 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & auto splited = util::split(p, ','); if (splited.size() != 2) util::myThrow("expected 'prefix,file.w2v'"); - getDict().loadWord2Vec(this->path / splited[1], splited[0]); - getDict().setState(Dict::State::Closed); - dictSetPretrained(true); + auto pretrained = getDict().loadWord2Vec(this->path / splited[1], splited[0]); + if (pretrained) + { + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } } } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 1ef134b..08d5945 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -49,9 +49,12 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st auto splited = util::split(p, ','); if (splited.size() != 2) util::myThrow("expected 'prefix,pretrained.w2v'"); - getDict().loadWord2Vec(this->path / splited[1], splited[0]); - getDict().setState(Dict::State::Closed); - dictSetPretrained(true); + auto pretrained = getDict().loadWord2Vec(this->path / splited[1], splited[0]); + if (pretrained) + { + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } } } -- GitLab