Skip to content
Snippets Groups Projects
Classifier.cpp 5.78 KiB
Newer Older
#include "util.hpp"
#include "OneWordNetwork.hpp"
#include "ConcatWordsNetwork.hpp"
Franck Dary's avatar
Franck Dary committed
#include "CNNNetwork.hpp"
Franck Dary's avatar
Franck Dary committed
#include "LSTMNetwork.hpp"

Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
{
  this->name = name;
  this->transitionSet.reset(new TransitionSet(tsFile));
  initNeuralNetwork(topology);
TransitionSet & Classifier::getTransitionSet()
{
  return *transitionSet;
}

NeuralNetwork & Classifier::getNN()
Franck Dary's avatar
Franck Dary committed
{
  return reinterpret_cast<NeuralNetwork&>(nn);
Franck Dary's avatar
Franck Dary committed
}

Franck Dary's avatar
Franck Dary committed
const std::string & Classifier::getName() const
{
  return name;
}

void Classifier::initNeuralNetwork(const std::string & topology)
{
  static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
  {
    {
Franck Dary's avatar
Franck Dary committed
      std::regex("OneWord\\(([+\\-]?\\d+)\\)"),
      "OneWord(focusedIndex) : Only use the word embedding of the focused word.",
      [this,topology](auto sm)
      {
        this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1))));
      std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"),
      "ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.",
Franck Dary's avatar
Franck Dary committed
      [this,topology](auto sm)
        std::vector<int> bufferContext, stackContext;
        for (auto s : util::split(sm.str(1), ','))
          bufferContext.emplace_back(std::stoi(s));
        for (auto s : util::split(sm.str(2), ','))
          stackContext.emplace_back(std::stoi(s));
        this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), bufferContext, stackContext));
Franck Dary's avatar
Franck Dary committed
    {
      std::regex("CNN\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
      "CNN(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
Franck Dary's avatar
Franck Dary committed
      [this,topology](auto sm)
      {
        std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext;
        std::vector<std::string> focusedColumns, columns;
        for (auto s : util::split(sm.str(2), ','))
          bufferContext.emplace_back(std::stoi(s));
        for (auto s : util::split(sm.str(3), ','))
          stackContext.emplace_back(std::stoi(s));
        for (auto s : util::split(sm.str(4), ','))
        for (auto s : util::split(sm.str(5), ','))
          focusedBuffer.push_back(std::stoi(s));
        for (auto s : util::split(sm.str(6), ','))
          focusedStack.push_back(std::stoi(s));
        for (auto s : util::split(sm.str(7), ','))
          focusedColumns.emplace_back(s);
        for (auto s : util::split(sm.str(8), ','))
          maxNbElements.push_back(std::stoi(s));
        if (focusedColumns.size() != maxNbElements.size())
          util::myThrow("focusedColumns.size() != maxNbElements.size()");
        this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
Franck Dary's avatar
Franck Dary committed
      }
    },
Franck Dary's avatar
Franck Dary committed
    {
      std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
      "LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
Franck Dary's avatar
Franck Dary committed
      [this,topology](auto sm)
      {
        std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext;
Franck Dary's avatar
Franck Dary committed
        std::vector<std::string> focusedColumns, columns;
        for (auto s : util::split(sm.str(2), ','))
          bufferContext.emplace_back(std::stoi(s));
        for (auto s : util::split(sm.str(3), ','))
          stackContext.emplace_back(std::stoi(s));
        for (auto s : util::split(sm.str(4), ','))
Franck Dary's avatar
Franck Dary committed
          columns.emplace_back(s);
        for (auto s : util::split(sm.str(5), ','))
          focusedBuffer.push_back(std::stoi(s));
        for (auto s : util::split(sm.str(6), ','))
          focusedStack.push_back(std::stoi(s));
        for (auto s : util::split(sm.str(7), ','))
Franck Dary's avatar
Franck Dary committed
          focusedColumns.emplace_back(s);
        for (auto s : util::split(sm.str(8), ','))
          maxNbElements.push_back(std::stoi(s));
Franck Dary's avatar
Franck Dary committed
        if (focusedColumns.size() != maxNbElements.size())
          util::myThrow("focusedColumns.size() != maxNbElements.size()");
        this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
Franck Dary's avatar
Franck Dary committed
      }
    },
      std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
      "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
Franck Dary's avatar
Franck Dary committed
      [this,topology](auto sm)
      {
        this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), std::stoi(sm.str(2)), std::stoi(sm.str(3))));
  std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);

  for (auto & initializer : initializers)
      if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
      {
        this->nn->to(NeuralNetworkImpl::device);
        return;
      }
    }
    catch (std::exception & e)
    {
      errorMessage = fmt::format("Caught({}) {}", e.what(), errorMessage);
      break;

  for (auto & initializer : initializers)
    errorMessage += std::get<1>(initializer) + "\n";

  util::myThrow(errorMessage);
}