From 050839edd5b7727159ad8ac8811a1b91452718a4 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 5 Apr 2020 16:14:33 +0200
Subject: [PATCH] Sequential

---
 decoder/src/Decoder.cpp                   |  3 +--
 reading_machine/include/Classifier.hpp    |  2 +-
 reading_machine/include/Config.hpp        |  1 +
 reading_machine/include/SubConfig.hpp     |  7 ++-----
 reading_machine/include/TransitionSet.hpp |  5 +++++
 reading_machine/src/Classifier.cpp        |  4 ++--
 reading_machine/src/Config.cpp            | 24 ++++++++++++++++++++++-
 reading_machine/src/ReadingMachine.cpp    |  8 +++++++-
 reading_machine/src/SubConfig.cpp         |  2 +-
 reading_machine/src/TransitionSet.cpp     | 11 +++++++++++
 torch_modules/src/LSTMNetwork.cpp         | 14 +++++++++----
 trainer/src/MacaonTrain.cpp               |  4 ++--
 trainer/src/Trainer.cpp                   |  6 +-----
 13 files changed, 67 insertions(+), 24 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 7d563c8..3a2caef 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -81,8 +81,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
       break;
 
     config.setState(movement.first);
-    if (!config.moveWordIndex(movement.second))
-      util::myThrow("Cannot move word index !");
+    config.moveWordIndexRelaxed(movement.second);
   }
   } catch(std::exception & e) {util::myThrow(e.what());}
 
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 1131db7..013a097 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -19,7 +19,7 @@ class Classifier
 
   public :
 
-  Classifier(const std::string & name, const std::string & topology, const std::string & tsFile);
+  Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFile);
   TransitionSet & getTransitionSet();
   NeuralNetwork & getNN();
   const std::string & getName() const;
diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp
index ea092f0..c9b7d4d 100644
--- a/reading_machine/include/Config.hpp
+++ b/reading_machine/include/Config.hpp
@@ -108,6 +108,7 @@ class Config
   bool isTokenPredicted(std::size_t lineIndex) const;
   bool moveWordIndex(int relativeMovement);
   bool canMoveWordIndex(int relativeMovement) const;
+  void moveWordIndexRelaxed(int relativeMovement);
   bool moveCharacterIndex(int relativeMovement);
   bool canMoveCharacterIndex(int relativeMovement) const;
   bool rawInputOnlySeparatorsLeft() const;
diff --git a/reading_machine/include/SubConfig.hpp b/reading_machine/include/SubConfig.hpp
index 2f1efd1..05b2541 100644
--- a/reading_machine/include/SubConfig.hpp
+++ b/reading_machine/include/SubConfig.hpp
@@ -8,16 +8,13 @@ class SubConfig : public Config
 {
   private :
 
-  static constexpr std::size_t spanSize = 800;
-
-  private :
-
   const BaseConfig & model;
+  std::size_t spanSize;
   std::size_t firstLineIndex{0};
 
   public :
 
-  SubConfig(BaseConfig & model);
+  SubConfig(BaseConfig & model, std::size_t spanSize);
   bool update();
   bool needsUpdate();
   std::size_t getNbColumns() const override;
diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp
index a1bc2c1..d0c7c1f 100644
--- a/reading_machine/include/TransitionSet.hpp
+++ b/reading_machine/include/TransitionSet.hpp
@@ -12,8 +12,13 @@ class TransitionSet
 
   std::vector<Transition> transitions;
 
+  private :
+
+  void addTransitionsFromFile(const std::string & filename);
+
   public :
 
+  TransitionSet(const std::vector<std::string> & filenames);
   TransitionSet(const std::string & filename);
   std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c);
   Transition * getBestAppliableTransition(const Config & c);
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index a58c3b1..54dccff 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -6,10 +6,10 @@
 #include "LSTMNetwork.hpp"
 #include "RandomNetwork.hpp"
 
-Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
+Classifier::Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFiles)
 {
   this->name = name;
-  this->transitionSet.reset(new TransitionSet(tsFile));
+  this->transitionSet.reset(new TransitionSet(tsFiles));
   initNeuralNetwork(topology);
 }
 
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index bb2366c..8681cce 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -431,6 +431,28 @@ bool Config::moveWordIndex(int relativeMovement)
   return true;
 }
 
