From 28045459af02c15661908a8588fc0eff576ff17b Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 17 Apr 2020 19:35:13 +0200
Subject: [PATCH] split unknown only when extracting train dataset

---
 decoder/src/Decoder.cpp                           |  1 +
 reading_machine/include/ReadingMachine.hpp        |  1 +
 reading_machine/src/ReadingMachine.cpp            |  5 +++++
 torch_modules/include/ContextLSTM.hpp             |  2 +-
 torch_modules/include/DepthLayerTreeEmbedding.hpp |  2 +-
 torch_modules/include/FocusedColumnLSTM.hpp       |  2 +-
 torch_modules/include/NeuralNetwork.hpp           |  6 ++++++
 torch_modules/include/RawInputLSTM.hpp            |  2 +-
 torch_modules/include/SplitTransLSTM.hpp          |  2 +-
 torch_modules/include/Submodule.hpp               |  2 +-
 torch_modules/src/ContextLSTM.cpp                 |  7 +++++--
 torch_modules/src/DepthLayerTreeEmbedding.cpp     |  2 +-
 torch_modules/src/FocusedColumnLSTM.cpp           |  2 +-
 torch_modules/src/LSTMNetwork.cpp                 | 14 +++++++-------
 torch_modules/src/NeuralNetwork.cpp               | 10 ++++++++++
 torch_modules/src/RawInputLSTM.cpp                |  2 +-
 torch_modules/src/SplitTransLSTM.cpp              |  2 +-
 trainer/src/Trainer.cpp                           |  2 ++
 18 files changed, 47 insertions(+), 19 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 9b6b3a6..33c4837 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
 {
   torch::AutoGradMode useGrad(false);
   machine.trainMode(false);
+  machine.splitUnknown(false);
   machine.setDictsState(Dict::State::Closed);
   machine.getStrategy().reset();
   config.addPredicted(machine.getPredicted());
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 9eb09d0..5f3ff1c 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -47,6 +47,7 @@ class ReadingMachine
   bool isPredicted(const std::string & columnName) const;
   const std::set<std::string> & getPredicted() const;
   void trainMode(bool isTrainMode);
+  void splitUnknown(bool splitUnknown);
   void setDictsState(Dict::State state);
   void saveBest() const;
   void saveLast() const;
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 38f79c8..138c249 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -182,6 +182,11 @@ void ReadingMachine::trainMode(bool isTrainMode)
   classifier->getNN()->train(isTrainMode);
 }
 
+void ReadingMachine::splitUnknown(bool splitUnknown)
+{
+  classifier->getNN()->setSplitUnknown(splitUnknown);
+}
+
 void ReadingMachine::setDictsState(Dict::State state)
 {
   for (auto & it : dicts)
diff --git a/torch_modules/include/ContextLSTM.hpp b/torch_modules/include/ContextLSTM.hpp
index 136029c..3e3bbac 100644
--- a/torch_modules/include/ContextLSTM.hpp
+++ b/torch_modules/include/ContextLSTM.hpp
@@ -22,7 +22,7 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(ContextLSTM);
 
diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp
index 436a082..2a8f7e8 100644
--- a/torch_modules/include/DepthLayerTreeEmbedding.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp
@@ -21,7 +21,7 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(DepthLayerTreeEmbedding);
 
diff --git a/torch_modules/include/FocusedColumnLSTM.hpp b/torch_modules/include/FocusedColumnLSTM.hpp
index 6ea836a..fd5d915 100644
--- a/torch_modules/include/FocusedColumnLSTM.hpp
+++ b/torch_modules/include/FocusedColumnLSTM.hpp
@@ -20,7 +20,7 @@ class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(FocusedColumnLSTM);
 
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index be25c87..1237f09 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -11,6 +11,10 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   static torch::Device device;
 
+  private :
+
+  bool splitUnknown{false};
+
   protected : 
 
   static constexpr int maxNbEmbeddings = 150000;
@@ -19,6 +23,8 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
   virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
+  bool mustSplitUnknown() const;
+  void setSplitUnknown(bool splitUnknown);
 };
 TORCH_MODULE(NeuralNetwork);
 
diff --git a/torch_modules/include/RawInputLSTM.hpp b/torch_modules/include/RawInputLSTM.hpp
index db17d6f..0e08560 100644
--- a/torch_modules/include/RawInputLSTM.hpp
+++ b/torch_modules/include/RawInputLSTM.hpp
@@ -18,7 +18,7 @@ class RawInputLSTMImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(RawInputLSTM);
 
