From 26cc83e4c8331984320522e5d091098f760374d6 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 15 Mar 2020 13:34:31 +0100
Subject: [PATCH] NeuralNetwork::extractContext can now generate multiple
 variants of context

---
 decoder/src/Decoder.cpp                 |   4 +-
 torch_modules/include/CNNNetwork.hpp    |   2 +-
 torch_modules/include/NeuralNetwork.hpp |   2 +-
 torch_modules/include/RLTNetwork.hpp    |   2 +-
 torch_modules/src/CNNNetwork.cpp        | 163 +++++++++++++-----------
 torch_modules/src/NeuralNetwork.cpp     |   4 +-
 torch_modules/src/RLTNetwork.cpp        |   4 +-
 trainer/src/Trainer.cpp                 |  12 +-
 8 files changed, 104 insertions(+), 89 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 1d81309..44a8edc 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -25,10 +25,10 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
       config.printForDebug(stderr);
 
     auto dictState = machine.getDict(config.getState()).getState();
-    auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState()));
+    auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
     machine.getDict(config.getState()).setState(dictState);
 
-    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device);
+    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
     auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
 
     int chosenTransition = -1;
diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index 2edac49..0cd54b8 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -33,7 +33,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
 
   CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
   torch::Tensor forward(torch::Tensor input) override;
-  std::vector<long> extractContext(Config & config, Dict & dict) const override;
+  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
 
 #endif
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 1ca0919..34bf14b 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -27,7 +27,7 @@ class NeuralNetworkImpl : public torch::nn::Module
   public :
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
-  virtual std::vector<long> extractContext(Config & config, Dict & dict) const;
+  virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const;
   std::vector<long> extractContextIndexes(const Config & config) const;
   int getContextSize() const;
   void setColumns(const std::vector<std::string> & columns);
diff --git a/torch_modules/include/RLTNetwork.hpp b/torch_modules/include/RLTNetwork.hpp
index 7d350b3..b996def 100644
--- a/torch_modules/include/RLTNetwork.hpp
+++ b/torch_modules/include/RLTNetwork.hpp
@@ -23,7 +23,7 @@ class RLTNetworkImpl : public NeuralNetworkImpl
 
   RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
   torch::Tensor forward(torch::Tensor input) override;
-  std::vector<long> extractContext(Config & config, Dict & dict) const override;
+  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
 
 #endif
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 5e9696e..9f9d6a1 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -74,118 +74,129 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
   return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
 }
 
-std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
+std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
   if (dict.size() >= maxNbEmbeddings)
     util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
 
   std::vector<long> contextIndexes = extractContextIndexes(config);
-  std::vector<long> context;
+  std::vector<std::vector<long>> context;
+  context.emplace_back();
 
   if (rawInputSize > 0)
   {
     for (int i = 0; i < leftWindowRawInput; i++)
       if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
-        context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
       else
-        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
 
     for (int i = 0; i <= rightWindowRawInput; i++)
       if (config.hasCharacter(config.getCharacterIndex()+i))
-
-        context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
       else
-        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
   }
 
   for (auto index : contextIndexes)
     for (auto & col : columns)
       if (index == -1)
-        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        for (auto & contextElement : context)
+          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
       else
       {
         int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
-        if (col == "FORM" || col == "LEMMA")
-          if (dict.getNbOccs(dictIndex) < unknownValueThreshold)
-            dictIndex = dict.getIndexOrInsert(Dict::unknownValueStr);
 
-        context.push_back(dictIndex);
-      }
+        for (auto & contextElement : context)
+          contextElement.push_back(dictIndex);
 
-  for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
-  {
-    auto & col = focusedColumns[colIndex];
+        if (is_training())
+          if (col == "FORM" || col == "LEMMA")
+            if (dict.getNbOccs(dictIndex) < unknownValueThreshold)
+            {
+              context.emplace_back(context.back());
+              context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
+            }
+      }
 
-    std::vector<int> focusedIndexes;
-    for (auto relIndex : focusedBufferIndexes)
+  for (auto & contextElement : context)
+    for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
     {
-      int index = relIndex + leftBorder;
-      if (index < 0 || index >= (int)contextIndexes.size())
-        focusedIndexes.push_back(-1);
-      else
-        focusedIndexes.push_back(contextIndexes[index]);
-    }
-    for (auto index : focusedStackIndexes)
-    {
-      if (!config.hasStack(index))
-        focusedIndexes.push_back(-1);
-      else if (!config.has(col, config.getStack(index), 0))
-        focusedIndexes.push_back(-1);
-      else
-        focusedIndexes.push_back(config.getStack(index));
-    }
+      auto & col = focusedColumns[colIndex];
 
