From 95131d6d8186067630bb3457395ad04071964833 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 10 Jun 2020 13:36:31 +0200
Subject: [PATCH] Added DistanceModule

---
 reading_machine/include/Config.hpp       |   1 +
 reading_machine/src/Config.cpp           |  26 +++++
 torch_modules/include/DistanceModule.hpp |  34 +++++++
 torch_modules/include/ModularNetwork.hpp |   1 +
 torch_modules/src/DistanceModule.cpp     | 115 +++++++++++++++++++++++
 torch_modules/src/ModularNetwork.cpp     |   2 +
 6 files changed, 179 insertions(+)
 create mode 100644 torch_modules/include/DistanceModule.hpp
 create mode 100644 torch_modules/src/DistanceModule.cpp

diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp
index faa594e..71fb1ee 100644
--- a/reading_machine/include/Config.hpp
+++ b/reading_machine/include/Config.hpp
@@ -98,6 +98,7 @@ class Config
   std::size_t & getStackRef(int relativeIndex);
 
   long getRelativeWordIndex(int relativeIndex) const;
+  long getRelativeDistance(int fromIndex, int toIndex) const;
 
   public :
 
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 4e320b9..c967a4f 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -675,6 +675,32 @@ long Config::getRelativeWordIndex(int relativeIndex) const
   return -1;
 }
 
