Skip to content
Snippets Groups Projects
Commit 4bfe6147 authored by Franck Dary's avatar Franck Dary
Browse files

Updated LSTMNetwork and attach action to work with tree embedding networks

parent ff3894a3
No related branches found
No related tags found
No related merge requests found
......@@ -570,12 +570,20 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent
lineIndex = config.getWordIndex() + governorIndex;
else
lineIndex = config.getStack(governorIndex);
int depIndex = 0;
if (dependentObject == Object::Buffer)
depIndex = config.getWordIndex() + dependentIndex;
else
depIndex = config.getStack(dependentIndex);
addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(lineIndex)).apply(config, a);
addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, std::to_string(depIndex)).apply(config, a);
};
auto undo = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a)
{
addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, "").undo(config, a);
addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, "").apply(config, a);
};
auto appliable = [governorObject, governorIndex, dependentObject, dependentIndex](const Config & config, const Action & action)
......
......@@ -71,9 +71,8 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
{
int unknownValueThreshold;
std::vector<int> bufferContext, stackContext;
std::vector<std::string> columns;
std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
std::vector<int> focusedBuffer, focusedStack;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
std::vector<std::pair<int, float>> mlp;
int rawInputLeftWindow, rawInputRightWindow;
......@@ -223,6 +222,13 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
{
treeEmbeddingColumns = util::split(sm.str(1), ' ');
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns));
}
......@@ -7,6 +7,7 @@
#include "SplitTransLSTM.hpp"
#include "FocusedColumnLSTM.hpp"
#include "MLP.hpp"
#include "DepthLayerTreeEmbedding.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl
{
......@@ -20,13 +21,15 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
ContextLSTM contextLSTM{nullptr};
RawInputLSTM rawInputLSTM{nullptr};
SplitTransLSTM splitTransLSTM{nullptr};
DepthLayerTreeEmbedding treeEmbedding{nullptr};
std::vector<FocusedColumnLSTM> focusedLstms;
bool hasRawInputLSTM{false};
bool hasTreeEmbedding{false};
public :
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout);
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
......
#include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns)
{
LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
auto lstmOptionsAll = lstmOptions;
......@@ -23,6 +23,15 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
currentInputSize += rawInputLSTM->getInputSize();
}
if (!treeEmbeddingColumns.empty())
{
hasTreeEmbedding = true;
treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(1,3,embeddingsSize,128,treeEmbeddingColumns,focusedBufferIndexes,focusedStackIndexes,lstmOptionsAll));
treeEmbedding->setFirstInputIndex(currentInputSize);
currentOutputSize += treeEmbedding->getOutputSize();
currentInputSize += treeEmbedding->getInputSize();
}
splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll));
splitTransLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += splitTransLSTM->getOutputSize();
......@@ -56,6 +65,9 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
if (hasRawInputLSTM)
outputs.emplace_back(rawInputLSTM(embeddings));
if (hasTreeEmbedding)
outputs.emplace_back(treeEmbedding(embeddings));
outputs.emplace_back(splitTransLSTM(embeddings));
for (auto & lstm : focusedLstms)
......@@ -79,6 +91,10 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
contextLSTM->addToContext(context, dict, config);
if (hasRawInputLSTM)
rawInputLSTM->addToContext(context, dict, config);
fmt::print(stderr, "before {}\n", context.back().size());
if (hasTreeEmbedding)
treeEmbedding->addToContext(context, dict, config);
fmt::print(stderr, "after {}\n", context.back().size());
splitTransLSTM->addToContext(context, dict, config);
for (auto & lstm : focusedLstms)
lstm->addToContext(context, dict, config);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment