Skip to content
Snippets Groups Projects
Commit 8adb3462 authored by Franck Dary's avatar Franck Dary
Browse files

code refactoring

parent c79e0a23
No related branches found
No related tags found
No related merge requests found
Showing
with 239 additions and 400 deletions
#include "Classifier.hpp" #include "Classifier.hpp"
#include "util.hpp" #include "util.hpp"
#include "ConcatWordsNetwork.hpp"
#include "RLTNetwork.hpp"
#include "CNNNetwork.hpp"
#include "LSTMNetwork.hpp" #include "LSTMNetwork.hpp"
#include "RandomNetwork.hpp" #include "RandomNetwork.hpp"
...@@ -40,45 +37,6 @@ void Classifier::initNeuralNetwork(const std::string & topology) ...@@ -40,45 +37,6 @@ void Classifier::initNeuralNetwork(const std::string & topology)
this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
} }
}, },
{
std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"),
"ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.",
[this,topology](auto sm)
{
std::vector<int> bufferContext, stackContext;
for (auto s : util::split(sm.str(1), ','))
bufferContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(2), ','))
stackContext.emplace_back(std::stoi(s));
this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), bufferContext, stackContext));
}
},
{
std::regex("CNN\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"CNN(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
[this,topology](auto sm)
{
std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext;
std::vector<std::string> focusedColumns, columns;
for (auto s : util::split(sm.str(2), ','))
bufferContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(3), ','))
stackContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(4), ','))
columns.emplace_back(s);
for (auto s : util::split(sm.str(5), ','))
focusedBuffer.push_back(std::stoi(s));
for (auto s : util::split(sm.str(6), ','))
focusedStack.push_back(std::stoi(s));
for (auto s : util::split(sm.str(7), ','))
focusedColumns.emplace_back(s);
for (auto s : util::split(sm.str(8), ','))
maxNbElements.push_back(std::stoi(s));
if (focusedColumns.size() != maxNbElements.size())
util::myThrow("focusedColumns.size() != maxNbElements.size()");
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
}
},
{ {
std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", "LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
...@@ -105,14 +63,6 @@ void Classifier::initNeuralNetwork(const std::string & topology) ...@@ -105,14 +63,6 @@ void Classifier::initNeuralNetwork(const std::string & topology)
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10)))); this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
} }
}, },
{
std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
[this,topology](auto sm)
{
this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), std::stoi(sm.str(2)), std::stoi(sm.str(3))));
}
},
}; };
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
......
...@@ -17,7 +17,7 @@ class CNNImpl : public torch::nn::Module ...@@ -17,7 +17,7 @@ class CNNImpl : public torch::nn::Module
CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize); CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
int getOutputSize(); std::size_t getOutputSize();
}; };
TORCH_MODULE(CNN); TORCH_MODULE(CNN);
......
#ifndef CNNNETWORK__H
#define CNNNETWORK__H
#include "NeuralNetwork.hpp"
#include "CNN.hpp"
class CNNNetworkImpl : public NeuralNetworkImpl
{
private :
int unknownValueThreshold;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
int leftWindowRawInput;
int rightWindowRawInput;
int rawInputSize;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout cnnDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
CNN contextCNN{nullptr};
CNN rawInputCNN{nullptr};
std::vector<CNN> cnns;
public :
CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
#ifndef CONCATWORDSNETWORK__H
#define CONCATWORDSNETWORK__H
#include "NeuralNetwork.hpp"
class ConcatWordsNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::Dropout dropout{nullptr};
public :
ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext);
torch::Tensor forward(torch::Tensor input) override;
};
#endif
#ifndef CONTEXTLSTM__H
#define CONTEXTLSTM__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"
class ContextLSTMImpl : public torch::nn::Module, public Submodule
{
private :
LSTM lstm{nullptr};
std::vector<std::string> columns;
std::vector<int> bufferContext;
std::vector<int> stackContext;
int unknownValueThreshold;
std::vector<std::string> unknownValueColumns{"FORM", "LEMMA"};
public :
ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(ContextLSTM);
#endif
#ifndef DEPTHLAYERTREEEMBEDDING__H
#define DEPTHLAYERTREEEMBEDDING__H
#include <torch/torch.h>
#include "fmt/core.h"
#include "LSTM.hpp"
class DepthLayerTreeEmbeddingImpl : public torch::nn::Module
{
private :
std::vector<LSTM> depthLstm;
int maxDepth;
int maxElemPerDepth;
public :
DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth);
torch::Tensor forward(torch::Tensor input);
int getOutputSize();
};
TORCH_MODULE(DepthLayerTreeEmbedding);
#endif
#ifndef FOCUSEDCOLUMNLSTM__H
#define FOCUSEDCOLUMNLSTM__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"
class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
{
private :
LSTM lstm{nullptr};
std::vector<int> focusedBuffer, focusedStack;
std::string column;
int maxNbElements;
public :
FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(FocusedColumnLSTM);
#endif
...@@ -6,6 +6,10 @@ ...@@ -6,6 +6,10 @@
class LSTMImpl : public torch::nn::Module class LSTMImpl : public torch::nn::Module
{ {
public :
using LSTMOptions = std::tuple<bool,bool,int,float,bool>;
private : private :
torch::nn::LSTM lstm{nullptr}; torch::nn::LSTM lstm{nullptr};
...@@ -13,7 +17,7 @@ class LSTMImpl : public torch::nn::Module ...@@ -13,7 +17,7 @@ class LSTMImpl : public torch::nn::Module
public : public :
LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options); LSTMImpl(int inputSize, int outputSize, LSTMOptions options);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength); int getOutputSize(int sequenceLength);
}; };
......
...@@ -2,29 +2,28 @@ ...@@ -2,29 +2,28 @@
#define LSTMNETWORK__H #define LSTMNETWORK__H
#include "NeuralNetwork.hpp" #include "NeuralNetwork.hpp"
#include "LSTM.hpp" #include "ContextLSTM.hpp"
#include "RawInputLSTM.hpp"
#include "SplitTransLSTM.hpp"
#include "FocusedColumnLSTM.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl class LSTMNetworkImpl : public NeuralNetworkImpl
{ {
private : private :
int unknownValueThreshold;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
int leftWindowRawInput;
int rightWindowRawInput;
int rawInputSize;
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout lstmDropout{nullptr}; torch::nn::Dropout lstmDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr}; torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr}; torch::nn::Linear linear2{nullptr};
LSTM contextLSTM{nullptr};
LSTM rawInputLSTM{nullptr}; ContextLSTM contextLSTM{nullptr};
LSTM splitTransLSTM{nullptr}; RawInputLSTM rawInputLSTM{nullptr};
std::vector<LSTM> lstms; SplitTransLSTM splitTransLSTM{nullptr};
std::vector<FocusedColumnLSTM> focusedLstms;
bool hasRawInputLSTM{false};
public : public :
......
#ifndef MLP__H
#define MLP__H
#include <torch/torch.h>
class MLPImpl : torch::nn::Module
{
public :
MLPImpl(const std::string & topology);
};
TORCH_MODULE(MLP);
#endif
...@@ -15,31 +15,10 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -15,31 +15,10 @@ class NeuralNetworkImpl : public torch::nn::Module
static constexpr int maxNbEmbeddings = 150000; static constexpr int maxNbEmbeddings = 150000;
std::vector<std::string> columns{"FORM"};
std::vector<int> bufferContext{-3,-2,-1,0,1};
std::vector<int> stackContext{};
std::vector<int> bufferFocused{};
std::vector<int> stackFocused{};
protected :
void setBufferContext(const std::vector<int> & bufferContext);
void setStackContext(const std::vector<int> & stackContext);
void setBufferFocused(const std::vector<int> & bufferFocused);
void setStackFocused(const std::vector<int> & stackFocused);
public : public :
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const; virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
std::vector<long> extractContextIndexes(const Config & config) const;
std::vector<long> extractFocusedIndexes(const Config & config) const;
int getContextSize() const;
void setColumns(const std::vector<std::string> & columns);
void addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const;
void addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const;
void addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const;
void addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const;
}; };
TORCH_MODULE(NeuralNetwork); TORCH_MODULE(NeuralNetwork);
......
#ifndef RLTNETWORK__H
#define RLTNETWORK__H
#include "NeuralNetwork.hpp"
class RLTNetworkImpl : public NeuralNetworkImpl
{
private :
static constexpr long maxNbChilds{8};
static inline std::vector<long> focusedBufferIndexes{0,1,2};
static inline std::vector<long> focusedStackIndexes{0,1};
int leftBorder, rightBorder;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::LSTM vectorBiLSTM{nullptr};
torch::nn::LSTM treeLSTM{nullptr};
torch::Tensor S;
torch::Tensor nullTree;
public :
RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
...@@ -13,6 +13,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl ...@@ -13,6 +13,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
RandomNetworkImpl(long outputSize); RandomNetworkImpl(long outputSize);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config &, Dict &) const override;
}; };
#endif #endif
#ifndef RAWINPUTLSTM__H
#define RAWINPUTLSTM__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"
class RawInputLSTMImpl : public torch::nn::Module, public Submodule
{
private :
LSTM lstm{nullptr};
int leftWindow, rightWindow;
public :
RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(RawInputLSTM);
#endif
#ifndef SPLITTRANSLSTM__H
#define SPLITTRANSLSTM__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"
class SplitTransLSTMImpl : public torch::nn::Module, public Submodule
{
private :
LSTM lstm{nullptr};
int maxNbTrans;
public :
SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(SplitTransLSTM);
#endif
#ifndef SUBMODULE__H
#define SUBMODULE__H
#include "Dict.hpp"
#include "Config.hpp"
class Submodule
{
protected :
std::size_t firstInputIndex{0};
public :
void setFirstInputIndex(std::size_t firstInputIndex);
virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
};
#endif
...@@ -26,7 +26,7 @@ torch::Tensor CNNImpl::forward(torch::Tensor input) ...@@ -26,7 +26,7 @@ torch::Tensor CNNImpl::forward(torch::Tensor input)
return cnnOut; return cnnOut;
} }
int CNNImpl::getOutputSize() std::size_t CNNImpl::getOutputSize()
{ {
return windowSizes.size()*nbFilters; return windowSizes.size()*nbFilters;
} }
......
#include "CNNNetwork.hpp"
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 1024;
constexpr int nbFiltersContext = 512;
constexpr int nbFiltersFocused = 64;
setBufferContext(bufferContext);
setStackContext(stackContext);
setColumns(columns);
setBufferFocused(focusedBufferIndexes);
setStackFocused(focusedStackIndexes);
rawInputSize = leftWindowRawInput + rightWindowRawInput + 1;
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
rawInputSize = 0;
else
rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
for (auto & col : focusedColumns)
{
std::vector<int> windows{2,3,4};
cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize)));
totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size());
}
linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
}
torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto context = embeddings.narrow(1, rawInputSize, getContextSize());
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
std::vector<torch::Tensor> cnnOutputs;
if (rawInputSize != 0)
{
auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
}
auto curIndex = 0;
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
long nbElements = maxNbElements[i];
for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
{
auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1);
curIndex += nbElements;
cnnOutputs.emplace_back(cnns[i](cnnInput));
}
}
cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1));
return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
}
std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{
if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<std::vector<long>> context;
context.emplace_back();
if (rawInputSize > 0)
{
for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
for (auto index : contextIndexes)
for (auto & col : columns)
if (index == -1)
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
if (col == "FORM" || col == "LEMMA")
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
{
context.emplace_back(context.back());
context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
}
}
std::vector<long> focusedIndexes = extractFocusedIndexes(config);
for (auto & contextElement : context)
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{
auto & col = focusedColumns[colIndex];
for (auto index : focusedIndexes)
{
if (index == -1)
{
for (int i = 0; i < maxNbElements[colIndex]; i++)
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue;
}
std::vector<std::string> elements;
if (col == "FORM")
{
auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("{}", asUtf8[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "FEATS")
{
auto splited = util::split(config.getAsFeature(col, index).get(), '|');
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)splited.size())
elements.emplace_back(fmt::format("FEATS({})", splited[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
}
else
{
elements.emplace_back(config.getAsFeature(col, index));
}
if ((int)elements.size() != maxNbElements[colIndex])
util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
for (auto & element : elements)
contextElement.emplace_back(dict.getIndexOrInsert(element));
}
}
if (!is_training() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
return context;
}
#include "ConcatWordsNetwork.hpp"
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 500;
setBufferContext(bufferContext);
setStackContext(stackContext);
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3));
}
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto wordsAsEmb = dropout(wordEmbeddings(input).view({input.size(0), -1}));
return linear2(torch::relu(linear1(wordsAsEmb)));
}
#include "ContextLSTM.hpp"
ContextLSTMImpl::ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold) : columns(columns), bufferContext(bufferContext), stackContext(stackContext), unknownValueThreshold(unknownValueThreshold)
{
lstm = register_module("lstm", LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options));
}
std::size_t ContextLSTMImpl::getOutputSize()
{
return lstm->getOutputSize(bufferContext.size()+stackContext.size());
}
std::size_t ContextLSTMImpl::getInputSize()
{
return columns.size()*(bufferContext.size()+stackContext.size());
}
void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
std::vector<long> contextIndexes;
for (int index : bufferContext)
contextIndexes.emplace_back(config.getRelativeWordIndex(index));
for (int index : stackContext)
if (config.hasStack(index))
contextIndexes.emplace_back(config.getStack(index));
else
contextIndexes.emplace_back(-1);
for (auto index : contextIndexes)
for (auto & col : columns)
if (index == -1)
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
for (auto & targetCol : unknownValueColumns)
if (col == targetCol)
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
{
context.emplace_back(context.back());
context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
}
}
}
torch::Tensor ContextLSTMImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)});
return lstm(context);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment