Skip to content
Snippets Groups Projects
Commit 765b687d authored by Franck Dary's avatar Franck Dary
Browse files

ReadingMachine now has list of predicted columns

parent bc2ede62
No related branches found
No related tags found
No related merge requests found
......@@ -36,10 +36,10 @@ class Dict
void insert(const std::string & element);
int getIndexOrInsert(const std::string & element);
void setState(State state);
State getState();
void save(std::FILE * destination, Encoding encoding);
State getState() const;
void save(std::FILE * destination, Encoding encoding) const;
bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding);
void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding);
void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
};
#endif
......@@ -79,12 +79,12 @@ void Dict::setState(State state)
this->state = state;
}
Dict::State Dict::getState()
Dict::State Dict::getState() const
{
return state;
}
void Dict::save(std::FILE * destination, Encoding encoding)
void Dict::save(std::FILE * destination, Encoding encoding) const
{
fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
......@@ -114,7 +114,7 @@ bool Dict::readEntry(std::FILE * file, int * index, char * entry, Encoding encod
}
}
void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding)
void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const
{
if (encoding == Encoding::Ascii)
{
......
#ifndef FEATUREFUNCTION__H
#define FEATUREFUNCTION__H
#include <map>
#include <string>
#include "Config.hpp"
class FeatureFunction
{
using Representation = std::vector<std::size_t>;
using Feature = std::function<Config::String(const Config &)>;
private :
std::map<std::string, Feature> features;
std::map<Config::String, std::size_t> dictionary;
private :
const Feature & getOrCreateFeature(const std::string & name);
public :
FeatureFunction(const std::vector<std::string_view> & lines);
Representation getRepresentation(const Config & config) const;
};
#endif
......@@ -5,7 +5,6 @@
#include <memory>
#include "Classifier.hpp"
#include "Strategy.hpp"
#include "FeatureFunction.hpp"
#include "Dict.hpp"
class ReadingMachine
......@@ -23,8 +22,8 @@ class ReadingMachine
std::filesystem::path path;
std::unique_ptr<Classifier> classifier;
std::unique_ptr<Strategy> strategy;
std::unique_ptr<FeatureFunction> featureFunction;
std::map<std::string, Dict> dicts;
std::set<std::string> predicted;
private :
......@@ -38,7 +37,8 @@ class ReadingMachine
Strategy & getStrategy();
Dict & getDict(const std::string & state);
Classifier * getClassifier();
void save();
void save() const;
bool isPredicted(const std::string & columnName) const;
};
#endif
#include "FeatureFunction.hpp"
FeatureFunction::FeatureFunction(const std::vector<std::string_view> & lines)
{
if (!util::doIfNameMatch(std::regex("Features :(.*)"), lines[0], [](auto){}))
util::myThrow(fmt::format("Wrong line '{}', expected 'Features :'", lines[0]));
for (unsigned int i = 1; i < lines.size(); i++)
{
if (util::doIfNameMatch(std::regex("(?: |\\t)*buffer from ((?:-|\\+|)\\d+) to ((?:-|\\+|)\\d+)"), lines[i], [this](auto &sm)
{
getOrCreateFeature(fmt::format("b."));
}))
continue;
util::myThrow(fmt::format("Unknown feature directive '{}'", lines[i]));
}
for (auto & it : features)
fmt::print("{}\n", it.first);
}
FeatureFunction::Representation FeatureFunction::getRepresentation(const Config & config) const
{
Representation representation;
return representation;
}
const FeatureFunction::Feature & FeatureFunction::getOrCreateFeature(const std::string & name)
{
auto found = features.find(name);
if (found != features.end())
return found->second;
if (util::doIfNameMatch(std::regex(""), name, [this,name](auto){features[name] = Feature();}))
return features[name];
util::myThrow(fmt::format("Unknown feature '{}'", name));
return found->second;
}
......@@ -51,11 +51,15 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
util::myThrow("No Classifier specified");
--curLine;
//std::vector<std::string_view> restOfFile;
//while (curLine < lines.size() and !util::doIfNameMatch(std::regex("Strategy(.*)"),lines[curLine], [](auto){}))
// restOfFile.emplace_back(lines[curLine++]);
//featureFunction.reset(new FeatureFunction(restOfFile));
if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
{
auto predictions = std::string(sm[1]);
auto splited = util::split(predictions, ' ');
for (auto & prediction : splited)
predicted.insert(std::string(prediction));
}))
util::myThrow("No predictions specified");
auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end());
......@@ -92,7 +96,7 @@ Classifier * ReadingMachine::getClassifier()
return classifier.get();
}
void ReadingMachine::save()
void ReadingMachine::save() const
{
for (auto & it : dicts)
{
......@@ -110,3 +114,8 @@ void ReadingMachine::save()
torch::save(classifier->getNN(), pathToClassifier);
}
bool ReadingMachine::isPredicted(const std::string & columnName) const
{
return predicted.count(columnName);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment