Skip to content
Snippets Groups Projects
Commit 58a388d4 authored by Franck Dary's avatar Franck Dary
Browse files

Integrated DepthLayerTreeEmbedding

parent 4bfe6147
No related branches found
No related tags found
No related merge requests found
......@@ -73,10 +73,12 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
std::vector<int> bufferContext, stackContext;
std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
std::vector<int> focusedBuffer, focusedStack;
std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack;
std::vector<int> maxNbElements;
std::vector<int> treeEmbeddingNbElems;
std::vector<std::pair<int, float>> mlp;
int rawInputLeftWindow, rawInputRightWindow;
int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers;
int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
bool bilstm;
float lstmDropout;
......@@ -229,6 +231,37 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
treeEmbeddingBuffer.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
treeEmbeddingStack.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
treeEmbeddingNbElems.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm)
{
treeEmbeddingSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize));
}
......@@ -9,17 +9,15 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
{
private :
std::vector<std::string> columns{"DEPREL"};
std::vector<int> focusedBuffer{0};
std::vector<int> focusedStack{0};
std::string firstElem{"__special_DepthLayerTreeEmbeddingImpl__"};
std::vector<int> maxElemPerDepth;
std::vector<std::string> columns;
std::vector<int> focusedBuffer;
std::vector<int> focusedStack;
std::vector<LSTM> depthLstm;
int maxDepth;
int maxElemPerDepth;
public :
DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
......
......@@ -29,7 +29,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
public :
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns);
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
......
#include "DepthLayerTreeEmbedding.hpp"
DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth)
DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options) :
maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack)
{
for (int i = 0; i < maxDepth; i++)
depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(embeddingsSize, outEmbeddingsSize, options)));
for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)));
}
torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
......@@ -12,9 +13,13 @@ torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
std::vector<torch::Tensor> outputs;
for (unsigned int i = 0; i < depthLstm.size(); i++)
for (unsigned int j = 0; j < focusedBuffer.size()+focusedStack.size(); j++)
outputs.emplace_back(depthLstm[i](input.narrow(1,i*(focusedBuffer.size()+focusedStack.size())*columns.size()*maxElemPerDepth + j*maxElemPerDepth, maxElemPerDepth)));
int offset = 0;
for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++)
for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
{
outputs.emplace_back(depthLstm[depth](context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)})));
offset += maxElemPerDepth[depth]*columns.size();
}
return torch::cat(outputs, 1);
}
......@@ -23,15 +28,18 @@ std::size_t DepthLayerTreeEmbeddingImpl::getOutputSize()
{
std::size_t outputSize = 0;
for (auto & lstm : depthLstm)
outputSize += lstm->getOutputSize(maxElemPerDepth);
for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
outputSize += depthLstm[depth]->getOutputSize(maxElemPerDepth[depth]);
return outputSize;
return outputSize*(focusedBuffer.size()+focusedStack.size());
}
std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
{
return (focusedBuffer.size()+focusedStack.size())*columns.size()*maxDepth*maxElemPerDepth;
int inputSize = 0;
for (int maxElem : maxElemPerDepth)
inputSize += (focusedBuffer.size()+focusedStack.size())*maxElem*columns.size();
return inputSize;
}
void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
......@@ -48,11 +56,27 @@ void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> &
focusedIndexes.emplace_back(-1);
for (auto & contextElement : context)
{
for (auto index : focusedIndexes)
{
std::vector<std::string> childs{std::to_string(index)};
for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
{
std::vector<std::string> newChilds;
for (auto & child : childs)
if (config.has(Config::childsColName, std::stoi(child), 0))
{
auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|');
newChilds.insert(newChilds.end(), val.begin(), val.end());
}
childs = newChilds;
for (int i = 0; i < maxElemPerDepth[depth]; i++)
for (auto & col : columns)
if (i < (int)childs.size() and config.has(col, std::stoi(childs[i]), 0))
contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(childs[i]))));
else
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
}
}
}
#include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns)
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize)
{
LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
auto lstmOptionsAll = lstmOptions;
......@@ -26,7 +26,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
if (!treeEmbeddingColumns.empty())
{
hasTreeEmbedding = true;
treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(1,3,embeddingsSize,128,treeEmbeddingColumns,focusedBufferIndexes,focusedStackIndexes,lstmOptionsAll));
treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptionsAll));
treeEmbedding->setFirstInputIndex(currentInputSize);
currentOutputSize += treeEmbedding->getOutputSize();
currentInputSize += treeEmbedding->getInputSize();
......@@ -89,13 +89,15 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
contextLSTM->addToContext(context, dict, config);
if (hasRawInputLSTM)
rawInputLSTM->addToContext(context, dict, config);
fmt::print(stderr, "before {}\n", context.back().size());
if (hasTreeEmbedding)
treeEmbedding->addToContext(context, dict, config);
fmt::print(stderr, "after {}\n", context.back().size());
splitTransLSTM->addToContext(context, dict, config);
for (auto & lstm : focusedLstms)
lstm->addToContext(context, dict, config);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment