From ab0bfc2795226908cc3f6afcc3bdcd945ef1d22d Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 17 Feb 2021 09:35:32 +0100
Subject: [PATCH] Working lineByLine in tsv mode

---
 common/include/util.hpp                |  2 +
 common/src/util.cpp                    | 42 ++++++++++++++
 decoder/src/MacaonDecode.cpp           | 41 ++++++++-----
 reading_machine/include/BaseConfig.hpp |  4 +-
 reading_machine/src/BaseConfig.cpp     | 50 +++-------------
 trainer/src/MacaonTrain.cpp            | 80 +++++++++++---------------
 6 files changed, 114 insertions(+), 105 deletions(-)

diff --git a/common/include/util.hpp b/common/include/util.hpp
index 478ff97..bd1a48b 100644
--- a/common/include/util.hpp
+++ b/common/include/util.hpp
@@ -54,6 +54,8 @@ std::string getTime();
 long float2long(float f);
 float long2float(long l);
 
+std::vector<std::vector<std::string>> readTSV(std::string_view tsvFilename);
+
 template <typename T>
 bool isEmpty(const std::vector<T> & s)
 {
diff --git a/common/src/util.cpp b/common/src/util.cpp
index d896173..f5717be 100644
--- a/common/src/util.cpp
+++ b/common/src/util.cpp
@@ -365,3 +365,45 @@ std::vector<util::utf8string> util::readFileAsUtf8(std::string_view filename, bo
   return res;
 }
 
+std::vector<std::vector<std::string>> util::readTSV(std::string_view tsvFilename)
+{
+  std::vector<std::vector<std::string>> sentences;
+
+  std::FILE * file = std::fopen(tsvFilename.data(), "r");
+
+  if (not file)
+    util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename));
+
+  char lineBuffer[100000];
+  bool inputHasBeenRead = false;
+
+  sentences.emplace_back();
+  while (!std::feof(file))
+  {
+    if (lineBuffer != std::fgets(lineBuffer, 100000, file))
+      break;
+
+    std::string_view line(lineBuffer);
+    sentences.back().emplace_back(line);
+
+    if (line.size() < 3)
+    {
+      if (!inputHasBeenRead)
+        continue;
+
+      sentences.emplace_back();
+      continue;
+    }
+
+    inputHasBeenRead = true;
+  }
+
+  if (sentences.back().empty())
+    sentences.pop_back();
+
+  std::fclose(file);
+  
+  return sentences;
+}
+
+
diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp
index e4194e8..2e4ed9f 100644
--- a/decoder/src/MacaonDecode.cpp
+++ b/decoder/src/MacaonDecode.cpp
@@ -96,25 +96,38 @@ int MacaonDecode::main()
     ReadingMachine machine(machinePath, false);
     Decoder decoder(machine);
 
-    if (inputTXT.empty())
+    std::vector<util::utf8string> rawInputs;
+    if (!inputTXT.empty())
+      rawInputs = util::readFileAsUtf8(inputTXT, lineByLine);
+
+    std::vector<std::vector<std::string>> tsv;
+    if (!inputTSV.empty())
+      tsv = util::readTSV(inputTSV);
+
+    std::vector<BaseConfig> configs;
+    if (lineByLine)
     {
-      BaseConfig config(mcd, inputTSV, util::utf8string(), std::vector<int>());
-      decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
-      config.print(stdout, true);
+      if (rawInputs.size())
+        for (unsigned int i = 0; i < rawInputs.size(); i++)
+          configs.emplace_back(mcd, tsv, rawInputs[i], std::vector<int>{(int)i});
+      else
+        for (unsigned int i = 0; i < tsv.size(); i++)
+          configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>{(int)i});
     }
     else
     {
-      auto inputs = util::readFileAsUtf8(inputTXT, lineByLine);
-      for (int i = 0; i < (int)inputs.size(); i++)
-      {
-        std::vector<int> sentIndexes;
-        if (inputs.size() > 1)
-          sentIndexes.emplace_back(i);
-        BaseConfig config(mcd, inputTSV, inputs[i], sentIndexes);
-        decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
-        config.print(stdout, i == 0);
-      }
+      if (rawInputs.size())
+        configs.emplace_back(mcd, tsv, rawInputs[0], std::vector<int>());
+      else
+        configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>());
+    }
+
+    for (unsigned int i = 0; i < configs.size(); i++)
+    {
+      decoder.decode(configs[i], beamSize, beamThreshold, debug, printAdvancement);
+      configs[i].print(stdout, i == 0);
     }
