Skip to content
Snippets Groups Projects
Classifier.cpp 17.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "util.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "LSTMNetwork.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "RandomNetwork.hpp"
    
    Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
    
      if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
            {
              std::vector<std::string> tsFiles;
    
              for (auto & tsFilename : util::split(sm.str(1), ' '))
                tsFiles.emplace_back(path.parent_path() / tsFilename);
    
              this->transitionSet.reset(new TransitionSet(tsFiles));
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
    
      initNeuralNetwork(definition);
    
    int Classifier::getNbParameters() const
    {
      int nbParameters = 0;
    
      for (auto & t : nn->parameters())
        nbParameters += torch::numel(t);
    
      return nbParameters;
    }
    
    
    TransitionSet & Classifier::getTransitionSet()
    {
      return *transitionSet;
    }
    
    
    NeuralNetwork & Classifier::getNN()
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      return reinterpret_cast<NeuralNetwork&>(nn);
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    const std::string & Classifier::getName() const
    {
      return name;
    }
    
    
    void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
    {
      std::size_t curIndex = 1;
    
      std::string networkType;
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
            {
              networkType = sm.str(1);
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));
    
      if (networkType == "Random")
        this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
      else if (networkType == "LSTM")
        initLSTM(definition, curIndex);
      else
        util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));
    
    
      this->nn->to(NeuralNetworkImpl::device);
    
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
            {
              std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
              if (sm.str(1) == "Adam")
              {
                auto splited = util::split(sm.str(2), ' ');
                if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
                  util::myThrow(expected);
    
                optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").beta1(std::stof(splited[1])).beta2(std::stof(splited[2])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
              }
              else
                util::myThrow(expected);
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
    
    }
    
    void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex)
    
      int unknownValueThreshold;
      std::vector<int> bufferContext, stackContext;
    
      std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
    
      std::vector<int> focusedBuffer, focusedStack;
    
      std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack;
    
      std::vector<int> maxNbElements;
    
      std::vector<int> treeEmbeddingNbElems;
    
    Franck Dary's avatar
    Franck Dary committed
      std::vector<std::pair<int, float>> mlp;
    
      int rawInputLeftWindow, rawInputRightWindow;
    
      int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
    
      float lstmDropout, embeddingsDropout, totalInputDropout;
    
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
            {
              unknownValueThreshold = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm)
            {
              for (auto & index : util::split(sm.str(1), ' '))
                bufferContext.emplace_back(std::stoi(index));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm)
            {
              for (auto & index : util::split(sm.str(1), ' '))
                stackContext.emplace_back(std::stoi(index));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm)
            {
              columns = util::split(sm.str(1), ' ');
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm)
            {
              for (auto & index : util::split(sm.str(1), ' '))
                focusedBuffer.emplace_back(std::stoi(index));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm)
            {
              for (auto & index : util::split(sm.str(1), ' '))
                focusedStack.emplace_back(std::stoi(index));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm)
            {
              focusedColumns = util::split(sm.str(1), ' ');
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm)
            {
              for (auto & index : util::split(sm.str(1), ' '))
                maxNbElements.emplace_back(std::stoi(index));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm)
            {
              rawInputLeftWindow = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm)
            {
              rawInputRightWindow = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm)
            {
              embeddingsSize = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));
    
    
    Franck Dary's avatar
    Franck Dary committed
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm)
    
    Franck Dary's avatar
    Franck Dary committed
              auto params = util::split(sm.str(1), ' ');
              if (params.size() % 2)
                util::myThrow("MLP must have even number of parameters");
              for (unsigned int i = 0; i < params.size()/2; i++)
    
    Franck Dary's avatar
    Franck Dary committed
                mlp.emplace_back(std::make_pair(std::stoi(params[2*i]), std::stof(params[2*i+1])));
    
    Franck Dary's avatar
    Franck Dary committed
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}"));
    
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
            {
              contextLSTMSize = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm)
            {
              focusedLSTMSize = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm)
            {
              rawInputLSTMSize = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm)
            {
              splitTransLSTMSize = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm)
            {
              nbLayers = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm)
            {
              bilstm = sm.str(1) == "true";
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm)
            {
              lstmDropout = std::stof(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
    
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm)
            {
              totalInputDropout = std::stof(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value"));
    
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm)
            {
              embeddingsDropout = std::stof(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value"));
    
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm)
            {
              drop2d = sm.str(1) == "true";
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false"));
    
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
            {
              treeEmbeddingColumns = util::split(sm.str(1), ' ');
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
    
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm)
            {
              for (auto & index : util::split(sm.str(1), ' '))
                treeEmbeddingBuffer.emplace_back(std::stoi(index));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm)
            {
              for (auto & index : util::split(sm.str(1), ' '))
                treeEmbeddingStack.emplace_back(std::stoi(index));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm)
            {
              for (auto & index : util::split(sm.str(1), ' '))
                treeEmbeddingNbElems.emplace_back(std::stoi(index));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}"));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm)
            {
              treeEmbeddingSize = std::stoi(sm.str(1));
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
    
    
      this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
    
    void Classifier::loadOptimizer(std::filesystem::path path)
    {
      torch::load(*optimizer, path);
    }
    
    void Classifier::saveOptimizer(std::filesystem::path path)
    {
      torch::save(*optimizer, path);
    }
    
    torch::optim::Adam & Classifier::getOptimizer()
    {
      return *optimizer;
    }