diff --git a/torch_modules/include/SplitTransLSTM.hpp b/torch_modules/include/SplitTransLSTM.hpp
index f90c0ed..85d542c 100644
--- a/torch_modules/include/SplitTransLSTM.hpp
+++ b/torch_modules/include/SplitTransLSTM.hpp
@@ -18,7 +18,7 @@ class SplitTransLSTMImpl : public torch::nn::Module, public Submodule
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
-  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
 TORCH_MODULE(SplitTransLSTM);
 
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 437bbfa..cc38101 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -15,7 +15,7 @@ class Submodule
   void setFirstInputIndex(std::size_t firstInputIndex);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
-  virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
+  virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0;
 };
 
 #endif
diff --git a/torch_modules/src/ContextLSTM.cpp b/torch_modules/src/ContextLSTM.cpp
index 95daa69..5da02e7 100644
--- a/torch_modules/src/ContextLSTM.cpp
+++ b/torch_modules/src/ContextLSTM.cpp
@@ -15,7 +15,7 @@ std::size_t ContextLSTMImpl::getInputSize()
   return columns.size()*(bufferContext.size()+stackContext.size());
 }
 
-void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const
 {
   std::vector<long> contextIndexes;
 
@@ -31,8 +31,10 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
   for (auto index : contextIndexes)
     for (auto & col : columns)
       if (index == -1)
+      {
         for (auto & contextElement : context)
           contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+      }
       else
       {
         int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
@@ -40,7 +42,8 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
         for (auto & contextElement : context)
           contextElement.push_back(dictIndex);
 
-        if (is_training())
+
+        if (splitUnknown)
           for (auto & targetCol : unknownValueColumns)
             if (col == targetCol)
               if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp
index 6e1342a..b506f92 100644
--- a/torch_modules/src/DepthLayerTreeEmbedding.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp
@@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
   return inputSize;
 }
 
-void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   std::vector<long> focusedIndexes;
 
diff --git a/torch_modules/src/FocusedColumnLSTM.cpp b/torch_modules/src/FocusedColumnLSTM.cpp
index 4e0da0e..e39af63 100644
--- a/torch_modules/src/FocusedColumnLSTM.cpp
+++ b/torch_modules/src/FocusedColumnLSTM.cpp
@@ -24,7 +24,7 @@ std::size_t FocusedColumnLSTMImpl::getInputSize()
   return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
 }
 
-void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   std::vector<long> focusedIndexes;
 
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index cfa004e..a4f5863 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -94,21 +94,21 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
 
   context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
 
-  contextLSTM->addToContext(context, dict, config);
+  contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
 
   if (!rawInputLSTM.is_empty())
-    rawInputLSTM->addToContext(context, dict, config);
+    rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
 
   if (!treeEmbedding.is_empty())
-    treeEmbedding->addToContext(context, dict, config);
+    treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
 
-  splitTransLSTM->addToContext(context, dict, config);
+  splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
 
   for (auto & lstm : focusedLstms)
-    lstm->addToContext(context, dict, config);
+    lstm->addToContext(context, dict, config, mustSplitUnknown());
 
-  if (!is_training() && context.size() > 1)
-    util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
+  if (!mustSplitUnknown() && context.size() > 1)
+    util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
 
   return context;
 }
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 02e8a19..235c677 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -2,3 +2,13 @@
 
 torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
 
+bool NeuralNetworkImpl::mustSplitUnknown() const
+{
+  return splitUnknown;
+}
+
+void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown)
+{
+  this->splitUnknown = splitUnknown;
+}
+
diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputLSTM.cpp
index 2aa8cfd..c6da426 100644
--- a/torch_modules/src/RawInputLSTM.cpp
+++ b/torch_modules/src/RawInputLSTM.cpp
@@ -20,7 +20,7 @@ std::size_t RawInputLSTMImpl::getInputSize()
   return leftWindow + rightWindow + 1;
 }
 
-void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   if (leftWindow < 0 or rightWindow < 0)
     return;
diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp
index a83894a..99a1b35 100644
--- a/torch_modules/src/SplitTransLSTM.cpp
+++ b/torch_modules/src/SplitTransLSTM.cpp
@@ -21,7 +21,7 @@ std::size_t SplitTransLSTMImpl::getInputSize()
   return maxNbTrans;
 }
 
-void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   auto & splitTransitions = config.getAppliableSplitTransitions();
   for (auto & contextElement : context)
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 03b7f88..95e98eb 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -10,6 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
+  machine.splitUnknown(true);
   machine.setDictsState(Dict::State::Open);
 
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
@@ -23,6 +24,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
+  machine.splitUnknown(false);
   machine.setDictsState(Dict::State::Closed);
 
   extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
-- 
GitLab