Skip to content
Snippets Groups Projects
ReadingMachine.cpp 3.74 KiB
Newer Older
#include "ReadingMachine.hpp"
#include "util.hpp"

Franck Dary's avatar
Franck Dary committed
ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
Franck Dary's avatar
Franck Dary committed
  dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
  readFromFile(path);
}

ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
{
  readFromFile(path);

  for (auto path : dicts)
    this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed});

  torch::load(classifier->getNN(), models[0]);
}

void ReadingMachine::readFromFile(std::filesystem::path path)
{
Franck Dary's avatar
Franck Dary committed
  std::FILE * file = std::fopen(path.c_str(), "r");
Franck Dary's avatar
Franck Dary committed
  std::string fileContent;
  std::vector<std::string> lines;
  while (!std::feof(file))
  {
    if (buffer != std::fgets(buffer, 1024, file))
      break;
    // If line is blank or commented (# or //), ignore it
    if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){}))
      continue;
    if (buffer[std::strlen(buffer)-1] == '\n')
      buffer[std::strlen(buffer)-1] = '\0';

    lines.emplace_back(buffer);
Franck Dary's avatar
Franck Dary committed

  try
  {
    unsigned int curLine = 0;
    if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
      util::myThrow("No name specified");

    while (util::doIfNameMatch(std::regex("Classifier : (.+) (.+) (.+)"), lines[curLine++], [this,path](auto sm){classifier.reset(new Classifier(sm.str(1), sm.str(2), path.parent_path() / sm.str(3)));}));
Franck Dary's avatar
Franck Dary committed
    if (!classifier.get())
      util::myThrow("No Classifier specified");

    if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
    {
      auto predictions = sm.str(1);
      auto splited = util::split(predictions, ' ');
      for (auto & prediction : splited)
        predicted.insert(std::string(prediction));
    }))
      util::myThrow("No predictions specified");

    auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end());

Franck Dary's avatar
Franck Dary committed
    strategy.reset(new Strategy(restOfFile));

Franck Dary's avatar
Franck Dary committed
  } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
}

TransitionSet & ReadingMachine::getTransitionSet()
{
  return classifier->getTransitionSet();
}

Strategy & ReadingMachine::getStrategy()
{
  return *strategy;
}

Franck Dary's avatar
Franck Dary committed
Dict & ReadingMachine::getDict(const std::string & state)
{
  auto found = dicts.find(state);

Franck Dary's avatar
Franck Dary committed
  try
  {
    if (found == dicts.end())
      return dicts.at(defaultDictName);
  } catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));}
Franck Dary's avatar
Franck Dary committed

  return found->second;
}

Classifier * ReadingMachine::getClassifier()
{
  return classifier.get();
}

void ReadingMachine::save() const
Franck Dary's avatar
Franck Dary committed
{
  for (auto & it : dicts)
  {
    auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first);
    std::FILE * file = std::fopen(pathToDict.c_str(), "w");
    if (!file)
      util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str()));

    it.second.save(file, Dict::Encoding::Ascii);

    std::fclose(file);
  }

  auto pathToClassifier = path.parent_path() / fmt::format(defaultModelFilename, classifier->getName());
  torch::save(classifier->getNN(), pathToClassifier);
}

bool ReadingMachine::isPredicted(const std::string & columnName) const
{
  return predicted.count(columnName);
}

const std::set<std::string> & ReadingMachine::getPredicted() const
{
  return predicted;
}

Franck Dary's avatar
Franck Dary committed
void ReadingMachine::trainMode(bool isTrainMode)
{
  classifier->getNN()->train(isTrainMode);
  for (auto & it : dicts)
    it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
}

std::map<std::string, Dict> & ReadingMachine::getDicts()
{
  return dicts;
}