Newer
Older
Franck Dary
committed
#include "Classifier.hpp"
#include "util.hpp"
#include "OneWordNetwork.hpp"
#include "ConcatWordsNetwork.hpp"
#include "RLTNetwork.hpp"
Franck Dary
committed
Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
{
this->name = name;
this->transitionSet.reset(new TransitionSet(tsFile));
Franck Dary
committed
}
TransitionSet & Classifier::getTransitionSet()
{
return *transitionSet;
}
return reinterpret_cast<NeuralNetwork&>(nn);
const std::string & Classifier::getName() const
{
return name;
}
void Classifier::initNeuralNetwork(const std::string & topology)
{
static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
{
{
"OneWord(focusedIndex) : Only use the word embedding of the focused word.",
[this,topology](auto sm)
{
this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm[1])));
}
},
{
std::regex("ConcatWords\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"ConcatWords(leftBorder,rightBorder,nbStack) : Concatenate embeddings of words in context.",
[this,topology](auto sm)
this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"),
"CNN(leftBorder,rightBorder,nbStack,{focusedBuffer},{focusedStack},{focusedColumns}) : CNN to capture context.",
std::vector<long> focusedBuffer, focusedStack;
std::vector<std::string> focusedColumns, columns;
for (auto s : util::split(std::string(sm[4]), ','))
columns.emplace_back(s);
for (auto s : util::split(std::string(sm[5]), ','))
focusedBuffer.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[6]), ','))
focusedStack.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[7]), ','))
focusedColumns.emplace_back(s);
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns));
std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
};
for (auto & initializer : initializers)
if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
{
this->nn->to(NeuralNetworkImpl::device);
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
for (auto & initializer : initializers)
errorMessage += std::get<1>(initializer) + "\n";
util::myThrow(errorMessage);
}