Skip to content
Snippets Groups Projects
ReadingMachine.cpp 3.44 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "ReadingMachine.hpp"
    #include "util.hpp"
    
    
    Franck Dary's avatar
    Franck Dary committed
    ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
    
    Franck Dary's avatar
    Franck Dary committed
      dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      readFromFile(path);
    }
    
    ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
    {
      readFromFile(path);
    
      for (auto path : dicts)
        this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed});
    
      torch::load(classifier->getNN(), models[0]);
    }
    
    void ReadingMachine::readFromFile(std::filesystem::path path)
    {
    
    Franck Dary's avatar
    Franck Dary committed
      std::FILE * file = std::fopen(path.c_str(), "r");
    
    Franck Dary's avatar
    Franck Dary committed
      std::string fileContent;
    
      std::vector<std::string> lines;
    
      while (!std::feof(file))
      {
        if (buffer != std::fgets(buffer, 1024, file))
          break;
    
        // If line is blank or commented (# or //), ignore it
        if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){}))
          continue;
    
        if (buffer[std::strlen(buffer)-1] == '\n')
          buffer[std::strlen(buffer)-1] = '\0';
    
        lines.emplace_back(buffer);
    
    Franck Dary's avatar
    Franck Dary committed
    
      try
      {
        unsigned int curLine = 0;
        if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
          util::myThrow("No name specified");
    
        while (util::doIfNameMatch(std::regex("Classifier : (.+) (.+) (.+)"), lines[curLine++], [this](auto sm){classifier.reset(new Classifier(sm[1], sm[2], sm[3]));}));
        if (!classifier.get())
          util::myThrow("No Classifier specified");
    
    
        if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
        {
          auto predictions = std::string(sm[1]);
          auto splited = util::split(predictions, ' ');
          for (auto & prediction : splited)
            predicted.insert(std::string(prediction));
        }))
          util::myThrow("No predictions specified");
    
    
        auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end());
    
    
    Franck Dary's avatar
    Franck Dary committed
        strategy.reset(new Strategy(restOfFile));
    
    
    Franck Dary's avatar
    Franck Dary committed
      } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
    }
    
    
    TransitionSet & ReadingMachine::getTransitionSet()
    {
      return classifier->getTransitionSet();
    }
    
    Strategy & ReadingMachine::getStrategy()
    {
      return *strategy;
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    Dict & ReadingMachine::getDict(const std::string & state)
    {
      auto found = dicts.find(state);
    
    
    Franck Dary's avatar
    Franck Dary committed
      try
      {
        if (found == dicts.end())
          return dicts.at(defaultDictName);
      } catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));}
    
    Franck Dary's avatar
    Franck Dary committed
    
      return found->second;
    }
    
    Classifier * ReadingMachine::getClassifier()
    {
      return classifier.get();
    }
    
    
    void ReadingMachine::save() const
    
    Franck Dary's avatar
    Franck Dary committed
    {
      for (auto & it : dicts)
      {
        auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first);
        std::FILE * file = std::fopen(pathToDict.c_str(), "w");
        if (!file)
          util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str()));
    
        it.second.save(file, Dict::Encoding::Ascii);
    
        std::fclose(file);
      }
    
      auto pathToClassifier = path.parent_path() / fmt::format(defaultModelFilename, classifier->getName());
      torch::save(classifier->getNN(), pathToClassifier);
    }
    
    
    bool ReadingMachine::isPredicted(const std::string & columnName) const
    {
      return predicted.count(columnName);
    }
    
    
    const std::set<std::string> & ReadingMachine::getPredicted() const
    {
      return predicted;
    }