From bc2ede62673fecb114423419af163c225283a337 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 12 Feb 2020 15:13:41 +0100
Subject: [PATCH] macaon_decode is working

---
 common/include/util.hpp                    |  3 ++
 common/src/Dict.cpp                        | 10 ++---
 common/src/util.cpp                        | 15 +++++++
 decoder/src/Decoder.cpp                    |  3 ++
 decoder/src/macaon_decode.cpp              | 24 ++++++++---
 reading_machine/include/Classifier.hpp     |  1 +
 reading_machine/include/ReadingMachine.hpp | 14 +++++--
 reading_machine/src/Classifier.cpp         |  5 +++
 reading_machine/src/ReadingMachine.cpp     | 47 ++++++++++++++++++----
 trainer/src/macaon_train.cpp               | 11 ++++-
 10 files changed, 109 insertions(+), 24 deletions(-)

diff --git a/common/include/util.hpp b/common/include/util.hpp
index 6b2d077..efe509a 100644
--- a/common/include/util.hpp
+++ b/common/include/util.hpp
@@ -19,6 +19,7 @@
 #include <array>
 #include <unordered_map>
 #include <regex>
+#include <filesystem>
 #include <experimental/source_location>
 #include <boost/flyweight.hpp>
 #include "fmt/core.h"
@@ -33,6 +34,8 @@ void error(std::string_view message, const std::experimental::source_location &
 void error(const std::exception & e, const std::experimental::source_location & location = std::experimental::source_location::current());
 void myThrow(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
 
+std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path directory, std::string extension);
+
 std::string_view getFilenameFromPath(std::string_view s);
 
 std::vector<std::string_view> split(std::string_view s, char delimiter);
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index a02edf3..74eac88 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -19,11 +19,11 @@ void Dict::readFromFile(const char * filename)
   std::FILE * file = std::fopen(filename, "r");
 
   if (!file)
-    util::myThrow(fmt::format("could not open file \'%s\'", filename));
+    util::myThrow(fmt::format("could not open file '{}'", filename));
 
   char buffer[1048];
   if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1)
-    util::myThrow(fmt::format("file \'%s\' bad format", filename));
+    util::myThrow(fmt::format("file '{}' bad format", filename));
 
   Encoding encoding{Encoding::Ascii};
   if (std::string(buffer) == "Ascii")
@@ -31,12 +31,12 @@ void Dict::readFromFile(const char * filename)
   else if (std::string(buffer) == "Binary")
     encoding = Encoding::Binary;
   else
-    util::myThrow(fmt::format("file \'%s\' bad format", filename));
+    util::myThrow(fmt::format("file '{}' bad format", filename));
 
   int nbEntries;
 
   if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1)
-    util::myThrow(fmt::format("file \'%s\' bad format", filename));
+    util::myThrow(fmt::format("file '{}' bad format", filename));
 
   elementsToIndexes.reserve(nbEntries);
 
@@ -45,7 +45,7 @@ void Dict::readFromFile(const char * filename)
   for (int i = 0; i < nbEntries; i++)
   {
     if (!readEntry(file, &entryIndex, entryString, encoding))
-      util::myThrow(fmt::format("file \'%s\' line {} bad format", filename, i));
+      util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
 
     elementsToIndexes[entryString] = entryIndex;
   }
diff --git a/common/src/util.cpp b/common/src/util.cpp
index 379056f..13d4122 100644
--- a/common/src/util.cpp
+++ b/common/src/util.cpp
@@ -171,3 +171,18 @@ std::string util::strip(const std::string & s)
   return std::string(s.begin()+first, s.begin()+last+1);
 }
 
+std::vector<std::filesystem::path> util::findFilesByExtension(std::filesystem::path directory, std::string extension)
+{
+  std::vector<std::filesystem::path> files;
+
+  for (auto entry : std::filesystem::directory_iterator(directory))
+    if (entry.is_regular_file())
+    {
+      auto path = entry.path();
+      if (path.extension() == extension)
+        files.push_back(path);
+    }
+
+  return files;
+}
+
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index eea1ef7..664e66e 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -7,6 +7,8 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 
 void Decoder::decode(BaseConfig & config, std::size_t beamSize)
 {
+  try
+  {
   config.setState(machine.getStrategy().getInitialState());
 
   fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
@@ -42,6 +44,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize)
     if (!config.moveWordIndex(movement.second))
       util::myThrow("Cannot move word index !");
   }
+  } catch(std::exception & e) {util::myThrow(e.what());}
 }
 
 float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex)
diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp
index c3c3cbf..8897d40 100644
--- a/decoder/src/macaon_decode.cpp
+++ b/decoder/src/macaon_decode.cpp
@@ -64,19 +64,31 @@ int main(int argc, char * argv[])
   auto variables = checkOptions(od, argc, argv);
 
   std::filesystem::path modelPath(variables["model"].as<std::string>());
-  auto machinePath = modelPath / ReadingMachine::defaultMachineName;
+  auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
+  auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
+  auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
   auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
   auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
   auto mcdFile = variables["mcd"].as<std::string>();
 
