Skip to content
Snippets Groups Projects
Classifier.cpp 5.11 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "util.hpp"
    #include "OneWordNetwork.hpp"
    #include "ConcatWordsNetwork.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "CNNNetwork.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "LSTMNetwork.hpp"
    
    
    Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
    {
      this->name = name;
      this->transitionSet.reset(new TransitionSet(tsFile));
    
      initNeuralNetwork(topology);
    
    TransitionSet & Classifier::getTransitionSet()
    {
      return *transitionSet;
    }
    
    
    NeuralNetwork & Classifier::getNN()
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      return reinterpret_cast<NeuralNetwork&>(nn);
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    const std::string & Classifier::getName() const
    {
      return name;
    }
    
    
    void Classifier::initNeuralNetwork(const std::string & topology)
    {
      static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
      {
        {
    
    Franck Dary's avatar
    Franck Dary committed
          std::regex("OneWord\\(([+\\-]?\\d+)\\)"),
    
          "OneWord(focusedIndex) : Only use the word embedding of the focused word.",
          [this,topology](auto sm)
          {
            this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm[1])));
          }
        },
        {
    
    Franck Dary's avatar
    Franck Dary committed
          std::regex("ConcatWords\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
          "ConcatWords(leftBorder,rightBorder,nbStack) : Concatenate embeddings of words in context.",
          [this,topology](auto sm)
    
    Franck Dary's avatar
    Franck Dary committed
            this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
          std::regex("CNN\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
          "CNN(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
    
    Franck Dary's avatar
    Franck Dary committed
          [this,topology](auto sm)
          {
    
            std::vector<int> focusedBuffer, focusedStack, maxNbElements;
    
            std::vector<std::string> focusedColumns, columns;
            for (auto s : util::split(std::string(sm[5]), ','))
    
            for (auto s : util::split(std::string(sm[6]), ','))
    
              focusedBuffer.push_back(std::stoi(std::string(s)));
    
            for (auto s : util::split(std::string(sm[7]), ','))
    
              focusedStack.push_back(std::stoi(std::string(s)));
    
            for (auto s : util::split(std::string(sm[8]), ','))
    
              focusedColumns.emplace_back(s);
            for (auto s : util::split(std::string(sm[9]), ','))
    
              maxNbElements.push_back(std::stoi(std::string(s)));
            if (focusedColumns.size() != maxNbElements.size())
              util::myThrow("focusedColumns.size() != maxNbElements.size()");
    
            this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11])));
    
    Franck Dary's avatar
    Franck Dary committed
          }
        },
    
    Franck Dary's avatar
    Franck Dary committed
        {
          std::regex("LSTM\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
          "LSTM(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
          [this,topology](auto sm)
          {
            std::vector<int> focusedBuffer, focusedStack, maxNbElements;
            std::vector<std::string> focusedColumns, columns;
            for (auto s : util::split(std::string(sm[5]), ','))
              columns.emplace_back(s);
            for (auto s : util::split(std::string(sm[6]), ','))
              focusedBuffer.push_back(std::stoi(std::string(s)));
            for (auto s : util::split(std::string(sm[7]), ','))
              focusedStack.push_back(std::stoi(std::string(s)));
            for (auto s : util::split(std::string(sm[8]), ','))
              focusedColumns.emplace_back(s);
            for (auto s : util::split(std::string(sm[9]), ','))
              maxNbElements.push_back(std::stoi(std::string(s)));
            if (focusedColumns.size() != maxNbElements.size())
              util::myThrow("focusedColumns.size() != maxNbElements.size()");
            this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11])));
          }
        },
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
          std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
          "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
    
    Franck Dary's avatar
    Franck Dary committed
          [this,topology](auto sm)
          {
    
            this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
    
      };
    
      for (auto & initializer : initializers)
        if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
    
        {
          this->nn->to(NeuralNetworkImpl::device);
    
    
      std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
      for (auto & initializer : initializers)
        errorMessage += std::get<1>(initializer) + "\n";
    
      util::myThrow(errorMessage);
    }