Skip to content
Snippets Groups Projects
Commit 5aafa66a authored by Franck Dary's avatar Franck Dary
Browse files

Implemented RTLSTM

parent d145be52
No related branches found
No related tags found
No related merge requests found
......@@ -114,7 +114,7 @@ void Config::print(FILE * dest) const
void Config::printForDebug(FILE * dest) const
{
static constexpr int windowSize = 5;
static constexpr int windowSize = 10;
static constexpr int lettersWindowSize = 40;
static constexpr int maxWordLength = 7;
......
......@@ -7,7 +7,7 @@
class NeuralNetworkImpl : public torch::nn::Module
{
private :
protected :
int leftBorder{5};
int rightBorder{5};
......@@ -23,7 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module
public :
virtual torch::Tensor forward(torch::Tensor input) = 0;
std::vector<long> extractContext(Config & config, Dict & dict) const;
virtual std::vector<long> extractContext(Config & config, Dict & dict) const;
int getContextSize() const;
void setColumns(const std::vector<std::string> & columns);
};
......
......@@ -7,16 +7,23 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl
{
private :
static constexpr long maxNbChilds{8};
static inline std::vector<long> focusedBufferIndexes{0,1,2};
static inline std::vector<long> focusedStackIndexes{0,1};
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::Dropout dropout{nullptr};
torch::nn::LSTM lstm{nullptr};
torch::nn::LSTM vectorBiLSTM{nullptr};
torch::nn::LSTM treeLSTM{nullptr};
torch::Tensor S;
torch::Tensor nullTree;
public :
RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override;
};
#endif
......@@ -3,31 +3,176 @@
RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
constexpr int embeddingsSize = 30;
constexpr int lstmOutputSize = 500;
constexpr int lstmOutputSize = 128;
constexpr int treeEmbeddingsSize = 256;
constexpr int hiddenSize = 500;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize));
linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3));
lstm = register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, lstmOutputSize).batch_first(true)));
vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false)));
S = register_parameter("S", torch::randn(treeEmbeddingsSize));
nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
}
torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
if (wordsAsEmb.dim() == 2)
wordsAsEmb = torch::unsqueeze(wordsAsEmb, 0);
auto lstmOut = lstm(wordsAsEmb).output;
// reshaped dim = {sequence, batch, embeddings}
auto reshaped = lstmOut.permute({1,0,2});
auto res = linear2(torch::relu(linear1(reshaped[-1])));
return res;
input = input.squeeze();
if (input.dim() != 1)
util::myThrow(fmt::format("Does not support batched input (dim()={})", input.dim()));
auto focusedIndexes = input.narrow(0, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
auto computeOrder = input.narrow(0, focusedIndexes.size(0), leftBorder+rightBorder+1);
auto childsFlat = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0), maxNbChilds*(leftBorder+rightBorder+1));
auto childs = torch::reshape(childsFlat, {computeOrder.size(0), maxNbChilds});
auto wordIndexes = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0)+childsFlat.size(0), columns.size()*(leftBorder+rightBorder+1));
auto baseEmbeddings = wordEmbeddings(wordIndexes);
auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {(int)baseEmbeddings.size(0)/(int)columns.size(), (int)baseEmbeddings.size(1)*(int)columns.size()}).unsqueeze(0);
auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output.squeeze();
std::vector<torch::Tensor> treeRepresentations(vectorRepresentations.size(0), nullTree);
for (unsigned int i = 0; i < computeOrder.size(0); i++)
{
int index = computeOrder[i].item<int>();
if (index == -1)
break;
std::vector<torch::Tensor> inputVector;
inputVector.emplace_back(torch::cat({vectorRepresentations[index], S}, 0));
for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
{
int child = childs[index][childIndex].item<int>();
if (child == -1)
break;
inputVector.emplace_back(torch::cat({vectorRepresentations[index], treeRepresentations[child]}, 0));
}
auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
treeRepresentations[index] = lstmOut;
}
std::vector<torch::Tensor> focusedTrees;
for (unsigned int i = 0; i < focusedIndexes.size(0); i++)
{
int index = focusedIndexes[i].item<int>();
if (index == -1)
focusedTrees.emplace_back(nullTree);
else
focusedTrees.emplace_back(treeRepresentations[index]);
}
auto representation = torch::cat(focusedTrees, 0);
return linear2(torch::relu(linear1(representation)));
}
std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<long> contextIndexes;
std::stack<int> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
if (config.isToken(index))
leftContext.push(index);
while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size())
contextIndexes.emplace_back(-1);
while (!leftContext.empty())
{
contextIndexes.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index)
if (config.isToken(index))
contextIndexes.emplace_back(index);
while ((int)contextIndexes.size() < leftBorder+rightBorder+1)
contextIndexes.emplace_back(-1);
std::map<long, long> indexInContext;
for (auto & l : contextIndexes)
indexInContext.emplace(std::make_pair(l, indexInContext.size()));
std::vector<long> headOf;
for (auto & l : contextIndexes)
{
if (l == -1)
headOf.push_back(-1);
else
{
auto & head = config.getAsFeature(Config::headColName, l);
if (util::isEmpty(head) or head == "_")
headOf.push_back(-1);
else if (indexInContext.count(std::stoi(head)))
headOf.push_back(std::stoi(head));
else
headOf.push_back(-1);
}
}
std::vector<std::vector<long>> childs(headOf.size());
for (unsigned int i = 0; i < headOf.size(); i++)
if (headOf[i] != -1)
childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]);
std::vector<long> treeComputationOrder;
std::vector<bool> treeIsComputed(contextIndexes.size(), false);
std::function<void(long)> depthFirst;
depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root)
{
if (!indexInContext.count(root))
return;
if (treeIsComputed[indexInContext[root]])
return;
for (auto child : childs[indexInContext[root]])
depthFirst(child);
treeIsComputed[indexInContext[root]] = true;
treeComputationOrder.push_back(indexInContext[root]);
};
for (auto & l : focusedBufferIndexes)
if (contextIndexes[leftBorder+l] != -1)
depthFirst(contextIndexes[leftBorder+l]);
for (auto & l : focusedStackIndexes)
if (config.hasStack(l))
depthFirst(config.getStack(l));
std::vector<long> context;
for (auto & c : focusedBufferIndexes)
context.push_back(leftBorder+c);
for (auto & c : focusedStackIndexes)
if (config.hasStack(c) && indexInContext.count(config.getStack(c)))
context.push_back(indexInContext[config.getStack(c)]);
else
context.push_back(-1);
for (auto & c : treeComputationOrder)
context.push_back(c);
while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size())
context.push_back(-1);
for (auto & c : childs)
{
for (unsigned int i = 0; i < maxNbChilds; i++)
if (i < c.size())
context.push_back(indexInContext[c[i]]);
else
context.push_back(-1);
}
for (auto & l : contextIndexes)
for (auto & col : columns)
if (l == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));
return context;
}
......@@ -18,7 +18,7 @@ class Trainer
DataLoader dataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0};
int batchSize{50};
int batchSize{1};
int nbExamples{0};
public :
......
......@@ -63,7 +63,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
float Trainer::epoch(bool printAdvancement)
{
constexpr int printInterval = 2000;
constexpr int printInterval = 50;
int nbExamplesProcessed = 0;
float totalLoss = 0.0;
float lossSoFar = 0.0;
......@@ -81,6 +81,8 @@ float Trainer::epoch(bool printAdvancement)
auto labels = batch.target.squeeze();
auto prediction = machine.getClassifier()->getNN()(data);
if (prediction.dim() == 1)
prediction = prediction.unsqueeze(0);
labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment