From 82827388de33b2b1539b1958e8d24d52bebcc14d Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 19 Feb 2020 21:08:07 +0100
Subject: [PATCH] Added base class to neural network

---
 decoder/src/Decoder.cpp                       |  2 +-
 reading_machine/include/Classifier.hpp        | 10 +++--
 reading_machine/include/Config.hpp            |  1 -
 reading_machine/src/Classifier.cpp            | 42 +++++++++++++++++--
 reading_machine/src/Config.cpp                | 27 ------------
 torch_modules/include/ConcatWordsNetwork.hpp  | 24 +++++++++++
 torch_modules/include/NeuralNetwork.hpp       | 25 +++++++++++
 torch_modules/include/OneWordNetwork.hpp      | 25 +++++++++++
 torch_modules/include/TestNetwork.hpp         | 27 ------------
 torch_modules/src/ConcatWordsNetwork.cpp      | 37 ++++++++++++++++
 torch_modules/src/NeuralNetwork.cpp           | 34 +++++++++++++++
 .../{TestNetwork.cpp => OneWordNetwork.cpp}   | 10 ++---
 trainer/include/Trainer.hpp                   |  1 -
 trainer/src/Trainer.cpp                       |  2 +-
 14 files changed, 198 insertions(+), 69 deletions(-)
 create mode 100644 torch_modules/include/ConcatWordsNetwork.hpp
 create mode 100644 torch_modules/include/NeuralNetwork.hpp
 create mode 100644 torch_modules/include/OneWordNetwork.hpp
 delete mode 100644 torch_modules/include/TestNetwork.hpp
 create mode 100644 torch_modules/src/ConcatWordsNetwork.cpp
 create mode 100644 torch_modules/src/NeuralNetwork.cpp
 rename torch_modules/src/{TestNetwork.cpp => OneWordNetwork.cpp} (76%)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 0c0f9bb..deddf3f 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -19,7 +19,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
       config.printForDebug(stderr);
 
     auto dictState = machine.getDict(config.getState()).getState();
-    auto context = config.extractContext(5,5,machine.getDict(config.getState()));
+    auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState()));
     machine.getDict(config.getState()).setState(dictState);
 
     auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 35f0611..1131db7 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -3,7 +3,7 @@
 
 #include <string>
 #include "TransitionSet.hpp"
-#include "TestNetwork.hpp"
+#include "NeuralNetwork.hpp"
 
 class Classifier
 {
@@ -11,13 +11,17 @@ class Classifier
 
   std::string name;
   std::unique_ptr<TransitionSet> transitionSet;
-  TestNetwork nn{nullptr};
+  std::shared_ptr<NeuralNetworkImpl> nn;
+
+  private :
+
+  void initNeuralNetwork(const std::string & topology);
 
   public :
 
   Classifier(const std::string & name, const std::string & topology, const std::string & tsFile);
   TransitionSet & getTransitionSet();
-  TestNetwork & getNN();
+  NeuralNetwork & getNN();
   const std::string & getName() const;
 };
 
diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp
index b18bc88..310e2f4 100644
--- a/reading_machine/include/Config.hpp
+++ b/reading_machine/include/Config.hpp
@@ -107,7 +107,6 @@ class Config
   String getState() const;
   void setState(const std::string state);
   bool stateIsDone() const;
-  std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const;
   void addPredicted(const std::set<std::string> & predicted);
   bool isPredicted(const std::string & colName) const;
   int getLastPoppedStack() const;
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 13e25b4..6ddd264 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -1,10 +1,13 @@
 #include "Classifier.hpp"
+#include "util.hpp"
+#include "OneWordNetwork.hpp"
+#include "ConcatWordsNetwork.hpp"
 
 Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
 {
   this->name = name;
   this->transitionSet.reset(new TransitionSet(tsFile));
-  this->nn = TestNetwork(transitionSet->size(), 5);
+  initNeuralNetwork(topology);
 }
 
 TransitionSet & Classifier::getTransitionSet()
@@ -12,9 +15,9 @@ TransitionSet & Classifier::getTransitionSet()
   return *transitionSet;
 }
 
-TestNetwork & Classifier::getNN()
+NeuralNetwork & Classifier::getNN()
 {
-  return nn;
+  return reinterpret_cast<NeuralNetwork&>(nn);
 }
 
 const std::string & Classifier::getName() const
@@ -22,3 +25,36 @@ const std::string & Classifier::getName() const
   return name;
 }
 
