Commit 8adb3462 authored by Franck Dary's avatar Franck Dary
Browse files

code refactoring

parent c79e0a23
#include "Classifier.hpp"
#include "util.hpp"
#include "ConcatWordsNetwork.hpp"
#include "RLTNetwork.hpp"
#include "CNNNetwork.hpp"
#include "LSTMNetwork.hpp"
#include "RandomNetwork.hpp"
......@@ -40,45 +37,6 @@ void Classifier::initNeuralNetwork(const std::string & topology)
this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
}
},
{
std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"),
"ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.",
[this,topology](auto sm)
{
std::vector<int> bufferContext, stackContext;
for (auto s : util::split(sm.str(1), ','))
bufferContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(2), ','))
stackContext.emplace_back(std::stoi(s));
this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), bufferContext, stackContext));
}
},
{
std::regex("CNN\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"CNN(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
[this,topology](auto sm)
{
std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext;
std::vector<std::string> focusedColumns, columns;
for (auto s : util::split(sm.str(2), ','))
bufferContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(3), ','))
stackContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(4), ','))
columns.emplace_back(s);
for (auto s : util::split(sm.str(5), ','))
focusedBuffer.push_back(std::stoi(s));
for (auto s : util::split(sm.str(6), ','))
focusedStack.push_back(std::stoi(s));
for (auto s : util::split(sm.str(7), ','))
focusedColumns.emplace_back(s);
for (auto s : util::split(sm.str(8), ','))
maxNbElements.push_back(std::stoi(s));
if (focusedColumns.size() != maxNbElements.size())
util::myThrow("focusedColumns.size() != maxNbElements.size()");
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
}
},
{
std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
......@@ -105,14 +63,6 @@ void Classifier::initNeuralNetwork(const std::string & topology)
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
}
},
{
std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
[this,topology](auto sm)
{
this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), std::stoi(sm.str(2)), std::stoi(sm.str(3))));
}
},
};
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
......
......@@ -17,7 +17,7 @@ class CNNImpl : public torch::nn::Module
CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize);
torch::Tensor forward(torch::Tensor input);
int getOutputSize();
std::size_t getOutputSize();
};
TORCH_MODULE(CNN);
......
#ifndef CNNNETWORK__H
#define CNNNETWORK__H
#include "NeuralNetwork.hpp"
#include "CNN.hpp"
class CNNNetworkImpl : public NeuralNetworkImpl
{
private :
int unknownValueThreshold;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
int leftWindowRawInput;
int rightWindowRawInput;
int rawInputSize;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout cnnDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
CNN contextCNN{nullptr};
CNN rawInputCNN{nullptr};
std::vector<CNN> cnns;
public :
CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
#ifndef CONCATWORDSNETWORK__H
#define CONCATWORDSNETWORK__H
#include "NeuralNetwork.hpp"
class ConcatWordsNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::Dropout dropout{nullptr};
public :
ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext);
torch::Tensor forward(torch::Tensor input) override;
};
#endif
#ifndef CONTEXTLSTM__H
#define CONTEXTLSTM__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"
class ContextLSTMImpl : public torch::nn::Module, public Submodule
{
private :
LSTM lstm{nullptr};
std::vector<std::string> columns;
std::vector<int> bufferContext;
std::vector<int> stackContext;
int unknownValueThreshold;
std::vector<std::string> unknownValueColumns{"FORM", "LEMMA"};
public :
ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(ContextLSTM);
#endif
#ifndef DEPTHLAYERTREEEMBEDDING__H
#define DEPTHLAYERTREEEMBEDDING__H
#include <torch/torch.h>
#include "fmt/core.h"
#include "LSTM.hpp"
class DepthLayerTreeEmbeddingImpl : public torch::nn::Module
{
private :
std::vector<LSTM> depthLstm;
int maxDepth;
int maxElemPerDepth;
public :
DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth);
torch::Tensor forward(torch::Tensor input);
int getOutputSize();
};
TORCH_MODULE(DepthLayerTreeEmbedding);
#endif
#ifndef FOCUSEDCOLUMNLSTM__H
#define FOCUSEDCOLUMNLSTM__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"
class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
{
private :
LSTM lstm{nullptr};
std::vector<int> focusedBuffer, focusedStack;
std::string column;
int maxNbElements;
public :
FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(FocusedColumnLSTM);
#endif
......@@ -6,6 +6,10 @@
class LSTMImpl : public torch::nn::Module
{
public :
using LSTMOptions = std::tuple<bool,bool,int,float,bool>;
private :
torch::nn::LSTM lstm{nullptr};
......@@ -13,7 +17,7 @@ class LSTMImpl : public torch::nn::Module
public :
LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options);
LSTMImpl(int inputSize, int outputSize, LSTMOptions options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
......
......@@ -2,29 +2,28 @@
#define LSTMNETWORK__H
#include "NeuralNetwork.hpp"
#include "LSTM.hpp"
#include "ContextLSTM.hpp"
#include "RawInputLSTM.hpp"
#include "SplitTransLSTM.hpp"
#include "FocusedColumnLSTM.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl
{
private :
int unknownValueThreshold;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
int leftWindowRawInput;
int rightWindowRawInput;
int rawInputSize;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout lstmDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
LSTM contextLSTM{nullptr};
LSTM rawInputLSTM{nullptr};
LSTM splitTransLSTM{nullptr};
std::vector<LSTM> lstms;
ContextLSTM contextLSTM{nullptr};
RawInputLSTM rawInputLSTM{nullptr};
SplitTransLSTM splitTransLSTM{nullptr};
std::vector<FocusedColumnLSTM> focusedLstms;
bool hasRawInputLSTM{false};
public :
......
#ifndef MLP__H
#define MLP__H
#include <torch/torch.h>
class MLPImpl : torch::nn::Module
{
public :
MLPImpl(const std::string & topology);
};
TORCH_MODULE(MLP);
#endif
......@@ -15,31 +15,10 @@ class NeuralNetworkImpl : public torch::nn::Module
static constexpr int maxNbEmbeddings = 150000;
std::vector<std::string> columns{"FORM"};
std::vector<int> bufferContext{-3,-2,-1,0,1};
std::vector<int> stackContext{};
std::vector<int> bufferFocused{};
std::vector<int> stackFocused{};
protected :
void setBufferContext(const std::vector<int> & bufferContext);
void setStackContext(const std::vector<int> & stackContext);
void setBufferFocused(const std::vector<int> & bufferFocused);
void setStackFocused(const std::vector<int> & stackFocused);
public :
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const;
std::vector<long> extractContextIndexes(const Config & config) const;
std::vector<long> extractFocusedIndexes(const Config & config) const;
int getContextSize() const;
void setColumns(const std::vector<std::string> & columns);
void addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const;
void addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const;
void addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const;
void addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
};
TORCH_MODULE(NeuralNetwork);
......
#ifndef RLTNETWORK__H
#define RLTNETWORK__H
#include "NeuralNetwork.hpp"
class RLTNetworkImpl : public NeuralNetworkImpl
{
private :
static constexpr long maxNbChilds{8};
static inline std::vector<long> focusedBufferIndexes{0,1,2};
static inline std::vector<long> focusedStackIndexes{0,1};
int leftBorder, rightBorder;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::LSTM vectorBiLSTM{nullptr};
torch::nn::LSTM treeLSTM{nullptr};
torch::Tensor S;
torch::Tensor nullTree;
public :
RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
......@@ -13,6 +13,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
RandomNetworkImpl(long outputSize);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config &, Dict &) const override;
};
#endif
#ifndef RAWINPUTLSTM__H
#define RAWINPUTLSTM__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"
class RawInputLSTMImpl : public torch::nn::Module, public Submodule
{
private :
LSTM lstm{nullptr};
int leftWindow, rightWindow;
public :
RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(RawInputLSTM);
#endif
#ifndef SPLITTRANSLSTM__H
#define SPLITTRANSLSTM__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"
class SplitTransLSTMImpl : public torch::nn::Module, public Submodule
{
private :
LSTM lstm{nullptr};
int maxNbTrans;
public :
SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(SplitTransLSTM);
#endif
#ifndef SUBMODULE__H
#define SUBMODULE__H
#include "Dict.hpp"
#include "Config.hpp"
class Submodule
{
protected :
std::size_t firstInputIndex{0};
public :
void setFirstInputIndex(std::size_t firstInputIndex);
virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
};
#endif
......@@ -26,7 +26,7 @@ torch::Tensor CNNImpl::forward(torch::Tensor input)
return cnnOut;
}
int CNNImpl::getOutputSize()
std::size_t CNNImpl::getOutputSize()
{
return windowSizes.size()*nbFilters;
}
......
#include "CNNNetwork.hpp"
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 1024;
constexpr int nbFiltersContext = 512;
constexpr int nbFiltersFocused = 64;
setBufferContext(bufferContext);
setStackContext(stackContext);
setColumns(columns);
setBufferFocused(focusedBufferIndexes);
setStackFocused(focusedStackIndexes);
rawInputSize = leftWindowRawInput + rightWindowRawInput + 1;
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
rawInputSize = 0;
else
rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
for (auto & col : focusedColumns)
{
std::vector<int> windows{2,3,4};
cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize)));
totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size());
}
linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}
torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto context = embeddings.narrow(1, rawInputSize, getContextSize());
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
std::vector<torch::Tensor> cnnOutputs;
if (rawInputSize != 0)
{
auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
}
auto curIndex = 0;
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
long nbElements = maxNbElements[i];
for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
{
auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1);
curIndex += nbElements;
cnnOutputs.emplace_back(cnns[i](cnnInput));
}
}
cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1));
return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
}
std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{
if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<std::vector<long>> context;
context.emplace_back();
if (rawInputSize > 0)
{
for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
for (auto index : contextIndexes)
for (auto & col : columns)
if (index == -1)
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())