From 085a30f2f015b4a679854c9cdc35f2aa8c29e1f2 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 16 Jun 2020 00:07:25 +0200
Subject: [PATCH] w2v in ContextModule is now relative path

---
 common/include/Dict.hpp                  |  2 +-
 common/src/Dict.cpp                      |  2 +-
 reading_machine/include/Classifier.hpp   |  4 ++--
 reading_machine/src/Classifier.cpp       | 10 +++++-----
 torch_modules/include/ContextModule.hpp  |  3 ++-
 torch_modules/include/ModularNetwork.hpp |  2 +-
 torch_modules/src/ContextModule.cpp      |  6 +++---
 torch_modules/src/ModularNetwork.cpp     |  4 ++--
 8 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 32571b3..efd5806 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -52,7 +52,7 @@ class Dict
   std::size_t size() const;
   int getNbOccs(int index) const;
   void removeRareElements();
-  void loadWord2Vec(std::filesystem::path & path);
+  void loadWord2Vec(std::filesystem::path path);
 };
 
 #endif
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index cdf09df..49e678f 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -217,7 +217,7 @@ void Dict::removeRareElements()
   nbOccs = newNbOccs;
 }
 
-void Dict::loadWord2Vec(std::filesystem::path & path)
+void Dict::loadWord2Vec(std::filesystem::path path)
 {
    if (path.empty())
     return;
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index ed41b29..a5f7d21 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -24,8 +24,8 @@ class Classifier
 
   private :
 
-  void initNeuralNetwork(const std::vector<std::string> & definition);
-  void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
+  void initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path);
+  void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path);
 
   public :
 
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 753c40e..323bdb0 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -58,7 +58,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
 
-  initNeuralNetwork(definition);
+  initNeuralNetwork(definition, path.parent_path());
 }
 
 int Classifier::getNbParameters() const
@@ -89,7 +89,7 @@ const std::string & Classifier::getName() const
   return name;
 }
 
-void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
+void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path)
 {
   std::map<std::string,std::size_t> nbOutputsPerState;
   for (auto & it : this->transitionSets)
@@ -108,7 +108,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
   if (networkType == "Random")
     this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
   else if (networkType == "Modular")
-    initModular(definition, curIndex, nbOutputsPerState);
+    initModular(definition, curIndex, nbOutputsPerState, path);
   else
     util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
 
@@ -141,7 +141,7 @@ void Classifier::setState(const std::string & state)
   nn->setState(state);
 }
 
-void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
+void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path)
 {
   std::string anyBlanks = "(?:(?:\\s|\\t)*)";
   std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
@@ -157,7 +157,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
     modulesDefinitions.emplace_back(definition[curIndex]);
   }
 
-  this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions));
+  this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions, path));
 }
 
 void Classifier::resetOptimizer()
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index 63f7a3b..7ff6c79 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -19,11 +19,12 @@ class ContextModuleImpl : public Submodule
   std::vector<std::function<std::string(const std::string &)>> functions;
   std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
   int inSize;
+  std::filesystem::path path;
   std::filesystem::path w2vFile;
 
   public :
 
-  ContextModuleImpl(std::string name, const std::string & definition);
+  ContextModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 8a8cd0e..3fa417b 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -27,7 +27,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
 
   public :
 
-  ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
+  ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config) override;
   void registerEmbeddings() override;
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 4b973c0..364f2cb 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -1,6 +1,6 @@
 #include "ContextModule.hpp"
 
-ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition)
+ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
 {
   setName(name);
 
@@ -50,7 +50,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
 
             if (!w2vFile.empty())
             {
-              getDict().loadWord2Vec(w2vFile);
+              getDict().loadWord2Vec(this->path / w2vFile);
               getDict().setState(Dict::State::Closed);
               dictSetPretrained(true);
             }
@@ -144,6 +144,6 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
 void ContextModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile);
+  loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile);
 }
 
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 75ca3ae..685060f 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -1,6 +1,6 @@
 #include "ModularNetwork.hpp"
 
-ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions)
+ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path)
 {
   setName(name);
   std::string anyBlanks = "(?:(?:\\s|\\t)*)";
@@ -28,7 +28,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
     std::string name = fmt::format("{}_{}", modules.size(), splited.first);
     std::string nameH = fmt::format("{}_{}", getName(), name);
     if (splited.first == "Context")
-      modules.emplace_back(register_module(name, ContextModule(nameH, splited.second)));
+      modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path)));
     else if (splited.first == "StateName")
       modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
     else if (splited.first == "History")
-- 
GitLab