From 765b687de555e1271218f7ff5e812eda9e3a17fe Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 12 Feb 2020 16:18:30 +0100
Subject: [PATCH] ReadingMachine now has list of predicted columns

---
 common/include/Dict.hpp                     |  6 +--
 common/src/Dict.cpp                         |  6 +--
 reading_machine/include/FeatureFunction.hpp | 28 -------------
 reading_machine/include/ReadingMachine.hpp  |  6 +--
 reading_machine/src/FeatureFunction.cpp     | 45 ---------------------
 reading_machine/src/ReadingMachine.cpp      | 19 ++++++---
 6 files changed, 23 insertions(+), 87 deletions(-)
 delete mode 100644 reading_machine/include/FeatureFunction.hpp
 delete mode 100644 reading_machine/src/FeatureFunction.cpp

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index d87df12..fb005e7 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -36,10 +36,10 @@ class Dict
   void insert(const std::string & element);
   int getIndexOrInsert(const std::string & element);
   void setState(State state);
-  State getState();
-  void save(std::FILE * destination, Encoding encoding);
+  State getState() const;
+  void save(std::FILE * destination, Encoding encoding) const;
   bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding);
-  void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding);
+  void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
 };
 
 #endif
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 74eac88..d09f149 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -79,12 +79,12 @@ void Dict::setState(State state)
   this->state = state;
 }
 
-Dict::State Dict::getState()
+Dict::State Dict::getState() const
 {
   return state;
 }
 
-void Dict::save(std::FILE * destination, Encoding encoding)
+void Dict::save(std::FILE * destination, Encoding encoding) const
 {
   fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
   fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
@@ -114,7 +114,7 @@ bool Dict::readEntry(std::FILE * file, int * index, char * entry, Encoding encod
   }
 }
 
-void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding)
+void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const
 {
   if (encoding == Encoding::Ascii)
   {
diff --git a/reading_machine/include/FeatureFunction.hpp b/reading_machine/include/FeatureFunction.hpp
deleted file mode 100644
index ed860b0..0000000
--- a/reading_machine/include/FeatureFunction.hpp
+++ /dev/null
@@ -1,28 +0,0 @@
-#ifndef FEATUREFUNCTION__H
-#define FEATUREFUNCTION__H
-
-#include <map>
-#include <string>
-#include "Config.hpp"
-
-class FeatureFunction
-{
-  using Representation = std::vector<std::size_t>;
-  using Feature = std::function<Config::String(const Config &)>;
-
-  private :
-
-  std::map<std::string, Feature> features;
-  std::map<Config::String, std::size_t> dictionary;
-
-  private :
-
-  const Feature & getOrCreateFeature(const std::string & name);
-
-  public :
-
-  FeatureFunction(const std::vector<std::string_view> & lines);
-  Representation getRepresentation(const Config & config) const;
-};
-
-#endif
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 1c08bd8..3e9eaf5 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -5,7 +5,6 @@
 #include <memory>
 #include "Classifier.hpp"
 #include "Strategy.hpp"
-#include "FeatureFunction.hpp"
 #include "Dict.hpp"
 
 class ReadingMachine
@@ -23,8 +22,8 @@ class ReadingMachine
   std::filesystem::path path;
   std::unique_ptr<Classifier> classifier;
   std::unique_ptr<Strategy> strategy;
-  std::unique_ptr<FeatureFunction> featureFunction;
   std::map<std::string, Dict> dicts;
+  std::set<std::string> predicted;
 
   private :
 
@@ -38,7 +37,8 @@ class ReadingMachine
   Strategy & getStrategy();
   Dict & getDict(const std::string & state);
   Classifier * getClassifier();
-  void save();
+  void save() const;
+  bool isPredicted(const std::string & columnName) const;
 };
 
 #endif
diff --git a/reading_machine/src/FeatureFunction.cpp b/reading_machine/src/FeatureFunction.cpp
deleted file mode 100644
index c516220..0000000
--- a/reading_machine/src/FeatureFunction.cpp
+++ /dev/null
@@ -1,45 +0,0 @@
-#include "FeatureFunction.hpp"
-
-FeatureFunction::FeatureFunction(const std::vector<std::string_view> & lines)
-{
-  if (!util::doIfNameMatch(std::regex("Features :(.*)"), lines[0], [](auto){}))
-  util::myThrow(fmt::format("Wrong line '{}', expected 'Features :'", lines[0]));
-
-  for (unsigned int i = 1; i < lines.size(); i++)
-  {
-    if (util::doIfNameMatch(std::regex("(?: |\\t)*buffer from ((?:-|\\+|)\\d+) to ((?:-|\\+|)\\d+)"), lines[i], [this](auto &sm)
-    {
-      getOrCreateFeature(fmt::format("b."));
-    }))
-      continue;
-
-    util::myThrow(fmt::format("Unknown feature directive '{}'", lines[i]));
-  }
-
-  for (auto & it : features)
-    fmt::print("{}\n", it.first);
-}
-
-FeatureFunction::Representation FeatureFunction::getRepresentation(const Config & config) const
-{
-  Representation representation;
-
-  return representation;
-}
-
-const FeatureFunction::Feature & FeatureFunction::getOrCreateFeature(const std::string & name)
-{
-  auto found = features.find(name);
-
-  if (found != features.end())
-    return found->second;
-
-  if (util::doIfNameMatch(std::regex(""), name, [this,name](auto){features[name] = Feature();}))
-    return features[name];
-
-
-  util::myThrow(fmt::format("Unknown feature '{}'", name));
-
-  return found->second;
-}
-
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index c9d5f6d..0a7838c 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -51,11 +51,15 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
       util::myThrow("No Classifier specified");
 
     --curLine;
-    //std::vector<std::string_view> restOfFile;
-    //while (curLine < lines.size() and !util::doIfNameMatch(std::regex("Strategy(.*)"),lines[curLine], [](auto){}))
-    //  restOfFile.emplace_back(lines[curLine++]);
 
-    //featureFunction.reset(new FeatureFunction(restOfFile));
+    if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
+    {
+      auto predictions = std::string(sm[1]);
+      auto splited = util::split(predictions, ' ');
+      for (auto & prediction : splited)
+        predicted.insert(std::string(prediction));
+    }))
+      util::myThrow("No predictions specified");
 
     auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end());
 
@@ -92,7 +96,7 @@ Classifier * ReadingMachine::getClassifier()
   return classifier.get();
 }
 
-void ReadingMachine::save()
+void ReadingMachine::save() const
 {
   for (auto & it : dicts)
   {
@@ -110,3 +114,8 @@ void ReadingMachine::save()
   torch::save(classifier->getNN(), pathToClassifier);
 }
 
+bool ReadingMachine::isPredicted(const std::string & columnName) const
+{
+  return predicted.count(columnName);
+}
+
-- 
GitLab