#include "ReadingMachine.hpp" #include "util.hpp" ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) { readFromFile(path); } ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models) : path(path) { readFromFile(path); loadDicts(); classifier->getNN()->registerEmbeddings(); classifier->getNN()->to(NeuralNetworkImpl::device); if (models.size() > 1) util::myThrow("having more than one model file is not supported"); try { torch::load(classifier->getNN(), models[0]); } catch (std::exception & e) { util::myThrow(fmt::format("error when loading '{}' : {}", models[0].string(), e.what())); } } void ReadingMachine::readFromFile(std::filesystem::path path) { std::FILE * file = std::fopen(path.c_str(), "r"); if (not file) util::myThrow(fmt::format("can't open file '{}'", path.string())); char buffer[1024]; std::string fileContent; std::vector<std::string> lines; while (!std::feof(file)) { if (buffer != std::fgets(buffer, 1024, file)) break; // If line is blank or commented (# or //), ignore it if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){})) continue; if (buffer[std::strlen(buffer)-1] == '\n') buffer[std::strlen(buffer)-1] = '\0'; lines.emplace_back(buffer); } std::fclose(file); try { unsigned int curLine = 0; if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];})) util::myThrow("No name specified"); while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm) { std::vector<std::string> classifierDefinition; if (lines[curLine] != "{") util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); for (curLine++; curLine < lines.size(); curLine++) { if (lines[curLine] == "}") break; classifierDefinition.emplace_back(lines[curLine]); } classifier.reset(new Classifier(sm.str(1), path, classifierDefinition)); })); if (!classifier.get()) util::myThrow("No Classifier specified"); util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm) { this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1))); curLine++; }); if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm) { auto predictions = sm.str(1); auto splited = util::split(predictions, ' '); for (auto & prediction : splited) predicted.insert(std::string(prediction)); })) util::myThrow("No predictions specified"); if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm) { std::vector<std::string> strategyDefinition; if (lines[curLine] != "{") util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); for (curLine++; curLine < lines.size(); curLine++) { if (lines[curLine] == "}") break; strategyDefinition.emplace_back(lines[curLine]); } strategy.reset(new Strategy(strategyDefinition)); })) util::myThrow("No Strategy specified"); } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));} } TransitionSet & ReadingMachine::getTransitionSet() { return classifier->getTransitionSet(); } bool ReadingMachine::hasSplitWordTransitionSet() const { return splitWordTransitionSet.get() != nullptr; } TransitionSet & ReadingMachine::getSplitWordTransitionSet() { return *splitWordTransitionSet; } Strategy & ReadingMachine::getStrategy() { return *strategy; } Classifier * ReadingMachine::getClassifier() { return classifier.get(); } void ReadingMachine::saveDicts() const { classifier->getNN()->saveDicts(path.parent_path()); } void ReadingMachine::loadDicts() { classifier->getNN()->loadDicts(path.parent_path()); } void ReadingMachine::save(const std::string & modelNameTemplate) const { saveDicts(); auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName()); torch::save(classifier->getNN(), pathToClassifier); } void ReadingMachine::saveBest() const { save(defaultModelFilename); } void ReadingMachine::saveLast() const { save(lastModelFilename); } bool ReadingMachine::isPredicted(const std::string & columnName) const { return predicted.count(columnName); } const std::set<std::string> & ReadingMachine::getPredicted() const { return predicted; } void ReadingMachine::trainMode(bool isTrainMode) { classifier->getNN()->train(isTrainMode); } void ReadingMachine::setDictsState(Dict::State state) { classifier->getNN()->setDictsState(state); } void ReadingMachine::loadLastSaved() { auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, "")); if (!lastSavedModel.empty()) torch::load(classifier->getNN(), lastSavedModel[0]); } void ReadingMachine::setCountOcc(bool countOcc) { classifier->getNN()->setCountOcc(countOcc); } void ReadingMachine::removeRareDictElements(float rarityThreshold) { classifier->getNN()->removeRareDictElements(rarityThreshold); }