From 0089639fe6a542da173a4fe773a5b67401b7fb4e Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 17 May 2020 18:34:31 +0200
Subject: [PATCH] Added module HistoryModule

---
 torch_modules/include/HistoryModule.hpp  | 31 +++++++++++
 torch_modules/include/ModularNetwork.hpp |  1 +
 torch_modules/src/HistoryModule.cpp      | 68 ++++++++++++++++++++++++
 torch_modules/src/ModularNetwork.cpp     |  2 +
 4 files changed, 102 insertions(+)
 create mode 100644 torch_modules/include/HistoryModule.hpp
 create mode 100644 torch_modules/src/HistoryModule.cpp

diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp
new file mode 100644
index 0000000..abcd26f
--- /dev/null
+++ b/torch_modules/include/HistoryModule.hpp
@@ -0,0 +1,31 @@
+#ifndef HISTORYMODULE__H
+#define HISTORYMODULE__H
+
+#include <torch/torch.h>
+#include "Submodule.hpp"
+#include "MyModule.hpp"
+#include "LSTM.hpp"
+#include "GRU.hpp"
+
+class HistoryModuleImpl : public Submodule
+{
+  private :
+
+  torch::nn::Embedding wordEmbeddings{nullptr};
+  std::shared_ptr<MyModule> myModule{nullptr};
+  int maxNbElements;
+  int inSize;
+
+  public :
+
+  HistoryModuleImpl(std::string name, const std::string & definition);
+  torch::Tensor forward(torch::Tensor input);
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings() override;
+};
+TORCH_MODULE(HistoryModule);
+
+#endif
+
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 40b1919..f49ba3f 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -11,6 +11,7 @@
 #include "StateNameModule.hpp"
 #include "UppercaseRateModule.hpp"
 #include "NumericColumnModule.hpp"
+#include "HistoryModule.hpp"
 #include "MLP.hpp"
 
 class ModularNetworkImpl : public NeuralNetworkImpl
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
new file mode 100644
index 0000000..bc9434b
--- /dev/null
+++ b/torch_modules/src/HistoryModule.cpp
@@ -0,0 +1,68 @@
+#include "HistoryModule.hpp"
+
+HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & definition)
+{
+  setName(name);
+  std::regex regex("(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
+        {
+          try
+          {
+            maxNbElements = std::stoi(sm.str(1));
+
+            auto subModuleType = sm.str(2);
+            auto subModuleArguments = util::split(sm.str(3), ' ');
+
+            auto options = MyModule::ModuleOptions(true)
+              .bidirectional(std::stoi(subModuleArguments[0]))
+              .num_layers(std::stoi(subModuleArguments[1]))
+              .dropout(std::stof(subModuleArguments[2]))
+              .complete(std::stoi(subModuleArguments[3]));
+
+            inSize = std::stoi(sm.str(4));
+            int outSize = std::stoi(sm.str(5));
+
+            if (subModuleType == "LSTM")
+              myModule = register_module("myModule", LSTM(inSize, outSize, options));
+            else if (subModuleType == "GRU")
+              myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else
+              util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
+
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
+}
+
+torch::Tensor HistoryModuleImpl::forward(torch::Tensor input)
+{
+  return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements)));
+}
+
+std::size_t HistoryModuleImpl::getOutputSize()
+{
+  return myModule->getOutputSize(maxNbElements);
+}
+
+std::size_t HistoryModuleImpl::getInputSize()
+{
+  return maxNbElements;
+}
+
+void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+{
+  auto & dict = getDict();
+
+  for (auto & contextElement : context)
+    for (int i = 0; i < maxNbElements; i++)
+      if (config.hasHistory(i))
+        contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i)));
+      else
+        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+}
+
+void HistoryModuleImpl::registerEmbeddings()
+{
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+}
+
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index c79791c..11b6962 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -31,6 +31,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
       modules.emplace_back(register_module(name, ContextModule(nameH, splited.second)));
     else if (splited.first == "StateName")
       modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
+    else if (splited.first == "History")
+      modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second)));
     else if (splited.first == "NumericColumn")
       modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second)));
     else if (splited.first == "UppercaseRate")
-- 
GitLab