Skip to content
Snippets Groups Projects
Classifier.cpp 1.61 KiB
Newer Older
#include "util.hpp"
#include "OneWordNetwork.hpp"
#include "ConcatWordsNetwork.hpp"

Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
{
  this->name = name;
  this->transitionSet.reset(new TransitionSet(tsFile));
  initNeuralNetwork(topology);
TransitionSet & Classifier::getTransitionSet()
{
  return *transitionSet;
}

NeuralNetwork & Classifier::getNN()
Franck Dary's avatar
Franck Dary committed
{
  return reinterpret_cast<NeuralNetwork&>(nn);
Franck Dary's avatar
Franck Dary committed
}

Franck Dary's avatar
Franck Dary committed
const std::string & Classifier::getName() const
{
  return name;
}

void Classifier::initNeuralNetwork(const std::string & topology)
{
  static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
  {
    {
      std::regex("OneWord\\((\\d+)\\)"),
      "OneWord(focusedIndex) : Only use the word embedding of the focused word.",
      [this,topology](auto sm)
      {
        this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm[1])));
      }
    },
    {
      std::regex("ConcatWords"),
      "ConcatWords : Concatenate embeddings of words in context.",
      [this,topology](auto)
      {
        this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size()));
      }
    },
  };

  for (auto & initializer : initializers)
    if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
      return;

  std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
  for (auto & initializer : initializers)
    errorMessage += std::get<1>(initializer) + "\n";

  util::myThrow(errorMessage);
}