Skip to content
Snippets Groups Projects
Classifier.cpp 11.98 KiB
#include "Classifier.hpp"
#include "util.hpp"
#include "LSTMNetwork.hpp"
#include "RandomNetwork.hpp"

Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
{
  this->name = name;
  if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
        {
          std::vector<std::string> tsFiles;

          for (auto & tsFilename : util::split(sm.str(1), ' '))
            tsFiles.emplace_back(path.parent_path() / tsFilename);

          this->transitionSet.reset(new TransitionSet(tsFiles));
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));

  initNeuralNetwork(definition);
}

int Classifier::getNbParameters() const
{
  int nbParameters = 0;

  for (auto & t : nn->parameters())
    nbParameters += torch::numel(t);

  return nbParameters;
}

TransitionSet & Classifier::getTransitionSet()
{
  return *transitionSet;
}

NeuralNetwork & Classifier::getNN()
{
  return reinterpret_cast<NeuralNetwork&>(nn);
}

const std::string & Classifier::getName() const
{
  return name;
}

void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
{
  std::size_t curIndex = 1;

  std::string networkType;
  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
        {
          networkType = sm.str(1);
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));

  if (networkType == "Random")
    this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
  else if (networkType == "LSTM")
    initLSTM(definition, curIndex);
  else
    util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));

  this->nn->to(NeuralNetworkImpl::device);
}

void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex)
{
  int unknownValueThreshold;
  std::vector<int> bufferContext, stackContext;
  std::vector<std::string> columns;
  std::vector<int> focusedBuffer, focusedStack;
  std::vector<std::string> focusedColumns;
  std::vector<int> maxNbElements;
  std::vector<std::pair<int, float>> mlp;
  int rawInputLeftWindow, rawInputRightWindow;
  int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers;
  bool bilstm;
  float lstmDropout;

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
        {
          unknownValueThreshold = std::stoi(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm)
        {
          for (auto & index : util::split(sm.str(1), ' '))
            bufferContext.emplace_back(std::stoi(index));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm)
        {
          for (auto & index : util::split(sm.str(1), ' '))
            stackContext.emplace_back(std::stoi(index));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm)
        {
          columns = util::split(sm.str(1), ' ');
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm)
        {
          for (auto & index : util::split(sm.str(1), ' '))
            focusedBuffer.emplace_back(std::stoi(index));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm)
        {
          for (auto & index : util::split(sm.str(1), ' '))
            focusedStack.emplace_back(std::stoi(index));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm)
        {
          focusedColumns = util::split(sm.str(1), ' ');
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm)
        {
          for (auto & index : util::split(sm.str(1), ' '))
            maxNbElements.emplace_back(std::stoi(index));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm)
        {
          rawInputLeftWindow = std::stoi(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm)
        {
          rawInputRightWindow = std::stoi(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm)
        {
          embeddingsSize = std::stoi(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm)
        {
          auto params = util::split(sm.str(1), ' ');
          if (params.size() % 2)
            util::myThrow("MLP must have even number of parameters");
          for (unsigned int i = 0; i < params.size()/2; i++)
            mlp.emplace_back(std::make_pair(std::stoi(params[i]), std::stof(params[i+1])));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
        {
          contextLSTMSize = std::stoi(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm)
        {
          focusedLSTMSize = std::stoi(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm)
        {
          rawInputLSTMSize = std::stoi(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm)
        {
          splitTransLSTMSize = std::stoi(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm)
        {
          nbLayers = std::stoi(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm)
        {
          bilstm = sm.str(1) == "true";
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false"));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm)
        {
          lstmDropout = std::stof(sm.str(1));
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));

  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout));
}