Newer
Older
Franck Dary
committed
#include "Classifier.hpp"
Franck Dary
committed
Classifier::Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFiles)
Franck Dary
committed
{
this->name = name;
Franck Dary
committed
}
TransitionSet & Classifier::getTransitionSet()
{
return *transitionSet;
}
return reinterpret_cast<NeuralNetwork&>(nn);
const std::string & Classifier::getName() const
{
return name;
}
void Classifier::initNeuralNetwork(const std::string & topology)
{
static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers
{
{
std::regex("Random"),
"Random : Output is chosen at random.",
[this,topology](auto sm)
{
this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
}
},
std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext;
for (auto s : util::split(sm.str(2), ','))
bufferContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(3), ','))
stackContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(4), ','))
for (auto s : util::split(sm.str(5), ','))
focusedBuffer.push_back(std::stoi(s));
for (auto s : util::split(sm.str(6), ','))
focusedStack.push_back(std::stoi(s));
for (auto s : util::split(sm.str(7), ','))
for (auto s : util::split(sm.str(8), ','))
maxNbElements.push_back(std::stoi(s));
if (focusedColumns.size() != maxNbElements.size())
util::myThrow("focusedColumns.size() != maxNbElements.size()");
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
{
this->nn->to(NeuralNetworkImpl::device);
return;
}
}
catch (std::exception & e)
{
errorMessage = fmt::format("Caught({}) {}", e.what(), errorMessage);
break;
for (auto & initializer : initializers)
errorMessage += std::get<1>(initializer) + "\n";
util::myThrow(errorMessage);
}