+
   } catch(std::exception & e) {util::error(e);}
 
   return 0;
diff --git a/reading_machine/include/BaseConfig.hpp b/reading_machine/include/BaseConfig.hpp
index a93b180..167b8fe 100644
--- a/reading_machine/include/BaseConfig.hpp
+++ b/reading_machine/include/BaseConfig.hpp
@@ -21,11 +21,11 @@ class BaseConfig : public Config
 
   void createColumns(std::string mcd);
   void readRawInput(std::string_view rawFilename);
-  void readTSVInput(std::string_view tsvFilename, const std::vector<int> & sentencesIndexes);
+  void readTSVInput(const std::vector<std::vector<std::string>> & sentences, const std::vector<int> & sentencesIndexes);
 
   public :
 
-  BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawFilename, const std::vector<int> & sentencesIndexes);
+  BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawFilename, const std::vector<int> & sentencesIndexes);
   BaseConfig(const BaseConfig & other);
   BaseConfig & operator=(const BaseConfig & other) = default;
 
diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp
index 69905a8..0d373c6 100644
--- a/reading_machine/src/BaseConfig.cpp
+++ b/reading_machine/src/BaseConfig.cpp
@@ -43,44 +43,8 @@ void BaseConfig::readRawInput(std::string_view rawFilename)
   rawInput.replace(util::utf8char("\t"), util::utf8char(" "));
 }
 
-void BaseConfig::readTSVInput(std::string_view tsvFilename, const std::vector<int> & sentencesIndexes)
+void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sentences, const std::vector<int> & sentencesIndexes)
 {
-  std::vector<std::vector<std::string>> sentences;
-
-  std::FILE * file = std::fopen(tsvFilename.data(), "r");
-
-  if (not file)
-    util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename));
-
-  char lineBuffer[100000];
-  bool inputHasBeenRead = false;
-
-  sentences.emplace_back();
-  while (!std::feof(file))
-  {
-    if (lineBuffer != std::fgets(lineBuffer, 100000, file))
-      break;
-
-    std::string_view line(lineBuffer);
-    sentences.back().emplace_back(line);
-
-    if (line.size() < 3)
-    {
-      if (!inputHasBeenRead)
-        continue;
-
-      sentences.emplace_back();
-      continue;
-    }
-
-    inputHasBeenRead = true;
-  }
-
-  if (sentences.back().empty())
-    sentences.pop_back();
-
-  std::fclose(file);
-
   for (unsigned int i = 0; i < (sentencesIndexes.size() ? sentencesIndexes.size() : sentences.size()); i++)
   {
     int targetSentenceIndex = sentencesIndexes.size() ? sentencesIndexes[i] : i;
@@ -115,7 +79,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename, const std::vector<in
       if (usualNbCol == -1)
         usualNbCol = splited.size();
       if ((int)splited.size() != usualNbCol)
-        util::myThrow(fmt::format("in file {} line {} is invalid, it shoud have {} columns", tsvFilename, line, usualNbCol));
+        util::myThrow(fmt::format("in tsv file line {} is invalid, it shoud have {} columns", line, usualNbCol));
 
       // Ignore empty nodes
       if (hasColIndex(idColName) && splited[getColIndex(idColName)].find('.') != std::string::npos)
@@ -182,10 +146,10 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name(
 {
 }
 
-BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes)
+BaseConfig::BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes)
 {
-  if (tsvFilename.empty() and rawInput.empty())
-    util::myThrow("tsvFilename and rawInput can't be both empty");
+  if (sentences.empty() and rawInput.empty())
+    util::myThrow("sentences and rawInput can't be both empty");
 
   createColumns(mcd);
 
@@ -193,8 +157,8 @@ BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util
     this->rawInput = rawInput;
 
 // sentencesIndexes = index of sentences to keep. Empty vector == keep all sentences.
-  if (not tsvFilename.empty())
-    readTSVInput(tsvFilename, sentencesIndexes);
+  if (not sentences.empty())
+    readTSVInput(sentences, sentencesIndexes);
 
   if (!has(0,wordIndex,0))
     addLines(1);
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 91d0ea8..02d58fd 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -117,29 +117,6 @@ inline auto decay(Optimizer &optimizer, double rate) -> void
   }
 }
 
