#include "ReadingMachine.hpp" #include "util.hpp" ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) { dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open)); readFromFile(path); } ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts) { readFromFile(path); for (auto path : dicts) this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed}); torch::load(classifier->getNN(), models[0]); } void ReadingMachine::readFromFile(std::filesystem::path path) { std::FILE * file = std::fopen(path.c_str(), "r"); char buffer[1024]; std::string fileContent; std::vector<std::string> lines; while (!std::feof(file)) { if (buffer != std::fgets(buffer, 1024, file)) break; // If line is blank or commented (# or //), ignore it if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){})) continue; if (buffer[std::strlen(buffer)-1] == '\n') buffer[std::strlen(buffer)-1] = '\0'; lines.emplace_back(buffer); } std::fclose(file); try { unsigned int curLine = 0; if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];})) util::myThrow("No name specified"); while (util::doIfNameMatch(std::regex("Classifier : (.+) (.+) (.+)"), lines[curLine++], [this,path](auto sm){classifier.reset(new Classifier(sm.str(1), sm.str(2), path.parent_path() / sm.str(3)));})); if (!classifier.get()) util::myThrow("No Classifier specified"); --curLine; if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm) { auto predictions = sm.str(1); auto splited = util::split(predictions, ' '); for (auto & prediction : splited) predicted.insert(std::string(prediction)); })) util::myThrow("No predictions specified"); auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end()); strategy.reset(new Strategy(restOfFile)); } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));} } TransitionSet & ReadingMachine::getTransitionSet() { return classifier->getTransitionSet(); } Strategy & ReadingMachine::getStrategy() { return *strategy; } Dict & ReadingMachine::getDict(const std::string & state) { auto found = dicts.find(state); try { if (found == dicts.end()) return dicts.at(defaultDictName); } catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));} return found->second; } Classifier * ReadingMachine::getClassifier() { return classifier.get(); } void ReadingMachine::save() const { for (auto & it : dicts) { auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first); std::FILE * file = std::fopen(pathToDict.c_str(), "w"); if (!file) util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str())); it.second.save(file, Dict::Encoding::Ascii); std::fclose(file); } auto pathToClassifier = path.parent_path() / fmt::format(defaultModelFilename, classifier->getName()); torch::save(classifier->getNN(), pathToClassifier); } bool ReadingMachine::isPredicted(const std::string & columnName) const { return predicted.count(columnName); } const std::set<std::string> & ReadingMachine::getPredicted() const { return predicted; } void ReadingMachine::trainMode(bool isTrainMode) { classifier->getNN()->train(isTrainMode); for (auto & it : dicts) it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed); } std::map<std::string, Dict> & ReadingMachine::getDicts() { return dicts; }