Skip to content
Snippets Groups Projects
Classifier.cpp 3.17 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "util.hpp"
    #include "OneWordNetwork.hpp"
    #include "ConcatWordsNetwork.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "CNNNetwork.hpp"
    
    
    Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
    {
      this->name = name;
      this->transitionSet.reset(new TransitionSet(tsFile));
    
      initNeuralNetwork(topology);
    
    TransitionSet & Classifier::getTransitionSet()
    {
      return *transitionSet;
    }
    
    
    NeuralNetwork & Classifier::getNN()
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      return reinterpret_cast<NeuralNetwork&>(nn);
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    const std::string & Classifier::getName() const
    {
      return name;
    }
    
    
    void Classifier::initNeuralNetwork(const std::string & topology)
    {
      static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
      {
        {
    
    Franck Dary's avatar
    Franck Dary committed
          std::regex("OneWord\\(([+\\-]?\\d+)\\)"),
    
          "OneWord(focusedIndex) : Only use the word embedding of the focused word.",
          [this,topology](auto sm)
          {
            this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm[1])));
          }
        },
        {
    
    Franck Dary's avatar
    Franck Dary committed
          std::regex("ConcatWords\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
          "ConcatWords(leftBorder,rightBorder,nbStack) : Concatenate embeddings of words in context.",
          [this,topology](auto sm)
    
    Franck Dary's avatar
    Franck Dary committed
            this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
          std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"),
          "CNN(leftBorder,rightBorder,nbStack,{focusedBuffer},{focusedStack},{focusedColumns}) : CNN to capture context.",
    
    Franck Dary's avatar
    Franck Dary committed
          [this,topology](auto sm)
          {
    
            std::vector<long> focusedBuffer, focusedStack;
            std::vector<std::string> focusedColumns, columns;
            for (auto s : util::split(std::string(sm[4]), ','))
              columns.emplace_back(s);
            for (auto s : util::split(std::string(sm[5]), ','))
              focusedBuffer.push_back(std::stoi(std::string(s)));
            for (auto s : util::split(std::string(sm[6]), ','))
              focusedStack.push_back(std::stoi(std::string(s)));
            for (auto s : util::split(std::string(sm[7]), ','))
              focusedColumns.emplace_back(s);
            this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns));
    
    Franck Dary's avatar
    Franck Dary committed
          }
        },
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
          std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
          "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
    
    Franck Dary's avatar
    Franck Dary committed
          [this,topology](auto sm)
          {
    
            this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
    
      };
    
      for (auto & initializer : initializers)
        if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
    
        {
          this->nn->to(NeuralNetworkImpl::device);
    
    
      std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
      for (auto & initializer : initializers)
        errorMessage += std::get<1>(initializer) + "\n";
    
      util::myThrow(errorMessage);
    }