-int countNbSentences(std::string_view filename)
-{
-  std::FILE * file = std::fopen(filename.data(), "r");
-
-  if (not file)
-    util::myThrow(fmt::format("Cannot open file '{}'", filename));
-
-  char lineBuffer[100000];
-  int nbSent = 0;
-
-  while (!std::feof(file))
-  {
-    if (lineBuffer != std::fgets(lineBuffer, 100000, file))
-      break;
-    if (std::string(lineBuffer).size() < 3)
-      nbSent++;
-  }
-
-  std::fclose(file);
-
-  return nbSent;
-}
-
 int MacaonTrain::main()
 {
   auto od = getOptionsDescription();
@@ -188,9 +165,13 @@ int MacaonTrain::main()
   try
   {
 
-  ReadingMachine machine(machinePath.string(), true);
+  std::vector<std::vector<std::string>> trainTsv, devTsv;
+  if (!trainTsvFile.empty())
+    trainTsv = util::readTSV(trainTsvFile);
+  if (!devTsvFile.empty())
+    devTsv = util::readTSV(devTsvFile);
 
-  int nbSentencesTrain = countNbSentences(trainTsvFile);
+  ReadingMachine machine(machinePath.string(), true);
 
   std::vector<util::utf8string> trainRawInputs;
   if (!trainRawFile.empty())
@@ -199,23 +180,20 @@ int MacaonTrain::main()
   if (lineByLine)
   {
     if (trainRawInputs.size())
-      for (int i = 0; i < (int)trainRawInputs.size(); i++)
-        goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>{i});
+      for (unsigned int i = 0; i < trainRawInputs.size(); i++)
+        goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[i], std::vector<int>{(int)i});
     else
-      for (int i = 0; i < nbSentencesTrain; i++)
-        goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>{i});
+      for (unsigned int i = 0; i < trainTsv.size(); i++)
+        goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>{(int)i});
   }
   else
   {
     if (trainRawInputs.size())
-      goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[0], std::vector<int>());
+      goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[0], std::vector<int>());
     else
-      goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>());
+      goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>());
   }
 
-
-  int nbSentencesDev = countNbSentences(devTsvFile);
-
   std::vector<util::utf8string> devRawInputs;
   if (!devRawFile.empty())
     devRawInputs = util::readFileAsUtf8(devRawFile, lineByLine);
@@ -223,18 +201,18 @@ int MacaonTrain::main()
   if (lineByLine)
   {
     if (devRawInputs.size())
-      for (int i = 0; i < (int)devRawInputs.size(); i++)
-        devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[i], std::vector<int>{i});
+      for (unsigned int i = 0; i < devRawInputs.size(); i++)
+        devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[i], std::vector<int>{(int)i});
     else
-      for (int i = 0; i < nbSentencesDev; i++)
-        devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>{i});
+      for (unsigned int i = 0; i < devTsv.size(); i++)
+        devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i});
   }
   else
   {
     if (devRawInputs.size())
-      devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[0], std::vector<int>());
+      devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[0], std::vector<int>());
     else
-      devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>());
+      devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>());
   }
     
   Trainer trainer(machine, batchSize);
@@ -322,13 +300,23 @@ int MacaonTrain::main()
     if (computeDevScore)
     {
       std::vector<BaseConfig> devConfigs;
-      for (int i = 0; i < (int)devRawInputs.size(); i++)
-        if (devRawInputs.size() == 1)
-          devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>());
+      if (lineByLine)
+      {
+        if (devRawInputs.size())
+          for (unsigned int i = 0; i < devRawInputs.size(); i++)
+            devConfigs.emplace_back(mcd, devTsv, devRawInputs[i], std::vector<int>{(int)i});
+        else
+          for (unsigned int i = 0; i < devTsv.size(); i++)
+            devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i});
+      }
+      else
+      {
+        if (devRawInputs.size())
+          devConfigs.emplace_back(mcd, devTsv, devRawInputs[0], std::vector<int>());
         else
-          devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>{i});
-      if (devConfigs.empty())
-        devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, util::utf8string(), std::vector<int>());
+          devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>());
+      }
+
       for (auto & devConfig : devConfigs)
         decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
 
-- 
GitLab