-  ReadingMachine machine(machinePath.string());
-  Decoder decoder(machine);
+  if (dictPaths.empty())
+    util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
+  if (modelPaths.empty())
+    util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
 
-  BaseConfig config(mcdFile, inputTSV, inputTXT);
+  try
+  {
+    ReadingMachine machine(machinePath, modelPaths, dictPaths);
+    Decoder decoder(machine);
+
+    BaseConfig config(mcdFile, inputTSV, inputTXT);
 
-  decoder.decode(config, 1);
+    decoder.decode(config, 1);
 
-  config.print(stdout);
+    fmt::print(stderr, "\n");
+  
+    config.print(stdout);
+  } catch(std::exception & e) {util::error(e);}
 
   return 0;
 }
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 0e8b120..35f0611 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -18,6 +18,7 @@ class Classifier
   Classifier(const std::string & name, const std::string & topology, const std::string & tsFile);
   TransitionSet & getTransitionSet();
   TestNetwork & getNN();
+  const std::string & getName() const;
 };
 
 #endif
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 8db2de1..1c08bd8 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -12,9 +12,10 @@ class ReadingMachine
 {
   public :
 
-  static inline const std::string defaultMachineName = "machine.rm";
-  static inline const std::string defaultModelName = "{}.pt";
-  static inline const std::string defaultDictName = "{}.dict";
+  static inline const std::string defaultMachineFilename = "machine.rm";
+  static inline const std::string defaultModelFilename = "{}.pt";
+  static inline const std::string defaultDictFilename = "{}.dict";
+  static inline const std::string defaultDictName = "_default_";
 
   private :
 
@@ -25,14 +26,19 @@ class ReadingMachine
   std::unique_ptr<FeatureFunction> featureFunction;
   std::map<std::string, Dict> dicts;
 
+  private :
+
+  void readFromFile(std::filesystem::path path);
+
   public :
 
   ReadingMachine(std::filesystem::path path);
-  ReadingMachine(const std::string & filename, const std::vector<std::string> & models, const std::vector<std::string> & dicts);
+  ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts);
   TransitionSet & getTransitionSet();
   Strategy & getStrategy();
   Dict & getDict(const std::string & state);
   Classifier * getClassifier();
+  void save();
 };
 
 #endif
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 34c5f60..13e25b4 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -17,3 +17,8 @@ TestNetwork & Classifier::getNN()
   return nn;
 }
 
+const std::string & Classifier::getName() const
+{
+  return name;
+}
+
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index d1d9da6..c9d5f6d 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -3,8 +3,23 @@
 
 ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
 {
-  dicts.emplace(std::make_pair("", Dict::State::Open));
+  dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
 
+  readFromFile(path);
+}
+
+ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
+{
+  readFromFile(path);
+
+  for (auto path : dicts)
+    this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed});
+
+  torch::load(classifier->getNN(), models[0]);
+}
+
+void ReadingMachine::readFromFile(std::filesystem::path path)
+{
   std::FILE * file = std::fopen(path.c_str(), "r");
 
   char buffer[1024];
@@ -49,11 +64,6 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
   } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
 }
 
-ReadingMachine::ReadingMachine(const std::string & filename, const std::vector<std::string> & models, const std::vector<std::string> & dicts)
-{
-
-}
-
 TransitionSet & ReadingMachine::getTransitionSet()
 {
   return classifier->getTransitionSet();
@@ -68,8 +78,11 @@ Dict & ReadingMachine::getDict(const std::string & state)
 {
   auto found = dicts.find(state);
 
-  if (found == dicts.end())
-    return dicts.at("");
+  try
+  {
+    if (found == dicts.end())
+      return dicts.at(defaultDictName);
+  } catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));}
 
   return found->second;
 }
@@ -79,3 +92,21 @@ Classifier * ReadingMachine::getClassifier()
   return classifier.get();
 }
 
+void ReadingMachine::save()
+{
+  for (auto & it : dicts)
+  {
+    auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first);
+    std::FILE * file = std::fopen(pathToDict.c_str(), "w");
+    if (!file)
+      util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str()));
+
+    it.second.save(file, Dict::Encoding::Ascii);
+
+    std::fclose(file);
+  }
+
+  auto pathToClassifier = path.parent_path() / fmt::format(defaultModelFilename, classifier->getName());
+  torch::save(classifier->getNN(), pathToClassifier);
+}
+
diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp
index 896b9ae..d58127d 100644
--- a/trainer/src/macaon_train.cpp
+++ b/trainer/src/macaon_train.cpp
@@ -82,13 +82,22 @@ int main(int argc, char * argv[])
   Decoder decoder(machine);
   BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
 
+  float bestDevScore = 0;
+
   for (int i = 0; i < nbEpoch; i++)
   {
     float loss = trainer.epoch();
     auto devConfig = devGoldConfig;
     decoder.decode(devConfig, 1);
     decoder.evaluate(devConfig, modelPath, devTsvFile);
-    fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS"));
+    float devScore = decoder.getF1Score("UPOS");
+    bool saved = devScore > bestDevScore;
+    if (saved)
+    {
+      bestDevScore = devScore;
+      machine.save();
+    }
+    fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
   }
 
   return 0;
-- 
GitLab