Skip to content
Snippets Groups Projects
Classifier.cpp 3.13 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "util.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "LSTMNetwork.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "RandomNetwork.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    Classifier::Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFiles)
    
    Franck Dary's avatar
    Franck Dary committed
      this->transitionSet.reset(new TransitionSet(tsFiles));
    
      initNeuralNetwork(topology);
    
    TransitionSet & Classifier::getTransitionSet()
    {
      return *transitionSet;
    }
    
    
    NeuralNetwork & Classifier::getNN()
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      return reinterpret_cast<NeuralNetwork&>(nn);
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    const std::string & Classifier::getName() const
    {
      return name;
    }
    
    
    void Classifier::initNeuralNetwork(const std::string & topology)
    {
      static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
      {
    
    Franck Dary's avatar
    Franck Dary committed
        {
          std::regex("Random"),
          "Random : Output is chosen at random.",
          [this,topology](auto sm)
          {
            this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
          }
        },
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
          std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
          "LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
    
    Franck Dary's avatar
    Franck Dary committed
          [this,topology](auto sm)
          {
    
            std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext;
    
    Franck Dary's avatar
    Franck Dary committed
            std::vector<std::string> focusedColumns, columns;
    
            for (auto s : util::split(sm.str(2), ','))
              bufferContext.emplace_back(std::stoi(s));
            for (auto s : util::split(sm.str(3), ','))
              stackContext.emplace_back(std::stoi(s));
            for (auto s : util::split(sm.str(4), ','))
    
    Franck Dary's avatar
    Franck Dary committed
              columns.emplace_back(s);
    
            for (auto s : util::split(sm.str(5), ','))
              focusedBuffer.push_back(std::stoi(s));
            for (auto s : util::split(sm.str(6), ','))
              focusedStack.push_back(std::stoi(s));
            for (auto s : util::split(sm.str(7), ','))
    
    Franck Dary's avatar
    Franck Dary committed
              focusedColumns.emplace_back(s);
    
            for (auto s : util::split(sm.str(8), ','))
              maxNbElements.push_back(std::stoi(s));
    
    Franck Dary's avatar
    Franck Dary committed
            if (focusedColumns.size() != maxNbElements.size())
              util::myThrow("focusedColumns.size() != maxNbElements.size()");
    
            this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
    
    Franck Dary's avatar
    Franck Dary committed
          }
        },
    
      std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
    
    
      for (auto & initializer : initializers)
    
          if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
          {
            this->nn->to(NeuralNetworkImpl::device);
            return;
          }
        }
        catch (std::exception & e)
        {
          errorMessage = fmt::format("Caught({}) {}", e.what(), errorMessage);
          break;
    
    
      for (auto & initializer : initializers)
        errorMessage += std::get<1>(initializer) + "\n";
    
      util::myThrow(errorMessage);
    }