Skip to content
Snippets Groups Projects
Commit cac30a69 authored by Franck Dary's avatar Franck Dary
Browse files

changed the way classifiers are initialized

parent 4165cd23
No related branches found
No related tags found
No related merge requests found
...@@ -15,11 +15,12 @@ class Classifier ...@@ -15,11 +15,12 @@ class Classifier
private : private :
void initNeuralNetwork(const std::string & topology); void initNeuralNetwork(const std::vector<std::string> & definition);
void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex);
public : public :
Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFile); Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition);
TransitionSet & getTransitionSet(); TransitionSet & getTransitionSet();
NeuralNetwork & getNN(); NeuralNetwork & getNN();
const std::string & getName() const; const std::string & getName() const;
......
...@@ -3,11 +3,21 @@ ...@@ -3,11 +3,21 @@
#include "LSTMNetwork.hpp" #include "LSTMNetwork.hpp"
#include "RandomNetwork.hpp" #include "RandomNetwork.hpp"
Classifier::Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFiles) Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
{ {
this->name = name; this->name = name;
this->transitionSet.reset(new TransitionSet(tsFiles)); if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
initNeuralNetwork(topology); {
std::vector<std::string> tsFiles;
for (auto & tsFilename : util::split(sm.str(1), ' '))
tsFiles.emplace_back(path.parent_path() / tsFilename);
this->transitionSet.reset(new TransitionSet(tsFiles));
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
initNeuralNetwork(definition);
} }
TransitionSet & Classifier::getTransitionSet() TransitionSet & Classifier::getTransitionSet()
...@@ -25,66 +35,177 @@ const std::string & Classifier::getName() const ...@@ -25,66 +35,177 @@ const std::string & Classifier::getName() const
return name; return name;
} }
void Classifier::initNeuralNetwork(const std::string & topology) void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
{
std::size_t curIndex = 1;
std::string networkType;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
{
networkType = sm.str(1);
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));
if (networkType == "Random")
this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
else if (networkType == "LSTM")
initLSTM(definition, curIndex);
else
util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));
}
void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex)
{ {
static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers int unknownValueThreshold;
{ std::vector<int> bufferContext, stackContext;
{ std::vector<std::string> columns;
std::regex("Random"), std::vector<int> focusedBuffer, focusedStack;
"Random : Output is chosen at random.", std::vector<std::string> focusedColumns;
[this,topology](auto sm) std::vector<int> maxNbElements;
{ int rawInputLeftWindow, rawInputRightWindow;
this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); int embeddingsSize, hiddenSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers;
} bool bilstm;
}, float lstmDropout;
{
std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
"LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", {
[this,topology](auto sm) unknownValueThreshold = std::stoi(sm.str(1));
{ curIndex++;
std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext; }))
std::vector<std::string> focusedColumns, columns; util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold"));
for (auto s : util::split(sm.str(2), ','))
bufferContext.emplace_back(std::stoi(s)); if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm)
for (auto s : util::split(sm.str(3), ',')) {
stackContext.emplace_back(std::stoi(s)); for (auto & index : util::split(sm.str(1), ' '))
for (auto s : util::split(sm.str(4), ',')) bufferContext.emplace_back(std::stoi(index));
columns.emplace_back(s); curIndex++;
for (auto s : util::split(sm.str(5), ',')) }))
focusedBuffer.push_back(std::stoi(s)); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}"));
for (auto s : util::split(sm.str(6), ','))
focusedStack.push_back(std::stoi(s)); if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm)
for (auto s : util::split(sm.str(7), ',')) {
focusedColumns.emplace_back(s); for (auto & index : util::split(sm.str(1), ' '))
for (auto s : util::split(sm.str(8), ',')) stackContext.emplace_back(std::stoi(index));
maxNbElements.push_back(std::stoi(s)); curIndex++;
if (focusedColumns.size() != maxNbElements.size()) }))
util::myThrow("focusedColumns.size() != maxNbElements.size()"); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}"));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
} if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm)
}, {
}; columns = util::split(sm.str(1), ' ');
curIndex++;
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); }))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}"));
for (auto & initializer : initializers)
try if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm)
{ {
if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer))) for (auto & index : util::split(sm.str(1), ' '))
{ focusedBuffer.emplace_back(std::stoi(index));
this->nn->to(NeuralNetworkImpl::device); curIndex++;
return; }))
} util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}"));
}
catch (std::exception & e) if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm)
{ {
errorMessage = fmt::format("Caught({}) {}", e.what(), errorMessage); for (auto & index : util::split(sm.str(1), ' '))
break; focusedStack.emplace_back(std::stoi(index));
} curIndex++;
}))
for (auto & initializer : initializers) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}"));
errorMessage += std::get<1>(initializer) + "\n";
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm)
util::myThrow(errorMessage); {
focusedColumns = util::split(sm.str(1), ' ');
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
maxNbElements.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm)
{
rawInputLeftWindow = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm)
{
rawInputRightWindow = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm)
{
embeddingsSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Hidden size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&hiddenSize](auto sm)
{
hiddenSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Hidden size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
{
contextLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm)
{
focusedLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm)
{
rawInputLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm)
{
splitTransLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm)
{
nbLayers = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm)
{
bilstm = sm.str(1) == "true";
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm)
{
lstmDropout = std::stof(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, hiddenSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout));
} }
...@@ -57,18 +57,23 @@ void ReadingMachine::readFromFile(std::filesystem::path path) ...@@ -57,18 +57,23 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];})) if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
util::myThrow("No name specified"); util::myThrow("No name specified");
while (util::doIfNameMatch(std::regex("Classifier : (.+) (.+) \\{(.+)\\}"), lines[curLine++], [this,path](auto sm) while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm)
{ {
std::vector<std::string> tsFiles = util::split(sm.str(3), ' '); std::vector<std::string> classifierDefinition;
for (auto & tsFile : tsFiles) if (lines[curLine] != "{")
tsFile = path.parent_path() / tsFile; util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
classifier.reset(new Classifier(sm.str(1), sm.str(2), tsFiles));
for (curLine++; curLine < lines.size(); curLine++)
{
if (lines[curLine] == "}")
break;
classifierDefinition.emplace_back(lines[curLine]);
}
classifier.reset(new Classifier(sm.str(1), path, classifierDefinition));
})); }));
if (!classifier.get()) if (!classifier.get())
util::myThrow("No Classifier specified"); util::myThrow("No Classifier specified");
--curLine;
util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm) util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
{ {
this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1))); this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1)));
......
...@@ -27,7 +27,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl ...@@ -27,7 +27,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
public : public :
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
}; };
......
#include "LSTMNetwork.hpp" #include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
{ {
constexpr int embeddingsSize = 256; LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
constexpr int hiddenSize = 8192;
constexpr int contextLSTMSize = 1024;
constexpr int focusedLSTMSize = 256;
constexpr int rawInputLSTMSize = 32;
LSTMImpl::LSTMOptions lstmOptions{true,true,2,0.3,false};
auto lstmOptionsAll = lstmOptions; auto lstmOptionsAll = lstmOptions;
std::get<4>(lstmOptionsAll) = true; std::get<4>(lstmOptionsAll) = true;
...@@ -29,7 +23,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: ...@@ -29,7 +23,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
currentInputSize += rawInputLSTM->getInputSize(); currentInputSize += rawInputLSTM->getInputSize();
} }
splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, embeddingsSize, lstmOptionsAll)); splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll));
splitTransLSTM->setFirstInputIndex(currentInputSize); splitTransLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += splitTransLSTM->getOutputSize(); currentOutputSize += splitTransLSTM->getOutputSize();
currentInputSize += splitTransLSTM->getInputSize(); currentInputSize += splitTransLSTM->getInputSize();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment