Newer
Older
Franck Dary
committed
#include "ReadingMachine.hpp"
#include "util.hpp"
ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
Franck Dary
committed
{
Franck Dary
committed
auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, ""));
if (!lastSavedModel.empty())
torch::load(classifier->getNN(), lastSavedModel[0]);
for (auto path : savedDicts)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
if (dicts.count(defaultDictName) == 0)
dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
}
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
{
readFromFile(path);
for (auto path : dicts)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed});
torch::load(classifier->getNN(), models[0]);
}
void ReadingMachine::readFromFile(std::filesystem::path path)
{
if (not file)
util::myThrow(fmt::format("can't open file '{}'", path.string()));
Franck Dary
committed
char buffer[1024];
Franck Dary
committed
while (!std::feof(file))
{
if (buffer != std::fgets(buffer, 1024, file))
break;
// If line is blank or commented (# or //), ignore it
if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){}))
continue;
Franck Dary
committed
if (buffer[std::strlen(buffer)-1] == '\n')
buffer[std::strlen(buffer)-1] = '\0';
lines.emplace_back(buffer);
Franck Dary
committed
}
std::fclose(file);
try
{
unsigned int curLine = 0;
if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
util::myThrow("No name specified");
while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm)
std::vector<std::string> classifierDefinition;
if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
for (curLine++; curLine < lines.size(); curLine++)
{
if (lines[curLine] == "}")
break;
classifierDefinition.emplace_back(lines[curLine]);
}
classifier.reset(new Classifier(sm.str(1), path, classifierDefinition));
if (!classifier.get())
util::myThrow("No Classifier specified");
util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
{
this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1)));
curLine++;
if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
{
auto splited = util::split(predictions, ' ');
for (auto & prediction : splited)
predicted.insert(std::string(prediction));
}))
util::myThrow("No predictions specified");
auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end());
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
}
TransitionSet & ReadingMachine::getTransitionSet()
{
return classifier->getTransitionSet();
}
bool ReadingMachine::hasSplitWordTransitionSet() const
{
return splitWordTransitionSet.get() != nullptr;
}
TransitionSet & ReadingMachine::getSplitWordTransitionSet()
{
return *splitWordTransitionSet;
}
Strategy & ReadingMachine::getStrategy()
{
return *strategy;
}
Dict & ReadingMachine::getDict(const std::string & state)
{
auto found = dicts.find(state);
try
{
if (found == dicts.end())
return dicts.at(defaultDictName);
} catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));}
return found->second;
}
Classifier * ReadingMachine::getClassifier()
{
return classifier.get();
}
Franck Dary
committed
void ReadingMachine::save(const std::string & modelNameTemplate) const
{
for (auto & it : dicts)
{
auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first);
std::FILE * file = std::fopen(pathToDict.c_str(), "w");
if (!file)
util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str()));
it.second.save(file, Dict::Encoding::Ascii);
std::fclose(file);
}
Franck Dary
committed
auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName());
torch::save(classifier->getNN(), pathToClassifier);
}
Franck Dary
committed
void ReadingMachine::saveBest() const
{
save(defaultModelFilename);
}
void ReadingMachine::saveLast() const
{
save(lastModelFilename);
}
bool ReadingMachine::isPredicted(const std::string & columnName) const
{
return predicted.count(columnName);
}
const std::set<std::string> & ReadingMachine::getPredicted() const
{
return predicted;
}
void ReadingMachine::trainMode(bool isTrainMode)
{
classifier->getNN()->train(isTrainMode);
for (auto & it : dicts)
it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
}
Franck Dary
committed
std::map<std::string, Dict> & ReadingMachine::getDicts()
{
return dicts;
}