Skip to content
Snippets Groups Projects
ReadingMachine.cpp 5.35 KiB
#include "ReadingMachine.hpp"
#include "util.hpp"

ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
{
  readFromFile(path);
}

ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models) : path(path)
{
  readFromFile(path);

  loadDicts();
  classifier->getNN()->registerEmbeddings();
  classifier->getNN()->to(NeuralNetworkImpl::device);

  if (models.size() > 1)
    util::myThrow("having more than one model file is not supported");

  try
  {
    torch::load(classifier->getNN(), models[0]);
  } catch (std::exception & e)
  {
    util::myThrow(fmt::format("error when loading '{}' : {}", models[0].string(), e.what()));
  }
}

void ReadingMachine::readFromFile(std::filesystem::path path)
{
  std::FILE * file = std::fopen(path.c_str(), "r");
  if (not file)
    util::myThrow(fmt::format("can't open file '{}'", path.string()));

  char buffer[1024];
  std::string fileContent;
  std::vector<std::string> lines;
  while (!std::feof(file))
  {
    if (buffer != std::fgets(buffer, 1024, file))
      break;
    // If line is blank or commented (# or //), ignore it
    if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){}))
      continue;

    if (buffer[std::strlen(buffer)-1] == '\n')
      buffer[std::strlen(buffer)-1] = '\0';

    lines.emplace_back(buffer);
  }
  std::fclose(file);

  try
  {
    unsigned int curLine = 0;
    if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
      util::myThrow("No name specified");

    while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm)
      {
        std::vector<std::string> classifierDefinition;
        if (lines[curLine] != "{")
          util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));

        for (curLine++; curLine < lines.size(); curLine++)
        {
          if (lines[curLine] == "}")
            break;
          classifierDefinition.emplace_back(lines[curLine]);
        }
        classifier.reset(new Classifier(sm.str(1), path, classifierDefinition));
      }));
    if (!classifier.get())
      util::myThrow("No Classifier specified");

    util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
    {
      this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1)));
      curLine++;
    });

    if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
    {
      auto predictions = sm.str(1);
      auto splited = util::split(predictions, ' ');
      for (auto & prediction : splited)
        predicted.insert(std::string(prediction));
    }))
      util::myThrow("No predictions specified");

    if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm)
      {
        std::vector<std::string> strategyDefinition;
        if (lines[curLine] != "{")
          util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));

        for (curLine++; curLine < lines.size(); curLine++)
        {
          if (lines[curLine] == "}")
            break;
          strategyDefinition.emplace_back(lines[curLine]);
        }
        strategy.reset(new Strategy(strategyDefinition));
      }))
      util::myThrow("No Strategy specified");

  } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
}

TransitionSet & ReadingMachine::getTransitionSet()
{
  return classifier->getTransitionSet();
}

bool ReadingMachine::hasSplitWordTransitionSet() const
{
  return splitWordTransitionSet.get() != nullptr;
}

TransitionSet & ReadingMachine::getSplitWordTransitionSet()
{
  return *splitWordTransitionSet;
}

Strategy & ReadingMachine::getStrategy()
{
  return *strategy;
}

Classifier * ReadingMachine::getClassifier()
{
  return classifier.get();
}

void ReadingMachine::saveDicts() const
{
  classifier->getNN()->saveDicts(path.parent_path());
}

void ReadingMachine::loadDicts()
{
  classifier->getNN()->loadDicts(path.parent_path());
}

void ReadingMachine::save(const std::string & modelNameTemplate) const
{
  saveDicts();

  auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName());
  torch::save(classifier->getNN(), pathToClassifier);
}

void ReadingMachine::saveBest() const
{
  save(defaultModelFilename);
}

void ReadingMachine::saveLast() const
{
  save(lastModelFilename);
}

bool ReadingMachine::isPredicted(const std::string & columnName) const
{
  return predicted.count(columnName);
}

const std::set<std::string> & ReadingMachine::getPredicted() const
{
  return predicted;
}

void ReadingMachine::trainMode(bool isTrainMode)
{
  classifier->getNN()->train(isTrainMode);
}

void ReadingMachine::setDictsState(Dict::State state)
{
  classifier->getNN()->setDictsState(state);
}

void ReadingMachine::loadLastSaved()
{
  auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
  if (!lastSavedModel.empty())
    torch::load(classifier->getNN(), lastSavedModel[0]);
}

void ReadingMachine::setCountOcc(bool countOcc)
{
  classifier->getNN()->setCountOcc(countOcc);
}

void ReadingMachine::removeRareDictElements(float rarityThreshold)
{
  classifier->getNN()->removeRareDictElements(rarityThreshold);
}