+void Classifier::initNeuralNetwork(const std::string & topology)
+{
+  static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
+  {
+    {
+      std::regex("OneWord\\((\\d+)\\)"),
+      "OneWord(focusedIndex) : Only use the word embedding of the focused word.",
+      [this,topology](auto sm)
+      {
+        this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm[1])));
+      }
+    },
+    {
+      std::regex("ConcatWords"),
+      "ConcatWords : Concatenate embeddings of words in context.",
+      [this,topology](auto)
+      {
+        this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size()));
+      }
+    },
+  };
+
+  for (auto & initializer : initializers)
+    if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
+      return;
+
+  std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
+  for (auto & initializer : initializers)
+    errorMessage += std::get<1>(initializer) + "\n";
+
+  util::myThrow(errorMessage);
+}
+
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 286386a..dd8ca3e 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -455,33 +455,6 @@ bool Config::stateIsDone() const
   return !has(0, wordIndex+1, 0) and !hasStack(0);
 }
 
-std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
-{
-  std::stack<int> leftContext;
-  for (int index = wordIndex-1; has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
-    if (isToken(index))
-      leftContext.push(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index)));
-
-  std::vector<long> context;
-
-  while ((int)context.size() < leftBorder-(int)leftContext.size())
-    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
-  while (!leftContext.empty())
-  {
-    context.emplace_back(leftContext.top());
-    leftContext.pop();
-  }
-
-  for (int index = wordIndex; has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index)
-    if (isToken(index))
-      context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index)));
-
-  while ((int)context.size() < leftBorder+rightBorder+1)
-    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
-
-  return context;
-}
-
 void Config::addPredicted(const std::set<std::string> & predicted)
 {
   this->predicted.insert(predicted.begin(), predicted.end());
diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp
new file mode 100644
index 0000000..4f67b3a
--- /dev/null
+++ b/torch_modules/include/ConcatWordsNetwork.hpp
@@ -0,0 +1,24 @@
+#ifndef CONCATWORDSNETWORK__H
+#define CONCATWORDSNETWORK__H
+
+#include "NeuralNetwork.hpp"
+
+class ConcatWordsNetworkImpl : public NeuralNetworkImpl
+{
+  private :
+
+  torch::nn::Embedding wordEmbeddings{nullptr};
+  torch::nn::Linear linear{nullptr};
+
+  std::vector<torch::Tensor> _denseParameters;
+  std::vector<torch::Tensor> _sparseParameters;
+
+  public :
+
+  ConcatWordsNetworkImpl(int nbOutputs);
+  torch::Tensor forward(torch::Tensor input) override;
+  std::vector<torch::Tensor> & denseParameters() override;
+  std::vector<torch::Tensor> & sparseParameters() override;
+};
+
+#endif
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
new file mode 100644
index 0000000..1846219
--- /dev/null
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -0,0 +1,25 @@
+#ifndef NEURALNETWORK__H
+#define NEURALNETWORK__H
+
+#include <torch/torch.h>
+#include "Config.hpp"
+#include "Dict.hpp"
+
+class NeuralNetworkImpl : public torch::nn::Module
+{
+  private : 
+
+  int leftBorder{5};
+  int rightBorder{5};
+
+  public :
+
+  virtual std::vector<torch::Tensor> & denseParameters() = 0;
+  virtual std::vector<torch::Tensor> & sparseParameters() = 0;
+  virtual torch::Tensor forward(torch::Tensor input) = 0;
+  std::vector<long> extractContext(Config & config, Dict & dict) const;
+  int getContextSize() const;
+};
+TORCH_MODULE(NeuralNetwork);
+
+#endif
diff --git a/torch_modules/include/OneWordNetwork.hpp b/torch_modules/include/OneWordNetwork.hpp
new file mode 100644
index 0000000..29edb7d
--- /dev/null
+++ b/torch_modules/include/OneWordNetwork.hpp
@@ -0,0 +1,25 @@
+#ifndef ONEWORDNETWORK__H
+#define ONEWORDNETWORK__H
+
+#include "NeuralNetwork.hpp"
+
+class OneWordNetworkImpl : public NeuralNetworkImpl
+{
+  private :
+
+  torch::nn::Embedding wordEmbeddings{nullptr};
+  torch::nn::Linear linear{nullptr};
+  int focusedIndex;
+
+  std::vector<torch::Tensor> _denseParameters;
+  std::vector<torch::Tensor> _sparseParameters;
+
+  public :
+
+  OneWordNetworkImpl(int nbOutputs, int focusedIndex);
+  torch::Tensor forward(torch::Tensor input) override;
+  std::vector<torch::Tensor> & denseParameters() override;
+  std::vector<torch::Tensor> & sparseParameters() override;
+};
+
+#endif
diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp
deleted file mode 100644
index 27b92e8..0000000
--- a/torch_modules/include/TestNetwork.hpp
+++ /dev/null
@@ -1,27 +0,0 @@
-#ifndef TESTNETWORK__H
-#define TESTNETWORK__H
-
-#include <torch/torch.h>
-#include "Config.hpp"
-
-class TestNetworkImpl : public torch::nn::Module
-{
-  private :
-
-  torch::nn::Embedding wordEmbeddings{nullptr};
-  torch::nn::Linear linear{nullptr};
-  int focusedIndex;
-
-  std::vector<torch::Tensor> _denseParameters;
-  std::vector<torch::Tensor> _sparseParameters;
-
-  public :
-
-  TestNetworkImpl(int nbOutputs, int focusedIndex);
-  torch::Tensor forward(torch::Tensor input);
-  std::vector<torch::Tensor> & denseParameters();
-  std::vector<torch::Tensor> & sparseParameters();
-};
-TORCH_MODULE(TestNetwork);
-
-#endif
diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp
new file mode 100644
index 0000000..1343d53
--- /dev/null
+++ b/torch_modules/src/ConcatWordsNetwork.cpp
@@ -0,0 +1,37 @@
+#include "ConcatWordsNetwork.hpp"
+
+ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs)
+{
+  constexpr int embeddingsSize = 30;
+
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
+  auto params = wordEmbeddings->parameters();
+  _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
+
+  linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
+  params = linear->parameters();
+  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
+}
+
+std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters()
+{
+  return _denseParameters;
+}
+
+std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters()
+{
+  return _sparseParameters;
+}
+
+torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
+{
+  // input dim = {batch, sequence, embeddings}
+  auto wordsAsEmb = wordEmbeddings(input);
+  // reshaped dim = {batch, sequence of embeddings}
+  auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
+
+  auto res = torch::softmax(linear(reshaped), reshaped.dim() == 2 ? 1 : 0);
+
+  return res;
+}
+
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
new file mode 100644
index 0000000..ab8921e
--- /dev/null
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -0,0 +1,34 @@
+#include "NeuralNetwork.hpp"
+
+std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
+{
+  std::stack<int> leftContext;
+  for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
+    if (config.isToken(index))
+      leftContext.push(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index)));
+
+  std::vector<long> context;
+
+  while ((int)context.size() < leftBorder-(int)leftContext.size())
+    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  while (!leftContext.empty())
+  {
+    context.emplace_back(leftContext.top());
+    leftContext.pop();
+  }
+
+  for (int index = config.getWordIndex(); config.has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index)
+    if (config.isToken(index))
+      context.emplace_back(dict.getIndexOrInsert(config.getLastNotEmptyConst("FORM", index)));
+
+  while ((int)context.size() < leftBorder+rightBorder+1)
+    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+
+  return context;
+}
+
+int NeuralNetworkImpl::getContextSize() const
+{
+  return 1 + leftBorder + rightBorder;
+}
+
diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp
similarity index 76%
rename from torch_modules/src/TestNetwork.cpp
rename to torch_modules/src/OneWordNetwork.cpp
index 19debf8..5cfa4f7 100644
--- a/torch_modules/src/TestNetwork.cpp
+++ b/torch_modules/src/OneWordNetwork.cpp
@@ -1,6 +1,6 @@
-#include "TestNetwork.hpp"
+#include "OneWordNetwork.hpp"
 
-TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
+OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
 {
   constexpr int embeddingsSize = 30;
 
@@ -15,17 +15,17 @@ TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
   this->focusedIndex = focusedIndex;
 }
 
-std::vector<torch::Tensor> & TestNetworkImpl::denseParameters()
+std::vector<torch::Tensor> & OneWordNetworkImpl::denseParameters()
 {
   return _denseParameters;
 }
 
-std::vector<torch::Tensor> & TestNetworkImpl::sparseParameters()
+std::vector<torch::Tensor> & OneWordNetworkImpl::sparseParameters()
 {
   return _sparseParameters;
 }
 
-torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
+torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
 {
   // input dim = {batch, sequence, embeddings}
   auto wordsAsEmb = wordEmbeddings(input);
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 69dde8d..0f9c3ec 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -4,7 +4,6 @@
 #include "ReadingMachine.hpp"
 #include "ConfigDataset.hpp"
 #include "SubConfig.hpp"
-#include "TestNetwork.hpp"
 
 class Trainer
 {
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 195cc5c..8e84642 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -25,7 +25,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
       util::myThrow("No transition appliable !");
     }
 
-    auto context = config.extractContext(5,5,machine.getDict(config.getState()));
+    auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
     contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
 
     int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
-- 
GitLab