-    for (auto index : focusedIndexes)
-    {
-      if (index == -1)
+      std::vector<int> focusedIndexes;
+      for (auto relIndex : focusedBufferIndexes)
       {
-        for (int i = 0; i < maxNbElements[colIndex]; i++)
-          context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
-        continue;
+        int index = relIndex + leftBorder;
+        if (index < 0 || index >= (int)contextIndexes.size())
+          focusedIndexes.push_back(-1);
+        else
+          focusedIndexes.push_back(contextIndexes[index]);
       }
-
-      std::vector<std::string> elements;
-      if (col == "FORM")
+      for (auto index : focusedStackIndexes)
       {
-        auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
-
-        for (int i = 0; i < maxNbElements[colIndex]; i++)
-          if (i < (int)asUtf8.size())
-            elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
-          else
-            elements.emplace_back(Dict::nullValueStr);
+        if (!config.hasStack(index))
+          focusedIndexes.push_back(-1);
+        else if (!config.has(col, config.getStack(index), 0))
+          focusedIndexes.push_back(-1);
+        else
+          focusedIndexes.push_back(config.getStack(index));
       }
-      else if (col == "FEATS")
-      {
-        auto splited = util::split(config.getAsFeature(col, index).get(), '|');
 
-        for (int i = 0; i < maxNbElements[colIndex]; i++)
-          if (i < (int)splited.size())
-            elements.emplace_back(fmt::format("FEATS({})", splited[i]));
-          else
-            elements.emplace_back(Dict::nullValueStr);
-      }
-      else if (col == "ID")
+      for (auto index : focusedIndexes)
       {
-        if (config.isTokenPredicted(index))
-          elements.emplace_back("ID(TOKEN)");
-        else if (config.isMultiwordPredicted(index))
-          elements.emplace_back("ID(MULTIWORD)");
-        else if (config.isEmptyNodePredicted(index))
-          elements.emplace_back("ID(EMPTYNODE)");
+        if (index == -1)
+        {
+          for (int i = 0; i < maxNbElements[colIndex]; i++)
+            contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+          continue;
+        }
+
+        std::vector<std::string> elements;
+        if (col == "FORM")
+        {
+          auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
+
+          for (int i = 0; i < maxNbElements[colIndex]; i++)
+            if (i < (int)asUtf8.size())
+              elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
+            else
+              elements.emplace_back(Dict::nullValueStr);
+        }
+        else if (col == "FEATS")
+        {
+          auto splited = util::split(config.getAsFeature(col, index).get(), '|');
+
+          for (int i = 0; i < maxNbElements[colIndex]; i++)
+            if (i < (int)splited.size())
+              elements.emplace_back(fmt::format("FEATS({})", splited[i]));
+            else
+              elements.emplace_back(Dict::nullValueStr);
+        }
+        else if (col == "ID")
+        {
+          if (config.isTokenPredicted(index))
+            elements.emplace_back("ID(TOKEN)");
+          else if (config.isMultiwordPredicted(index))
+            elements.emplace_back("ID(MULTIWORD)");
+          else if (config.isEmptyNodePredicted(index))
+            elements.emplace_back("ID(EMPTYNODE)");
+        }
+        else
+        {
+          elements.emplace_back(config.getAsFeature(col, index));
+        }
+
+        if ((int)elements.size() != maxNbElements[colIndex])
+          util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
+
+        for (auto & element : elements)
+          contextElement.emplace_back(dict.getIndexOrInsert(element));
       }
-      else
-      {
-        elements.emplace_back(config.getAsFeature(col, index));
-      }
-
-      if ((int)elements.size() != maxNbElements[colIndex])
-        util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
-
-      for (auto & element : elements)
-        context.emplace_back(dict.getIndexOrInsert(element));
     }
-  }
+
+  if (!is_training() && context.size() > 1)
+    util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
 
   return context;
 }
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 37c206c..fef5519 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -35,7 +35,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
   return context;
 }
 
-std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
+std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
   std::vector<long> indexes = extractContextIndexes(config);
   std::vector<long> context;
@@ -47,7 +47,7 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict
       else
         context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
 
-  return context;
+  return {context};
 }
 
 int NeuralNetworkImpl::getContextSize() const
diff --git a/torch_modules/src/RLTNetwork.cpp b/torch_modules/src/RLTNetwork.cpp
index 85223e7..38fe642 100644
--- a/torch_modules/src/RLTNetwork.cpp
+++ b/torch_modules/src/RLTNetwork.cpp
@@ -79,7 +79,7 @@ torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
   return linear2(torch::relu(linear1(representation)));
 }
 
-std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
+std::vector<std::vector<long>> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
   std::vector<long> contextIndexes;
   std::stack<int> leftContext;
@@ -183,6 +183,6 @@ std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) c
       else
         context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));
 
-  return context;
+  return {context};
 }
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 2963701..501af8e 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -48,20 +48,24 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
       util::myThrow("No transition appliable !");
     }
 
+    std::vector<std::vector<long>> context;
+
     try
     {
-      auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
-      contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device));
+      context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
+      for (auto & element : context)
+        contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device));
     } catch(std::exception & e)
     {
       util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
     }
 
     int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
-    auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device));
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
     gold[0] = goldIndex;
 
-    classes.emplace_back(gold);
+    for (auto & element : context)
+      classes.emplace_back(gold);
 
     transition->apply(config);
     config.addToHistory(transition->getName());
-- 
GitLab