+long Config::getRelativeDistance(int fromIndex, int toIndex) const
+{
+  if (toIndex < fromIndex)
+  {
+    for (int index = fromIndex, counter = 0; has(0,index,0); --index)
+      if (!isCommentPredicted(index))
+      {
+        if (index == toIndex)
+          return counter;
+        --counter;
+      }
+  }
+  else
+  {
+    for (int index = fromIndex, counter = 0; has(0,index,0); ++index)
+      if (!isCommentPredicted(index))
+      {
+        if (index == toIndex)
+          return counter;
+        ++counter;
+      }
+  }
+
+  return 0;
+}
+
 long Config::getRelativeWordIndex(Object object, int relativeIndex) const
 {
   if (object == Object::Buffer)
diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp
new file mode 100644
index 0000000..b6e22d8
--- /dev/null
+++ b/torch_modules/include/DistanceModule.hpp
@@ -0,0 +1,34 @@
+#ifndef DISTANCEMODULE__H
+#define DISTANCEMODULE__H
+
+#include <torch/torch.h>
+#include "Submodule.hpp"
+#include "MyModule.hpp"
+#include "LSTM.hpp"
+#include "GRU.hpp"
+#include "Concat.hpp"
+
+class DistanceModuleImpl : public Submodule
+{
+  private :
+
+  torch::nn::Embedding wordEmbeddings{nullptr};
+  std::shared_ptr<MyModule> myModule{nullptr};
+  std::vector<int> fromBuffer, fromStack;
+  std::vector<int> toBuffer, toStack;
+  int threshold;
+  int inSize;
+
+  public :
+
+  DistanceModuleImpl(std::string name, const std::string & definition);
+  torch::Tensor forward(torch::Tensor input);
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings(std::filesystem::path pretrained) override;
+};
+TORCH_MODULE(DistanceModule);
+
+#endif
+
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index a6a6c3e..7e98302 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -12,6 +12,7 @@
 #include "UppercaseRateModule.hpp"
 #include "NumericColumnModule.hpp"
 #include "HistoryModule.hpp"
+#include "DistanceModule.hpp"
 #include "MLP.hpp"
 
 class ModularNetworkImpl : public NeuralNetworkImpl
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
new file mode 100644
index 0000000..50deea0
--- /dev/null
+++ b/torch_modules/src/DistanceModule.cpp
@@ -0,0 +1,115 @@
+#include "DistanceModule.hpp"
+
+DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & definition)
+{
+  setName(name);
+  std::regex regex("(?:(?:\\s|\\t)*)FromBuffer\\{(.*)\\}(?:(?:\\s|\\t)*)FromStack\\{(.*)\\}(?:(?:\\s|\\t)*)ToBuffer\\{(.*)\\}(?:(?:\\s|\\t)*)ToStack\\{(.*)\\}(?:(?:\\s|\\t)*)Threshold\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
+        {
+          try
+          {
+            for (auto & index : util::split(sm.str(1), ' '))
+              fromBuffer.emplace_back(std::stoi(index));
+
+            for (auto & index : util::split(sm.str(2), ' '))
+              fromStack.emplace_back(std::stoi(index));
+
+            for (auto & index : util::split(sm.str(3), ' '))
+              toBuffer.emplace_back(std::stoi(index));
+
+            for (auto & index : util::split(sm.str(4), ' '))
+              toStack.emplace_back(std::stoi(index));
+
+            threshold = std::stoi(sm.str(5));
+
+            auto subModuleType = sm.str(6);
+            auto subModuleArguments = util::split(sm.str(7), ' ');
+
+            auto options = MyModule::ModuleOptions(true)
+              .bidirectional(std::stoi(subModuleArguments[0]))
+              .num_layers(std::stoi(subModuleArguments[1]))
+              .dropout(std::stof(subModuleArguments[2]))
+              .complete(std::stoi(subModuleArguments[3]));
+
+            inSize = std::stoi(sm.str(8));
+            int outSize = std::stoi(sm.str(9));
+
+            if (subModuleType == "LSTM")
+              myModule = register_module("myModule", LSTM(inSize, outSize, options));
+            else if (subModuleType == "GRU")
+              myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else if (subModuleType == "Concat")
+              myModule = register_module("myModule", Concat(inSize));
+            else
+              util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
+
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
+}
+
+torch::Tensor DistanceModuleImpl::forward(torch::Tensor input)
+{
+  return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
+}
+
+std::size_t DistanceModuleImpl::getOutputSize()
+{
+  return myModule->getOutputSize(getInputSize());
+}
+
+std::size_t DistanceModuleImpl::getInputSize()
+{
+  return (fromBuffer.size()+fromStack.size()) * (toBuffer.size()+toStack.size());
+}
+
+void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+{
+  auto & dict = getDict();
+  std::vector<long> fromIndexes, toIndexes;
+
+  for (int index : fromBuffer)
+    fromIndexes.emplace_back(config.getRelativeWordIndex(index));
+
+  for (int index : fromStack)
+    if (config.hasStack(index))
+      fromIndexes.emplace_back(config.getStack(index));
+    else
+      fromIndexes.emplace_back(-1);
+
+  for (int index : toBuffer)
+    toIndexes.emplace_back(config.getRelativeWordIndex(index));
+
+  for (int index : toStack)
+    if (config.hasStack(index))
+      toIndexes.emplace_back(config.getStack(index));
+    else
+      toIndexes.emplace_back(-1);
+
+  for (auto & contextElement : context)
+  {
+    for (auto from : fromIndexes)
+      for (auto to : toIndexes)
+      {
+        if (from == -1 or to == -1)
+        {
+          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+          continue;
+        }
+
+        long dist = std::abs(config.getRelativeDistance(from, to));
+
+        if (dist <= threshold)
+          contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}", dist)));
+        else
+          contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr));
+      }
+  }
+}
+
+void DistanceModuleImpl::registerEmbeddings(std::filesystem::path path)
+{
+  wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  loadPretrainedW2vEmbeddings(wordEmbeddings, path);
+}
+
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index f9707a1..22cdb3a 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -45,6 +45,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
       modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second)));
     else if (splited.first == "AppliableTrans")
       modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs)));
+    else if (splited.first == "Distance")
+      modules.emplace_back(register_module(name, DistanceModule(nameH, splited.second)));
     else if (splited.first == "DepthLayerTree")
       modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second)));
     else if (splited.first == "MLP")
-- 
GitLab