+void Config::moveWordIndexRelaxed(int relativeMovement)
+{
+  int nbMovements = 0;
+  int increment = relativeMovement > 0 ? 1 : -1;
+  while (nbMovements != relativeMovement)
+  {
+    do
+    {
+      if (!has(0,wordIndex+increment,0))
+        break;
+      wordIndex += increment;
+    }
+    while (isComment(wordIndex));
+    nbMovements += relativeMovement > 0 ? 1 : -1;
+  }
+
+  if (!isComment(wordIndex))
+    return;
+
+  moveWordIndex(-increment);
+}
+
 bool Config::canMoveWordIndex(int relativeMovement) const
 {
   int nbMovements = 0;
@@ -515,7 +537,7 @@ void Config::setState(const std::string state)
 bool Config::stateIsDone() const
 {
   if (!rawInput.empty())
-    return rawInputOnlySeparatorsLeft();
+    return rawInputOnlySeparatorsLeft() and !has(0, wordIndex+1, 0) and !hasStack(0);
 
   return !has(0, wordIndex+1, 0) and !hasStack(0);
 }
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 96b5e71..9d4a854 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -57,7 +57,13 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
     if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
       util::myThrow("No name specified");
 
-    while (util::doIfNameMatch(std::regex("Classifier : (.+) (.+) (.+)"), lines[curLine++], [this,path](auto sm){classifier.reset(new Classifier(sm.str(1), sm.str(2), path.parent_path() / sm.str(3)));}));
+    while (util::doIfNameMatch(std::regex("Classifier : (.+) (.+) \\{(.+)\\}"), lines[curLine++], [this,path](auto sm)
+      {
+        std::vector<std::string> tsFiles = util::split(sm.str(3), ' ');
+        for (auto & tsFile : tsFiles)
+          tsFile = path.parent_path() / tsFile;
+        classifier.reset(new Classifier(sm.str(1), sm.str(2), tsFiles));
+      }));
     if (!classifier.get())
       util::myThrow("No Classifier specified");
 
diff --git a/reading_machine/src/SubConfig.cpp b/reading_machine/src/SubConfig.cpp
index 1b63ed2..c571b3c 100644
--- a/reading_machine/src/SubConfig.cpp
+++ b/reading_machine/src/SubConfig.cpp
@@ -1,6 +1,6 @@
 #include "SubConfig.hpp"
 
-SubConfig::SubConfig(BaseConfig & model) : Config(model.rawInput), model(model)
+SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : Config(model.rawInput), model(model), spanSize(spanSize)
 {
   wordIndex = model.wordIndex;
   characterIndex = model.characterIndex;
diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp
index 5d0df94..a6ed1b0 100644
--- a/reading_machine/src/TransitionSet.cpp
+++ b/reading_machine/src/TransitionSet.cpp
@@ -2,6 +2,17 @@
 #include <limits>
 
 TransitionSet::TransitionSet(const std::string & filename)
+{
+  addTransitionsFromFile(filename);
+}
+
+TransitionSet::TransitionSet(const std::vector<std::string> & filenames)
+{
+  for (auto & filename : filenames)
+    addTransitionsFromFile(filename);
+}
+
+void TransitionSet::addTransitionsFromFile(const std::string & filename)
 {
   FILE * file = std::fopen(filename.c_str(), "r");
   if (!file)
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index 201feea..7ab2dc2 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -40,7 +40,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
     totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (bufferFocused.size()+stackFocused.size());
   }
 
-  linear1 = register_module("linear1", torch::nn::Linear(totalLSTMOutputSize, hiddenSize));
+  linear1 = register_module("linear1", torch::nn::Linear(embeddingsSize+totalLSTMOutputSize, hiddenSize));
   linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
 }
 
@@ -51,16 +51,20 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
 
   auto embeddings = embeddingsDropout(wordEmbeddings(input));
 
-  auto splitTrans = embeddings.narrow(1, 0, Config::maxNbAppliableSplitTransitions);
+  auto state = embeddings.narrow(1, 0, 1).squeeze(1);
 
-  auto context = embeddings.narrow(1, splitTrans.size(1)+rawInputSize, getContextSize());
+  auto splitTrans = embeddings.narrow(1, 1, Config::maxNbAppliableSplitTransitions);
+
+  auto context = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize, getContextSize());
 
   context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
 
-  auto elementsEmbeddings = embeddings.narrow(1, splitTrans.size(1)+rawInputSize+context.size(1), input.size(1)-(splitTrans.size(1)+rawInputSize+context.size(1)));
+  auto elementsEmbeddings = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize+context.size(1), input.size(1)-(1+splitTrans.size(1)+rawInputSize+context.size(1)));
 
   std::vector<torch::Tensor> lstmOutputs;
 
+  lstmOutputs.emplace_back(state);
+
   if (rawInputSize != 0)
   {
     auto rawLetters = embeddings.narrow(1, splitTrans.size(1), rawInputSize);
@@ -110,6 +114,8 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
   std::vector<std::vector<long>> context;
   context.emplace_back();
 
+  context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
+
   auto & splitTransitions = config.getAppliableSplitTransitions();
   for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++)
     if (i < (int)splitTransitions.size())
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 7b8e60f..ffce48f 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -101,7 +101,7 @@ int MacaonTrain::main()
 
   BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
   BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
-  SubConfig config(goldConfig);
+  SubConfig config(goldConfig, goldConfig.getNbLines());
 
   fillDicts(machine, goldConfig);
 
@@ -109,7 +109,7 @@ int MacaonTrain::main()
   trainer.createDataset(config, debug);
   if (!computeDevScore)
   {
-    SubConfig devConfig(devGoldConfig);
+    SubConfig devConfig(devGoldConfig, devGoldConfig.getNbLines());
     trainer.createDevDataset(devConfig, debug);
   }
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 57e7e51..6459bb3 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -99,11 +99,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
       break;
 
     config.setState(movement.first);
-    if (!config.moveWordIndex(movement.second))
-    {
-      config.printForDebug(stderr);
-      util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
-    }
+    config.moveWordIndexRelaxed(movement.second);
 
     if (config.needsUpdate())
       config.update();
-- 
GitLab