Skip to content
Snippets Groups Projects
Select Git revision
  • 4ce19f6f30c1ed7e161443e62f60a12bde62fa09
  • master default protected
  • fullUD
  • movementInAction
4 results

macaon_decode.cpp

Blame
  • ReadingMachine.cpp 5.37 KiB
    #include "ReadingMachine.hpp"
    #include "util.hpp"
    
    ReadingMachine::ReadingMachine(std::filesystem::path path, bool train) : path(path), train(train)
    {
      readFromFile(path);
    }
    
    void ReadingMachine::readFromFile(std::filesystem::path path)
    {
      std::FILE * file = std::fopen(path.c_str(), "r");
      if (not file)
        util::myThrow(fmt::format("can't open file '{}'", path.string()));
    
      char buffer[1024];
      std::string fileContent;
      std::vector<std::string> lines;
      while (!std::feof(file))
      {
        if (buffer != std::fgets(buffer, 1024, file))
          break;
        // If line is blank or commented (# or //), ignore it
        if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){}))
          continue;
    
        if (buffer[std::strlen(buffer)-1] == '\n')
          buffer[std::strlen(buffer)-1] = '\0';
    
        lines.emplace_back(buffer);
      }
      std::fclose(file);
    
      try
      {
        unsigned int curLine = 0;
        if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
          util::myThrow("No name specified");
    
        while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine], [this,path,&lines,&curLine](auto sm)
          {
            curLine++;
            classifierDefinitions.emplace_back();
            classifierNames.emplace_back(sm.str(1));
            if (lines[curLine] != "{")
              util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
    
            for (curLine++; curLine < lines.size(); curLine++)
            {
              if (lines[curLine] == "}")
              {
                curLine++;
                break;
              }
              classifierDefinitions.back().emplace_back(lines[curLine]);
            }
            classifiers.emplace_back(new Classifier(sm.str(1), path.parent_path(), classifierDefinitions.back(), train));
            for (auto state : classifiers.back()->getStates())
              state2classifier[state] = classifiers.size()-1;
          }));
        if (classifiers.empty())
          util::myThrow("No Classifier specified");
    
        util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
        {
          this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1)));
          curLine++;
        });
    
        if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
        {
          auto predictions = sm.str(1);
          auto splited = util::split(predictions, ' ');
          for (auto & prediction : splited)
            predicted.insert(std::string(prediction));
        }))
          util::myThrow("No predictions specified");
    
        if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm)
          {
            strategyDefinition.clear();
            if (lines[curLine] != "{")
              util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
    
            for (curLine++; curLine < lines.size(); curLine++)
            {
              if (lines[curLine] == "}")
                break;
              strategyDefinition.emplace_back(lines[curLine]);
            }
          }))
          util::myThrow("No Strategy specified");
    
      } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
    }
    
    TransitionSet & ReadingMachine::getTransitionSet(const std::string & state)
    {
      return classifiers[state2classifier.at(state)]->getTransitionSet();
    }
    
    bool ReadingMachine::hasSplitWordTransitionSet() const
    {
      return splitWordTransitionSet.get() != nullptr;
    }
    
    TransitionSet & ReadingMachine::getSplitWordTransitionSet()
    {
      return *splitWordTransitionSet;
    }
    
    const std::vector<std::string> & ReadingMachine::getStrategyDefinition() const
    {
      return strategyDefinition;
    }
    
    Classifier * ReadingMachine::getClassifier(const std::string & state)
    {
      return classifiers[state2classifier.at(state)].get();
    }
    
    void ReadingMachine::saveDicts() const
    {
      for (auto & classifier : classifiers)
        classifier->saveDicts();
    }
    
    void ReadingMachine::saveBest() const
    {
      saveDicts();
      for (auto & classifier : classifiers)
        classifier->saveBest();
    }
    
    void ReadingMachine::saveLast() const
    {
      saveDicts();
      for (auto & classifier : classifiers)
        classifier->saveLast();
    }
    
    bool ReadingMachine::isPredicted(const std::string & columnName) const
    {
      return predicted.count(columnName);
    }
    
    const std::set<std::string> & ReadingMachine::getPredicted() const
    {
      return predicted;
    }
    
    void ReadingMachine::trainMode(bool isTrainMode)
    {
      for (auto & classifier : classifiers)
        classifier->getNN()->train(isTrainMode);
    }
    
    void ReadingMachine::setDictsState(Dict::State state)
    {
      for (auto & classifier : classifiers)
        classifier->getNN()->setDictsState(state);
    }
    
    void ReadingMachine::setCountOcc(bool countOcc)
    {
      for (auto & classifier : classifiers)
        classifier->getNN()->setCountOcc(countOcc);
    }
    
    void ReadingMachine::removeRareDictElements(float rarityThreshold)
    {
      for (auto & classifier : classifiers)
        classifier->getNN()->removeRareDictElements(rarityThreshold);
    }
    
    void ReadingMachine::resetClassifiers()
    {
      for (unsigned int i = 0; i < classifiers.size(); i++)
        classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train));
    }
    
    int ReadingMachine::getNbParameters() const
    {
      int sum = 0;
    
      for (auto & classifier : classifiers)
        sum += classifier->getNbParameters();
    
      return sum;
    }
    
    void ReadingMachine::resetOptimizers()
    {
      for (auto & classifier : classifiers)
        classifier->resetOptimizer();
    }