From d2655230d8b685911d3e3b7ffe0dceab409a351b Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 6 May 2020 20:45:20 +0200
Subject: [PATCH] Added NumericColumnModule

---
 torch_modules/include/ModularNetwork.hpp      |  1 +
 torch_modules/include/NumericColumnModule.hpp | 31 +++++++
 torch_modules/src/ModularNetwork.cpp          |  2 +
 torch_modules/src/NumericColumnModule.cpp     | 88 +++++++++++++++++++
 4 files changed, 122 insertions(+)
 create mode 100644 torch_modules/include/NumericColumnModule.hpp
 create mode 100644 torch_modules/src/NumericColumnModule.cpp

diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 70c159c..7e721b9 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -9,6 +9,7 @@
 #include "DepthLayerTreeEmbeddingModule.hpp"
 #include "StateNameModule.hpp"
 #include "UppercaseRateModule.hpp"
+#include "NumericColumnModule.hpp"
 #include "MLP.hpp"
 
 class ModularNetworkImpl : public NeuralNetworkImpl
diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp
new file mode 100644
index 0000000..baa2bc8
--- /dev/null
+++ b/torch_modules/include/NumericColumnModule.hpp
@@ -0,0 +1,31 @@
+#ifndef NUMERICCOLUMNMODULE__H
+#define NUMERICCOLUMNMODULE__H
+
+#include <torch/torch.h>
+#include "Submodule.hpp"
+#include "MyModule.hpp"
+#include "LSTM.hpp"
+#include "GRU.hpp"
+
+class NumericColumnModuleImpl : public Submodule
+{
+  private :
+
+  int outSize;
+  std::vector<int> focusedBuffer, focusedStack;
+  std::shared_ptr<MyModule> myModule{nullptr};
+  std::string column;
+
+  public :
+
+  NumericColumnModuleImpl(std::string name, const std::string & definition);
+  torch::Tensor forward(torch::Tensor input);
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
+  void registerEmbeddings() override;
+};
+TORCH_MODULE(NumericColumnModule);
+
+#endif
+
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 5edf1b4..b4e23cc 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -27,6 +27,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
       modules.emplace_back(register_module(name, ContextModule(nameH, splited.second)));
     else if (splited.first == "StateName")
       modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
+    else if (splited.first == "NumericColumn")
+      modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second)));
     else if (splited.first == "UppercaseRate")
       modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second)));
     else if (splited.first == "Focused")
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
new file mode 100644
index 0000000..60ca6e4
--- /dev/null
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -0,0 +1,88 @@
+#include "NumericColumnModule.hpp"
+#include "NeuralNetwork.hpp"
+
+NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::string & definition)
+{
+  setName(name);
+  std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
+        {
+          try
+          {
+            column = sm.str(1);
+
+            for (auto & index : util::split(sm.str(2), ' '))
+              focusedBuffer.emplace_back(std::stoi(index));
+
+            for (auto & index : util::split(sm.str(3), ' '))
+              focusedStack.emplace_back(std::stoi(index));
+
+            auto subModuleType = sm.str(4);
+            auto subModuleArguments = util::split(sm.str(5), ' ');
+
+            auto options = MyModule::ModuleOptions(true)
+              .bidirectional(std::stoi(subModuleArguments[0]))
+              .num_layers(std::stoi(subModuleArguments[1]))
+              .dropout(std::stof(subModuleArguments[2]))
+              .complete(std::stoi(subModuleArguments[3]));
+
+            int outSize = std::stoi(sm.str(6));
+
+            if (subModuleType == "LSTM")
+              myModule = register_module("myModule", LSTM(1, outSize, options));
+            else if (subModuleType == "GRU")
+              myModule = register_module("myModule", GRU(1, outSize, options));
+            else
+              util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
+}
+
+torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
+{
+  auto context = input.narrow(1, firstInputIndex, getInputSize());
+  void * dataPtr = context.flatten().data_ptr();
+  auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::kDouble).clone().to(torch::kFloat).to(NeuralNetworkImpl::device).view({(long)context.size(0), (long)context.size(1), 1});
+  return myModule->forward(values);
+}
+
+std::size_t NumericColumnModuleImpl::getOutputSize()
+{
+  return myModule->getOutputSize(getInputSize());
+}
+
+std::size_t NumericColumnModuleImpl::getInputSize()
+{
+  return focusedBuffer.size() + focusedStack.size();
+}
+
+void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
+{
+  std::vector<long> focusedIndexes;
+
+  for (int index : focusedBuffer)
+    focusedIndexes.emplace_back(config.getRelativeWordIndex(index));
+
+  for (int index : focusedStack)
+    if (config.hasStack(index))
+      focusedIndexes.emplace_back(config.getStack(index));
+    else
+      focusedIndexes.emplace_back(-1);
+
+  for (auto & contextElement : context)
+    for (auto index : focusedIndexes)
+    {
+      double res = 0.0;
+      if (index >= 0)
+        res = std::stof(config.getAsFeature(column, index).get());
+
+      contextElement.emplace_back(0);
+      std::memcpy(&contextElement.back(), &res, sizeof res);
+    }
+}
+
+void NumericColumnModuleImpl::registerEmbeddings()
+{
+}
+
-- 
GitLab