From 2a0a58f1e321047940ee0c3ff26c61d766653a8f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 5 Aug 2020 11:57:42 +0200
Subject: [PATCH] Do not lock dicts when the pretraiend file is empty

---
 common/include/Dict.hpp                   |  2 +-
 common/src/Dict.cpp                       |  8 ++++++--
 torch_modules/src/ContextModule.cpp       | 10 +++++++---
 torch_modules/src/ContextualModule.cpp    |  9 ++++++---
 torch_modules/src/FocusedColumnModule.cpp |  9 ++++++---
 5 files changed, 26 insertions(+), 12 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index dda547b..031d463 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -54,7 +54,7 @@ class Dict
   std::size_t size() const;
   int getNbOccs(int index) const;
   void removeRareElements();
-  void loadWord2Vec(std::filesystem::path path, std::string prefix);
+  bool loadWord2Vec(std::filesystem::path path, std::string prefix);
   bool isSpecialValue(const std::string & value);
 };
 
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index d4e7ba2..10cc040 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -223,10 +223,10 @@ void Dict::removeRareElements()
   nbOccs = newNbOccs;
 }
 
-void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix)
+bool Dict::loadWord2Vec(std::filesystem::path path, std::string prefix)
 {
    if (path.empty())
-    return;
+    return false;
 
   if (!std::filesystem::exists(path))
     util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
@@ -238,6 +238,7 @@ void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix)
   char buffer[100000];
 
   bool firstLine = true;
+  bool pretrained = false;
 
   try
   {
@@ -262,6 +263,7 @@ void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix)
         continue;
       }
 
+      pretrained = true;
       auto splited = util::split(util::strip(buffer), ' ');
 
       if (splited.size() < 2)
@@ -287,6 +289,8 @@ void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix)
     util::myThrow(fmt::format("file '{}' is empty", path.string()));
 
   setState(originalState);
+
+  return pretrained;
 }
 
 bool Dict::isSpecialValue(const std::string & value)
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 2e9a383..6629b18 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -58,9 +58,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
                 auto splited = util::split(p, ',');
                 if (splited.size() != 2)
                   util::myThrow("expected 'prefix,pretrained.w2v'");
-                getDict().loadWord2Vec(this->path / splited[1], splited[0]);
-                getDict().setState(Dict::State::Closed);
-                dictSetPretrained(true);
+
+                auto pretrained = getDict().loadWord2Vec(this->path / splited[1], splited[0]);
+                if (pretrained)
+                {
+                  getDict().setState(Dict::State::Closed);
+                  dictSetPretrained(true);
+                }
               }
             }
 
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index 8b76987..4829596 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -63,9 +63,12 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
                 auto splited = util::split(p, ',');
                 if (splited.size() != 2)
                   util::myThrow("expected 'prefix,file.w2v'");
-                getDict().loadWord2Vec(this->path / splited[1], splited[0]);
-                getDict().setState(Dict::State::Closed);
-                dictSetPretrained(true);
+                auto pretrained = getDict().loadWord2Vec(this->path / splited[1], splited[0]);
+                if (pretrained)
+                {
+                  getDict().setState(Dict::State::Closed);
+                  dictSetPretrained(true);
+                }
               }
             }
 
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 1ef134b..08d5945 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -49,9 +49,12 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
                 auto splited = util::split(p, ',');
                 if (splited.size() != 2)
                   util::myThrow("expected 'prefix,pretrained.w2v'");
-                getDict().loadWord2Vec(this->path / splited[1], splited[0]);
-                getDict().setState(Dict::State::Closed);
-                dictSetPretrained(true);
+                auto pretrained = getDict().loadWord2Vec(this->path / splited[1], splited[0]);
+                if (pretrained)
+                {
+                  getDict().setState(Dict::State::Closed);
+                  dictSetPretrained(true);
+                }
               }
             }
 
-- 
GitLab