-
Franck Dary authoredFranck Dary authored
ReadingMachine.cpp 5.37 KiB
#include "ReadingMachine.hpp"
#include "util.hpp"
ReadingMachine::ReadingMachine(std::filesystem::path path, bool train) : path(path), train(train)
{
readFromFile(path);
}
void ReadingMachine::readFromFile(std::filesystem::path path)
{
std::FILE * file = std::fopen(path.c_str(), "r");
if (not file)
util::myThrow(fmt::format("can't open file '{}'", path.string()));
char buffer[1024];
std::string fileContent;
std::vector<std::string> lines;
while (!std::feof(file))
{
if (buffer != std::fgets(buffer, 1024, file))
break;
// If line is blank or commented (# or //), ignore it
if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){}))
continue;
if (buffer[std::strlen(buffer)-1] == '\n')
buffer[std::strlen(buffer)-1] = '\0';
lines.emplace_back(buffer);
}
std::fclose(file);
try
{
unsigned int curLine = 0;
if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
util::myThrow("No name specified");
while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine], [this,path,&lines,&curLine](auto sm)
{
curLine++;
classifierDefinitions.emplace_back();
classifierNames.emplace_back(sm.str(1));
if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
for (curLine++; curLine < lines.size(); curLine++)
{
if (lines[curLine] == "}")
{
curLine++;
break;
}
classifierDefinitions.back().emplace_back(lines[curLine]);
}
classifiers.emplace_back(new Classifier(sm.str(1), path.parent_path(), classifierDefinitions.back(), train));
for (auto state : classifiers.back()->getStates())
state2classifier[state] = classifiers.size()-1;
}));
if (classifiers.empty())
util::myThrow("No Classifier specified");
util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
{
this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1)));
curLine++;
});
if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
{
auto predictions = sm.str(1);
auto splited = util::split(predictions, ' ');
for (auto & prediction : splited)
predicted.insert(std::string(prediction));
}))
util::myThrow("No predictions specified");
if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm)
{
strategyDefinition.clear();
if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
for (curLine++; curLine < lines.size(); curLine++)
{
if (lines[curLine] == "}")
break;
strategyDefinition.emplace_back(lines[curLine]);
}
}))
util::myThrow("No Strategy specified");
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
}
TransitionSet & ReadingMachine::getTransitionSet(const std::string & state)
{
return classifiers[state2classifier.at(state)]->getTransitionSet();
}
bool ReadingMachine::hasSplitWordTransitionSet() const
{
return splitWordTransitionSet.get() != nullptr;
}
TransitionSet & ReadingMachine::getSplitWordTransitionSet()
{
return *splitWordTransitionSet;
}
const std::vector<std::string> & ReadingMachine::getStrategyDefinition() const
{
return strategyDefinition;
}
Classifier * ReadingMachine::getClassifier(const std::string & state)
{
return classifiers[state2classifier.at(state)].get();
}
void ReadingMachine::saveDicts() const
{
for (auto & classifier : classifiers)
classifier->saveDicts();
}
void ReadingMachine::saveBest() const
{
saveDicts();
for (auto & classifier : classifiers)
classifier->saveBest();
}
void ReadingMachine::saveLast() const
{
saveDicts();
for (auto & classifier : classifiers)
classifier->saveLast();
}
bool ReadingMachine::isPredicted(const std::string & columnName) const
{
return predicted.count(columnName);
}
const std::set<std::string> & ReadingMachine::getPredicted() const
{
return predicted;
}
void ReadingMachine::trainMode(bool isTrainMode)
{
for (auto & classifier : classifiers)
classifier->getNN()->train(isTrainMode);
}
void ReadingMachine::setDictsState(Dict::State state)
{
for (auto & classifier : classifiers)
classifier->getNN()->setDictsState(state);
}
void ReadingMachine::setCountOcc(bool countOcc)
{
for (auto & classifier : classifiers)
classifier->getNN()->setCountOcc(countOcc);
}
void ReadingMachine::removeRareDictElements(float rarityThreshold)
{
for (auto & classifier : classifiers)
classifier->getNN()->removeRareDictElements(rarityThreshold);
}
void ReadingMachine::resetClassifiers()
{
for (unsigned int i = 0; i < classifiers.size(); i++)
classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train));
}
int ReadingMachine::getNbParameters() const
{
int sum = 0;
for (auto & classifier : classifiers)
sum += classifier->getNbParameters();
return sum;
}
void ReadingMachine::resetOptimizers()
{
for (auto & classifier : classifiers)
classifier->resetOptimizer();
}