#include "Classifier.hpp" #include "util.hpp" #include "OneWordNetwork.hpp" #include "ConcatWordsNetwork.hpp" #include "RLTNetwork.hpp" #include "CNNNetwork.hpp" #include "LSTMNetwork.hpp" #include "RandomNetwork.hpp" Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) { this->name = name; this->transitionSet.reset(new TransitionSet(tsFile)); initNeuralNetwork(topology); } TransitionSet & Classifier::getTransitionSet() { return *transitionSet; } NeuralNetwork & Classifier::getNN() { return reinterpret_cast<NeuralNetwork&>(nn); } const std::string & Classifier::getName() const { return name; } void Classifier::initNeuralNetwork(const std::string & topology) { static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers { { std::regex("Random"), "Random : Output is chosen at random.", [this,topology](auto sm) { this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); } }, { std::regex("OneWord\\(([+\\-]?\\d+)\\)"), "OneWord(focusedIndex) : Only use the word embedding of the focused word.", [this,topology](auto sm) { this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)))); } }, { std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"), "ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.", [this,topology](auto sm) { std::vector<int> bufferContext, stackContext; for (auto s : util::split(sm.str(1), ',')) bufferContext.emplace_back(std::stoi(s)); for (auto s : util::split(sm.str(2), ',')) stackContext.emplace_back(std::stoi(s)); this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), bufferContext, stackContext)); } }, { std::regex("CNN\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), "CNN(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", [this,topology](auto sm) { std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext; std::vector<std::string> focusedColumns, columns; for (auto s : util::split(sm.str(2), ',')) bufferContext.emplace_back(std::stoi(s)); for (auto s : util::split(sm.str(3), ',')) stackContext.emplace_back(std::stoi(s)); for (auto s : util::split(sm.str(4), ',')) columns.emplace_back(s); for (auto s : util::split(sm.str(5), ',')) focusedBuffer.push_back(std::stoi(s)); for (auto s : util::split(sm.str(6), ',')) focusedStack.push_back(std::stoi(s)); for (auto s : util::split(sm.str(7), ',')) focusedColumns.emplace_back(s); for (auto s : util::split(sm.str(8), ',')) maxNbElements.push_back(std::stoi(s)); if (focusedColumns.size() != maxNbElements.size()) util::myThrow("focusedColumns.size() != maxNbElements.size()"); this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10)))); } }, { std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), "LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", [this,topology](auto sm) { std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext; std::vector<std::string> focusedColumns, columns; for (auto s : util::split(sm.str(2), ',')) bufferContext.emplace_back(std::stoi(s)); for (auto s : util::split(sm.str(3), ',')) stackContext.emplace_back(std::stoi(s)); for (auto s : util::split(sm.str(4), ',')) columns.emplace_back(s); for (auto s : util::split(sm.str(5), ',')) focusedBuffer.push_back(std::stoi(s)); for (auto s : util::split(sm.str(6), ',')) focusedStack.push_back(std::stoi(s)); for (auto s : util::split(sm.str(7), ',')) focusedColumns.emplace_back(s); for (auto s : util::split(sm.str(8), ',')) maxNbElements.push_back(std::stoi(s)); if (focusedColumns.size() != maxNbElements.size()) util::myThrow("focusedColumns.size() != maxNbElements.size()"); this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10)))); } }, { std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", [this,topology](auto sm) { this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), std::stoi(sm.str(2)), std::stoi(sm.str(3)))); } }, }; std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); for (auto & initializer : initializers) try { if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer))) { this->nn->to(NeuralNetworkImpl::device); return; } } catch (std::exception & e) { errorMessage = fmt::format("Caught({}) {}", e.what(), errorMessage); break; } for (auto & initializer : initializers) errorMessage += std::get<1>(initializer) + "\n"; util::myThrow(errorMessage); }