Skip to content
Snippets Groups Projects
Commit 5aafa66a authored by Franck Dary's avatar Franck Dary
Browse files

Implemented RTLSTM

parent d145be52
No related branches found
No related tags found
No related merge requests found
...@@ -114,7 +114,7 @@ void Config::print(FILE * dest) const ...@@ -114,7 +114,7 @@ void Config::print(FILE * dest) const
void Config::printForDebug(FILE * dest) const void Config::printForDebug(FILE * dest) const
{ {
static constexpr int windowSize = 5; static constexpr int windowSize = 10;
static constexpr int lettersWindowSize = 40; static constexpr int lettersWindowSize = 40;
static constexpr int maxWordLength = 7; static constexpr int maxWordLength = 7;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
class NeuralNetworkImpl : public torch::nn::Module class NeuralNetworkImpl : public torch::nn::Module
{ {
private : protected :
int leftBorder{5}; int leftBorder{5};
int rightBorder{5}; int rightBorder{5};
...@@ -23,7 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -23,7 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module
public : public :
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
std::vector<long> extractContext(Config & config, Dict & dict) const; virtual std::vector<long> extractContext(Config & config, Dict & dict) const;
int getContextSize() const; int getContextSize() const;
void setColumns(const std::vector<std::string> & columns); void setColumns(const std::vector<std::string> & columns);
}; };
......
...@@ -7,16 +7,23 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl ...@@ -7,16 +7,23 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl
{ {
private : private :
static constexpr long maxNbChilds{8};
static inline std::vector<long> focusedBufferIndexes{0,1,2};
static inline std::vector<long> focusedStackIndexes{0,1};
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr}; torch::nn::Linear linear2{nullptr};
torch::nn::Dropout dropout{nullptr}; torch::nn::LSTM vectorBiLSTM{nullptr};
torch::nn::LSTM lstm{nullptr}; torch::nn::LSTM treeLSTM{nullptr};
torch::Tensor S;
torch::Tensor nullTree;
public : public :
RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override;
}; };
#endif #endif
...@@ -3,31 +3,176 @@ ...@@ -3,31 +3,176 @@
RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{ {
constexpr int embeddingsSize = 30; constexpr int embeddingsSize = 30;
constexpr int lstmOutputSize = 500; constexpr int lstmOutputSize = 128;
constexpr int treeEmbeddingsSize = 256;
constexpr int hiddenSize = 500; constexpr int hiddenSize = 500;
setLeftBorder(leftBorder); setLeftBorder(leftBorder);
setRightBorder(rightBorder); setRightBorder(rightBorder);
setNbStackElements(nbStackElements); setNbStackElements(nbStackElements);
setColumns({"FORM", "UPOS"}); setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize)); linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3)); vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
lstm = register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, lstmOutputSize).batch_first(true))); treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false)));
S = register_parameter("S", torch::randn(treeEmbeddingsSize));
nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
} }
torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
{ {
// input dim = {batch, sequence, embeddings} input = input.squeeze();
auto wordsAsEmb = wordEmbeddings(input); if (input.dim() != 1)
if (wordsAsEmb.dim() == 2) util::myThrow(fmt::format("Does not support batched input (dim()={})", input.dim()));
wordsAsEmb = torch::unsqueeze(wordsAsEmb, 0);
auto lstmOut = lstm(wordsAsEmb).output; auto focusedIndexes = input.narrow(0, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
// reshaped dim = {sequence, batch, embeddings} auto computeOrder = input.narrow(0, focusedIndexes.size(0), leftBorder+rightBorder+1);
auto reshaped = lstmOut.permute({1,0,2}); auto childsFlat = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0), maxNbChilds*(leftBorder+rightBorder+1));
auto res = linear2(torch::relu(linear1(reshaped[-1]))); auto childs = torch::reshape(childsFlat, {computeOrder.size(0), maxNbChilds});
auto wordIndexes = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0)+childsFlat.size(0), columns.size()*(leftBorder+rightBorder+1));
return res; auto baseEmbeddings = wordEmbeddings(wordIndexes);
auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {(int)baseEmbeddings.size(0)/(int)columns.size(), (int)baseEmbeddings.size(1)*(int)columns.size()}).unsqueeze(0);
auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output.squeeze();
std::vector<torch::Tensor> treeRepresentations(vectorRepresentations.size(0), nullTree);
for (unsigned int i = 0; i < computeOrder.size(0); i++)
{
int index = computeOrder[i].item<int>();
if (index == -1)
break;
std::vector<torch::Tensor> inputVector;
inputVector.emplace_back(torch::cat({vectorRepresentations[index], S}, 0));
for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
{
int child = childs[index][childIndex].item<int>();
if (child == -1)
break;
inputVector.emplace_back(torch::cat({vectorRepresentations[index], treeRepresentations[child]}, 0));
}
auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
treeRepresentations[index] = lstmOut;
}
std::vector<torch::Tensor> focusedTrees;
for (unsigned int i = 0; i < focusedIndexes.size(0); i++)
{
int index = focusedIndexes[i].item<int>();
if (index == -1)
focusedTrees.emplace_back(nullTree);
else
focusedTrees.emplace_back(treeRepresentations[index]);
}
auto representation = torch::cat(focusedTrees, 0);
return linear2(torch::relu(linear1(representation)));
}
std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<long> contextIndexes;
std::stack<int> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
if (config.isToken(index))
leftContext.push(index);
while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size())
contextIndexes.emplace_back(-1);
while (!leftContext.empty())
{
contextIndexes.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index)
if (config.isToken(index))
contextIndexes.emplace_back(index);
while ((int)contextIndexes.size() < leftBorder+rightBorder+1)
contextIndexes.emplace_back(-1);
std::map<long, long> indexInContext;
for (auto & l : contextIndexes)
indexInContext.emplace(std::make_pair(l, indexInContext.size()));
std::vector<long> headOf;
for (auto & l : contextIndexes)
{
if (l == -1)
headOf.push_back(-1);
else
{
auto & head = config.getAsFeature(Config::headColName, l);
if (util::isEmpty(head) or head == "_")
headOf.push_back(-1);
else if (indexInContext.count(std::stoi(head)))
headOf.push_back(std::stoi(head));
else
headOf.push_back(-1);
}
}
std::vector<std::vector<long>> childs(headOf.size());
for (unsigned int i = 0; i < headOf.size(); i++)
if (headOf[i] != -1)
childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]);
std::vector<long> treeComputationOrder;
std::vector<bool> treeIsComputed(contextIndexes.size(), false);
std::function<void(long)> depthFirst;
depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root)
{
if (!indexInContext.count(root))
return;
if (treeIsComputed[indexInContext[root]])
return;
for (auto child : childs[indexInContext[root]])
depthFirst(child);
treeIsComputed[indexInContext[root]] = true;
treeComputationOrder.push_back(indexInContext[root]);
};
for (auto & l : focusedBufferIndexes)
if (contextIndexes[leftBorder+l] != -1)
depthFirst(contextIndexes[leftBorder+l]);
for (auto & l : focusedStackIndexes)
if (config.hasStack(l))
depthFirst(config.getStack(l));
std::vector<long> context;
for (auto & c : focusedBufferIndexes)
context.push_back(leftBorder+c);
for (auto & c : focusedStackIndexes)
if (config.hasStack(c) && indexInContext.count(config.getStack(c)))
context.push_back(indexInContext[config.getStack(c)]);
else
context.push_back(-1);
for (auto & c : treeComputationOrder)
context.push_back(c);
while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size())
context.push_back(-1);
for (auto & c : childs)
{
for (unsigned int i = 0; i < maxNbChilds; i++)
if (i < c.size())
context.push_back(indexInContext[c[i]]);
else
context.push_back(-1);
}
for (auto & l : contextIndexes)
for (auto & col : columns)
if (l == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));
return context;
} }
...@@ -18,7 +18,7 @@ class Trainer ...@@ -18,7 +18,7 @@ class Trainer
DataLoader dataLoader{nullptr}; DataLoader dataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer; std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0}; std::size_t epochNumber{0};
int batchSize{50}; int batchSize{1};
int nbExamples{0}; int nbExamples{0};
public : public :
......
...@@ -63,7 +63,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -63,7 +63,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
float Trainer::epoch(bool printAdvancement) float Trainer::epoch(bool printAdvancement)
{ {
constexpr int printInterval = 2000; constexpr int printInterval = 50;
int nbExamplesProcessed = 0; int nbExamplesProcessed = 0;
float totalLoss = 0.0; float totalLoss = 0.0;
float lossSoFar = 0.0; float lossSoFar = 0.0;
...@@ -81,6 +81,8 @@ float Trainer::epoch(bool printAdvancement) ...@@ -81,6 +81,8 @@ float Trainer::epoch(bool printAdvancement)
auto labels = batch.target.squeeze(); auto labels = batch.target.squeeze();
auto prediction = machine.getClassifier()->getNN()(data); auto prediction = machine.getClassifier()->getNN()(data);
if (prediction.dim() == 1)
prediction = prediction.unsqueeze(0);
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment