From 3489e3885fd1ceb262157040d399650d1000af68 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 12 Feb 2020 20:50:45 +0100
Subject: [PATCH] Config is now aware of what is predicted

---
 decoder/src/Decoder.cpp                    |  2 +
 reading_machine/include/Config.hpp         |  8 +++-
 reading_machine/include/ReadingMachine.hpp |  1 +
 reading_machine/src/Config.cpp             | 52 +++++++++++++++++++++-
 reading_machine/src/ReadingMachine.cpp     |  5 +++
 trainer/src/Trainer.cpp                    |  1 +
 6 files changed, 66 insertions(+), 3 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 8e74076..543dbbf 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -7,6 +7,8 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 
 void Decoder::decode(BaseConfig & config, std::size_t beamSize)
 {
+  config.addPredicted(machine.getPredicted());
+
   try
   {
   config.setState(machine.getStrategy().getInitialState());
diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp
index 5f34b24..82b3444 100644
--- a/reading_machine/include/Config.hpp
+++ b/reading_machine/include/Config.hpp
@@ -30,6 +30,7 @@ class Config
   private :
 
   std::vector<String> lines;
+  std::set<std::string> predicted;
 
   protected :
 
@@ -61,6 +62,8 @@ class Config
   String & get(int colIndex, int lineIndex, int hypothesisIndex);
   const String & getConst(int colIndex, int lineIndex, int hypothesisIndex) const;
   String & getLastNotEmpty(int colIndex, int lineIndex);
+  String & getLastNotEmptyHyp(int colIndex, int lineIndex);
+  const String & getLastNotEmptyHypConst(int colIndex, int lineIndex) const;
   const String & getLastNotEmptyConst(int colIndex, int lineIndex) const;
   ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex);
   ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const;
@@ -75,6 +78,8 @@ class Config
   const String & getConst(const std::string & colName, int lineIndex, int hypothesisIndex) const;
   String & getLastNotEmpty(const std::string & colName, int lineIndex);
   const String & getLastNotEmptyConst(const std::string & colName, int lineIndex) const;
+  String & getLastNotEmptyHyp(const std::string & colName, int lineIndex);
+  const String & getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const;
   String & getFirstEmpty(int colIndex, int lineIndex);
   String & getFirstEmpty(const std::string & colName, int lineIndex);
   bool hasCharacter(int letterIndex) const;
@@ -100,7 +105,8 @@ class Config
   void setState(const std::string state);
   bool stateIsDone() const;
   std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const;
-
+  void addPredicted(const std::set<std::string> & predicted);
+  bool isPredicted(const std::string & colName) const;
 };
 
 #endif
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 3e9eaf5..11058e9 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -39,6 +39,7 @@ class ReadingMachine
   Classifier * getClassifier();
   void save() const;
   bool isPredicted(const std::string & columnName) const;
+  const std::set<std::string> & getPredicted() const;
 };
 
 #endif
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 2f6e7cf..bd640ff 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -70,7 +70,10 @@ void Config::print(FILE * dest) const
       continue;
     }
     for (unsigned int i = 0; i < getNbColumns()-1; i++)
-      fmt::print(dest, "{}{}", getLastNotEmptyConst(i, getFirstLineIndex()+line), i < getNbColumns()-2 ? "\t" : "\n");
+    {
+      auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, getFirstLineIndex()+line) : getLastNotEmptyConst(i, getFirstLineIndex()+line);
+      fmt::print(dest, "{}{}", colContent, i < getNbColumns()-2 ? "\t" : "\n");
+    }
     if (getLastNotEmptyConst(EOSColName, getFirstLineIndex()+line) == EOSSymbol1)
       fmt::print(dest, "\n");
   }
@@ -105,7 +108,10 @@ void Config::printForDebug(FILE * dest) const
     toPrint.emplace_back();
     toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : "");
     for (unsigned int i = 0; i < getNbColumns(); i++)
-      toPrint.back().emplace_back(util::shrink(getLastNotEmptyConst(i, line), maxWordLength));
+    {
+      auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, line) : getLastNotEmptyConst(i, getFirstLineIndex()+line);
+      toPrint.back().emplace_back(util::shrink(colContent, maxWordLength));
+    }
   }
 
   std::vector<std::size_t> colLength(toPrint[0].size(), 0);
@@ -167,6 +173,17 @@ Config::String & Config::getLastNotEmpty(int colIndex, int lineIndex)
   return lines[baseIndex];
 }
 
+Config::String & Config::getLastNotEmptyHyp(int colIndex, int lineIndex)
+{
+  int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex);
+
+  for (int i = nbHypothesesMax; i > 0; --i)
+    if (!util::isEmpty(lines[baseIndex+i]))
+      return lines[baseIndex+i];
+
+  return lines[baseIndex+1];
+}
+
 Config::String & Config::getFirstEmpty(int colIndex, int lineIndex)
 {
   int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex);
@@ -194,16 +211,37 @@ const Config::String & Config::getLastNotEmptyConst(int colIndex, int lineIndex)
   return lines[baseIndex];
 }
 
+const Config::String & Config::getLastNotEmptyHypConst(int colIndex, int lineIndex) const
+{
+  int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex);
+
+  for (int i = nbHypothesesMax; i > 0; --i)
+    if (!util::isEmpty(lines[baseIndex+i]))
+      return lines[baseIndex+i];
+
+  return lines[baseIndex+1];
+}
+
 Config::String & Config::getLastNotEmpty(const std::string & colName, int lineIndex)
 {
   return getLastNotEmpty(getColIndex(colName), lineIndex);
 }
 
+Config::String & Config::getLastNotEmptyHyp(const std::string & colName, int lineIndex)
+{
+  return getLastNotEmptyHyp(getColIndex(colName), lineIndex);
+}
+
 const Config::String & Config::getLastNotEmptyConst(const std::string & colName, int lineIndex) const
 {
   return getLastNotEmptyConst(getColIndex(colName), lineIndex);
 }
 
+const Config::String & Config::getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const
+{
+  return getLastNotEmptyHypConst(getColIndex(colName), lineIndex);
+}
+
 Config::ValueIterator Config::getIterator(int colIndex, int lineIndex, int hypothesisIndex)
 {
   return lines.begin() + getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex) + hypothesisIndex;
@@ -393,3 +431,13 @@ std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict &
   return context;
 }
 
+void Config::addPredicted(const std::set<std::string> & predicted)
+{
+  this->predicted.insert(predicted.begin(), predicted.end());
+}
+
+bool Config::isPredicted(const std::string & colName) const
+{
+  return predicted.count(colName);
+}
+
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 0a7838c..2b5cb61 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -119,3 +119,8 @@ bool ReadingMachine::isPredicted(const std::string & columnName) const
   return predicted.count(columnName);
 }
 
+const std::set<std::string> & ReadingMachine::getPredicted() const
+{
+  return predicted;
+}
+
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 51d1a3b..6496aa3 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
 
 void Trainer::createDataset(SubConfig & config)
 {
+  config.addPredicted(machine.getPredicted());
   config.setState(machine.getStrategy().getInitialState());
 
   std::vector<torch::Tensor> contexts;
-- 
GitLab