Skip to content
Snippets Groups Projects
Commit 765b687d authored by Franck Dary's avatar Franck Dary
Browse files

ReadingMachine now has list of predicted columns

parent bc2ede62
No related branches found
No related tags found
No related merge requests found
...@@ -36,10 +36,10 @@ class Dict ...@@ -36,10 +36,10 @@ class Dict
void insert(const std::string & element); void insert(const std::string & element);
int getIndexOrInsert(const std::string & element); int getIndexOrInsert(const std::string & element);
void setState(State state); void setState(State state);
State getState(); State getState() const;
void save(std::FILE * destination, Encoding encoding); void save(std::FILE * destination, Encoding encoding) const;
bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding); bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding);
void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding); void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
}; };
#endif #endif
...@@ -79,12 +79,12 @@ void Dict::setState(State state) ...@@ -79,12 +79,12 @@ void Dict::setState(State state)
this->state = state; this->state = state;
} }
Dict::State Dict::getState() Dict::State Dict::getState() const
{ {
return state; return state;
} }
void Dict::save(std::FILE * destination, Encoding encoding) void Dict::save(std::FILE * destination, Encoding encoding) const
{ {
fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary"); fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size()); fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
...@@ -114,7 +114,7 @@ bool Dict::readEntry(std::FILE * file, int * index, char * entry, Encoding encod ...@@ -114,7 +114,7 @@ bool Dict::readEntry(std::FILE * file, int * index, char * entry, Encoding encod
} }
} }
void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const
{ {
if (encoding == Encoding::Ascii) if (encoding == Encoding::Ascii)
{ {
......
#ifndef FEATUREFUNCTION__H
#define FEATUREFUNCTION__H
#include <map>
#include <string>
#include "Config.hpp"
class FeatureFunction
{
using Representation = std::vector<std::size_t>;
using Feature = std::function<Config::String(const Config &)>;
private :
std::map<std::string, Feature> features;
std::map<Config::String, std::size_t> dictionary;
private :
const Feature & getOrCreateFeature(const std::string & name);
public :
FeatureFunction(const std::vector<std::string_view> & lines);
Representation getRepresentation(const Config & config) const;
};
#endif
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <memory> #include <memory>
#include "Classifier.hpp" #include "Classifier.hpp"
#include "Strategy.hpp" #include "Strategy.hpp"
#include "FeatureFunction.hpp"
#include "Dict.hpp" #include "Dict.hpp"
class ReadingMachine class ReadingMachine
...@@ -23,8 +22,8 @@ class ReadingMachine ...@@ -23,8 +22,8 @@ class ReadingMachine
std::filesystem::path path; std::filesystem::path path;
std::unique_ptr<Classifier> classifier; std::unique_ptr<Classifier> classifier;
std::unique_ptr<Strategy> strategy; std::unique_ptr<Strategy> strategy;
std::unique_ptr<FeatureFunction> featureFunction;
std::map<std::string, Dict> dicts; std::map<std::string, Dict> dicts;
std::set<std::string> predicted;
private : private :
...@@ -38,7 +37,8 @@ class ReadingMachine ...@@ -38,7 +37,8 @@ class ReadingMachine
Strategy & getStrategy(); Strategy & getStrategy();
Dict & getDict(const std::string & state); Dict & getDict(const std::string & state);
Classifier * getClassifier(); Classifier * getClassifier();
void save(); void save() const;
bool isPredicted(const std::string & columnName) const;
}; };
#endif #endif
#include "FeatureFunction.hpp"
FeatureFunction::FeatureFunction(const std::vector<std::string_view> & lines)
{
if (!util::doIfNameMatch(std::regex("Features :(.*)"), lines[0], [](auto){}))
util::myThrow(fmt::format("Wrong line '{}', expected 'Features :'", lines[0]));
for (unsigned int i = 1; i < lines.size(); i++)
{
if (util::doIfNameMatch(std::regex("(?: |\\t)*buffer from ((?:-|\\+|)\\d+) to ((?:-|\\+|)\\d+)"), lines[i], [this](auto &sm)
{
getOrCreateFeature(fmt::format("b."));
}))
continue;
util::myThrow(fmt::format("Unknown feature directive '{}'", lines[i]));
}
for (auto & it : features)
fmt::print("{}\n", it.first);
}
FeatureFunction::Representation FeatureFunction::getRepresentation(const Config & config) const
{
Representation representation;
return representation;
}
const FeatureFunction::Feature & FeatureFunction::getOrCreateFeature(const std::string & name)
{
auto found = features.find(name);
if (found != features.end())
return found->second;
if (util::doIfNameMatch(std::regex(""), name, [this,name](auto){features[name] = Feature();}))
return features[name];
util::myThrow(fmt::format("Unknown feature '{}'", name));
return found->second;
}
...@@ -51,11 +51,15 @@ void ReadingMachine::readFromFile(std::filesystem::path path) ...@@ -51,11 +51,15 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
util::myThrow("No Classifier specified"); util::myThrow("No Classifier specified");
--curLine; --curLine;
//std::vector<std::string_view> restOfFile;
//while (curLine < lines.size() and !util::doIfNameMatch(std::regex("Strategy(.*)"),lines[curLine], [](auto){}))
// restOfFile.emplace_back(lines[curLine++]);
//featureFunction.reset(new FeatureFunction(restOfFile)); if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
{
auto predictions = std::string(sm[1]);
auto splited = util::split(predictions, ' ');
for (auto & prediction : splited)
predicted.insert(std::string(prediction));
}))
util::myThrow("No predictions specified");
auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end()); auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end());
...@@ -92,7 +96,7 @@ Classifier * ReadingMachine::getClassifier() ...@@ -92,7 +96,7 @@ Classifier * ReadingMachine::getClassifier()
return classifier.get(); return classifier.get();
} }
void ReadingMachine::save() void ReadingMachine::save() const
{ {
for (auto & it : dicts) for (auto & it : dicts)
{ {
...@@ -110,3 +114,8 @@ void ReadingMachine::save() ...@@ -110,3 +114,8 @@ void ReadingMachine::save()
torch::save(classifier->getNN(), pathToClassifier); torch::save(classifier->getNN(), pathToClassifier);
} }
bool ReadingMachine::isPredicted(const std::string & columnName) const
{
return predicted.count(columnName);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment