From 765b687de555e1271218f7ff5e812eda9e3a17fe Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 12 Feb 2020 16:18:30 +0100 Subject: [PATCH] ReadingMachine now has list of predicted columns --- common/include/Dict.hpp | 6 +-- common/src/Dict.cpp | 6 +-- reading_machine/include/FeatureFunction.hpp | 28 ------------- reading_machine/include/ReadingMachine.hpp | 6 +-- reading_machine/src/FeatureFunction.cpp | 45 --------------------- reading_machine/src/ReadingMachine.cpp | 19 ++++++--- 6 files changed, 23 insertions(+), 87 deletions(-) delete mode 100644 reading_machine/include/FeatureFunction.hpp delete mode 100644 reading_machine/src/FeatureFunction.cpp diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index d87df12..fb005e7 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -36,10 +36,10 @@ class Dict void insert(const std::string & element); int getIndexOrInsert(const std::string & element); void setState(State state); - State getState(); - void save(std::FILE * destination, Encoding encoding); + State getState() const; + void save(std::FILE * destination, Encoding encoding) const; bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding); - void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding); + void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const; }; #endif diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 74eac88..d09f149 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -79,12 +79,12 @@ void Dict::setState(State state) this->state = state; } -Dict::State Dict::getState() +Dict::State Dict::getState() const { return state; } -void Dict::save(std::FILE * destination, Encoding encoding) +void Dict::save(std::FILE * destination, Encoding encoding) const { fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary"); fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size()); @@ -114,7 +114,7 @@ bool Dict::readEntry(std::FILE * file, int * index, char * entry, Encoding encod } } -void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) +void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const { if (encoding == Encoding::Ascii) { diff --git a/reading_machine/include/FeatureFunction.hpp b/reading_machine/include/FeatureFunction.hpp deleted file mode 100644 index ed860b0..0000000 --- a/reading_machine/include/FeatureFunction.hpp +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef FEATUREFUNCTION__H -#define FEATUREFUNCTION__H - -#include <map> -#include <string> -#include "Config.hpp" - -class FeatureFunction -{ - using Representation = std::vector<std::size_t>; - using Feature = std::function<Config::String(const Config &)>; - - private : - - std::map<std::string, Feature> features; - std::map<Config::String, std::size_t> dictionary; - - private : - - const Feature & getOrCreateFeature(const std::string & name); - - public : - - FeatureFunction(const std::vector<std::string_view> & lines); - Representation getRepresentation(const Config & config) const; -}; - -#endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 1c08bd8..3e9eaf5 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -5,7 +5,6 @@ #include <memory> #include "Classifier.hpp" #include "Strategy.hpp" -#include "FeatureFunction.hpp" #include "Dict.hpp" class ReadingMachine @@ -23,8 +22,8 @@ class ReadingMachine std::filesystem::path path; std::unique_ptr<Classifier> classifier; std::unique_ptr<Strategy> strategy; - std::unique_ptr<FeatureFunction> featureFunction; std::map<std::string, Dict> dicts; + std::set<std::string> predicted; private : @@ -38,7 +37,8 @@ class ReadingMachine Strategy & getStrategy(); Dict & getDict(const std::string & state); Classifier * getClassifier(); - void save(); + void save() const; + bool isPredicted(const std::string & columnName) const; }; #endif diff --git a/reading_machine/src/FeatureFunction.cpp b/reading_machine/src/FeatureFunction.cpp deleted file mode 100644 index c516220..0000000 --- a/reading_machine/src/FeatureFunction.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "FeatureFunction.hpp" - -FeatureFunction::FeatureFunction(const std::vector<std::string_view> & lines) -{ - if (!util::doIfNameMatch(std::regex("Features :(.*)"), lines[0], [](auto){})) - util::myThrow(fmt::format("Wrong line '{}', expected 'Features :'", lines[0])); - - for (unsigned int i = 1; i < lines.size(); i++) - { - if (util::doIfNameMatch(std::regex("(?: |\\t)*buffer from ((?:-|\\+|)\\d+) to ((?:-|\\+|)\\d+)"), lines[i], [this](auto &sm) - { - getOrCreateFeature(fmt::format("b.")); - })) - continue; - - util::myThrow(fmt::format("Unknown feature directive '{}'", lines[i])); - } - - for (auto & it : features) - fmt::print("{}\n", it.first); -} - -FeatureFunction::Representation FeatureFunction::getRepresentation(const Config & config) const -{ - Representation representation; - - return representation; -} - -const FeatureFunction::Feature & FeatureFunction::getOrCreateFeature(const std::string & name) -{ - auto found = features.find(name); - - if (found != features.end()) - return found->second; - - if (util::doIfNameMatch(std::regex(""), name, [this,name](auto){features[name] = Feature();})) - return features[name]; - - - util::myThrow(fmt::format("Unknown feature '{}'", name)); - - return found->second; -} - diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index c9d5f6d..0a7838c 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -51,11 +51,15 @@ void ReadingMachine::readFromFile(std::filesystem::path path) util::myThrow("No Classifier specified"); --curLine; - //std::vector<std::string_view> restOfFile; - //while (curLine < lines.size() and !util::doIfNameMatch(std::regex("Strategy(.*)"),lines[curLine], [](auto){})) - // restOfFile.emplace_back(lines[curLine++]); - //featureFunction.reset(new FeatureFunction(restOfFile)); + if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm) + { + auto predictions = std::string(sm[1]); + auto splited = util::split(predictions, ' '); + for (auto & prediction : splited) + predicted.insert(std::string(prediction)); + })) + util::myThrow("No predictions specified"); auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end()); @@ -92,7 +96,7 @@ Classifier * ReadingMachine::getClassifier() return classifier.get(); } -void ReadingMachine::save() +void ReadingMachine::save() const { for (auto & it : dicts) { @@ -110,3 +114,8 @@ void ReadingMachine::save() torch::save(classifier->getNN(), pathToClassifier); } +bool ReadingMachine::isPredicted(const std::string & columnName) const +{ + return predicted.count(columnName); +} + -- GitLab