Skip to content
Snippets Groups Projects
Commit 8adb3462 authored by Franck Dary's avatar Franck Dary
Browse files

code refactoring

parent c79e0a23
No related branches found
No related tags found
No related merge requests found
#include "DepthLayerTreeEmbedding.hpp"
DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth)
{
}
torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
{
}
int DepthLayerTreeEmbeddingImpl::getOutputSize()
{
}
#include "FocusedColumnLSTM.hpp"
FocusedColumnLSTMImpl::FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements)
{
lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
}
torch::Tensor FocusedColumnLSTMImpl::forward(torch::Tensor input)
{
std::vector<torch::Tensor> outputs;
for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++)
outputs.emplace_back(lstm(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements)));
return torch::cat(outputs, 1);
}
std::size_t FocusedColumnLSTMImpl::getOutputSize()
{
return (focusedBuffer.size()+focusedStack.size())*lstm->getOutputSize(maxNbElements);
}
std::size_t FocusedColumnLSTMImpl::getInputSize()
{
return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
}
void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
std::vector<long> focusedIndexes;
for (int index : focusedBuffer)
focusedIndexes.emplace_back(config.getRelativeWordIndex(index));
for (int index : focusedStack)
if (config.hasStack(index))
focusedIndexes.emplace_back(config.getStack(index));
else
focusedIndexes.emplace_back(-1);
for (auto & contextElement : context)
{
for (auto index : focusedIndexes)
{
if (index == -1)
{
for (int i = 0; i < maxNbElements; i++)
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue;
}
std::vector<std::string> elements;
if (column == "FORM")
{
auto asUtf8 = util::splitAsUtf8(config.getAsFeature(column, index).get());
for (int i = 0; i < maxNbElements; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("{}", asUtf8[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (column == "FEATS")
{
auto splited = util::split(config.getAsFeature(column, index).get(), '|');
for (int i = 0; i < maxNbElements; i++)
if (i < (int)splited.size())
elements.emplace_back(fmt::format("FEATS({})", splited[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (column == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
}
else
{
elements.emplace_back(config.getAsFeature(column, index));
}
if ((int)elements.size() != maxNbElements)
util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements));
for (auto & element : elements)
contextElement.emplace_back(dict.getIndexOrInsert(element));
}
}
}
#include "LSTM.hpp" #include "LSTM.hpp"
LSTMImpl::LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options) : outputAll(std::get<4>(options)) LSTMImpl::LSTMImpl(int inputSize, int outputSize, LSTMOptions options) : outputAll(std::get<4>(options))
{ {
auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize) auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
.batch_first(std::get<0>(options)) .batch_first(std::get<0>(options))
......
#include "LSTMNetwork.hpp" #include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput)
{ {
constexpr int embeddingsSize = 256; constexpr int embeddingsSize = 256;
constexpr int hiddenSize = 8192; constexpr int hiddenSize = 8192;
...@@ -8,41 +8,45 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: ...@@ -8,41 +8,45 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
constexpr int focusedLSTMSize = 256; constexpr int focusedLSTMSize = 256;
constexpr int rawInputLSTMSize = 32; constexpr int rawInputLSTMSize = 32;
std::tuple<bool,bool,int,float,bool> lstmOptions{true,true,2,0.3,false}; LSTMImpl::LSTMOptions lstmOptions{true,true,2,0.3,false};
auto lstmOptionsAll = lstmOptions; auto lstmOptionsAll = lstmOptions;
std::get<4>(lstmOptionsAll) = true; std::get<4>(lstmOptionsAll) = true;
setBufferContext(bufferContext); int currentOutputSize = embeddingsSize;
setStackContext(stackContext); int currentInputSize = 1;
setColumns(columns);
setBufferFocused(focusedBufferIndexes); contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold));
setStackFocused(focusedStackIndexes); contextLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += contextLSTM->getOutputSize();
rawInputSize = leftWindowRawInput + rightWindowRawInput + 1; currentInputSize += contextLSTM->getInputSize();
int rawInputLSTMOutSize = 0;
if (leftWindowRawInput < 0 or rightWindowRawInput < 0) if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
rawInputSize = 0;
else
{ {
rawInputLSTM = register_module("rawInputLSTM", LSTM(embeddingsSize, rawInputLSTMSize, lstmOptionsAll)); hasRawInputLSTM = true;
rawInputLSTMOutSize = rawInputLSTM->getOutputSize(rawInputSize); rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
rawInputLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += rawInputLSTM->getOutputSize();
currentInputSize += rawInputLSTM->getInputSize();
} }
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, embeddingsSize, lstmOptionsAll));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); splitTransLSTM->setFirstInputIndex(currentInputSize);
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); currentOutputSize += splitTransLSTM->getOutputSize();
contextLSTM = register_module("contextLSTM", LSTM(columns.size()*embeddingsSize, contextLSTMSize, lstmOptions)); currentInputSize += splitTransLSTM->getInputSize();
splitTransLSTM = register_module("splitTransLSTM", LSTM(embeddingsSize, embeddingsSize, lstmOptionsAll));
int totalLSTMOutputSize = rawInputLSTMOutSize + contextLSTM->getOutputSize(getContextSize()) + splitTransLSTM->getOutputSize(Config::maxNbAppliableSplitTransitions);
for (unsigned int i = 0; i < focusedColumns.size(); i++) for (unsigned int i = 0; i < focusedColumns.size(); i++)
{ {
lstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), LSTM(embeddingsSize, focusedLSTMSize, lstmOptions))); focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions)));
totalLSTMOutputSize += (bufferFocused.size()+stackFocused.size())*lstms.back()->getOutputSize(maxNbElements[i]); focusedLstms.back()->setFirstInputIndex(currentInputSize);
currentOutputSize += focusedLstms.back()->getOutputSize();
currentInputSize += focusedLstms.back()->getInputSize();
} }
linear1 = register_module("linear1", torch::nn::Linear(embeddingsSize+totalLSTMOutputSize, hiddenSize)); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
linear1 = register_module("linear1", torch::nn::Linear(currentOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
} }
...@@ -53,40 +57,19 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) ...@@ -53,40 +57,19 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
auto embeddings = embeddingsDropout(wordEmbeddings(input)); auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto state = embeddings.narrow(1, 0, 1).squeeze(1); std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
auto splitTrans = embeddings.narrow(1, 1, Config::maxNbAppliableSplitTransitions);
auto context = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize, getContextSize()); outputs.emplace_back(contextLSTM(embeddings));
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); if (hasRawInputLSTM)
outputs.emplace_back(rawInputLSTM(embeddings));
auto elementsEmbeddings = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize+context.size(1), input.size(1)-(1+splitTrans.size(1)+rawInputSize+context.size(1))); outputs.emplace_back(splitTransLSTM(embeddings));
std::vector<torch::Tensor> lstmOutputs; for (auto & lstm : focusedLstms)
outputs.emplace_back(lstm(embeddings));
lstmOutputs.emplace_back(state); auto totalInput = torch::cat(outputs, 1);
if (rawInputSize != 0)
{
auto rawLetters = embeddings.narrow(1, splitTrans.size(1), rawInputSize);
lstmOutputs.emplace_back(rawInputLSTM(rawLetters));
}
lstmOutputs.emplace_back(splitTransLSTM(splitTrans));
auto curIndex = 0;
for (unsigned int i = 0; i < focusedColumns.size(); i++)
for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
{
auto lstmInput = elementsEmbeddings.narrow(1, curIndex, maxNbElements[i]);
curIndex += maxNbElements[i];
lstmOutputs.emplace_back(lstms[i](lstmInput));
}
lstmOutputs.emplace_back(contextLSTM(context));
auto totalInput = torch::cat(lstmOutputs, 1);
return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
} }
...@@ -101,13 +84,12 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, ...@@ -101,13 +84,12 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
context.back().emplace_back(dict.getIndexOrInsert(config.getState())); context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
addAppliableSplitTransitions(context, dict, config); contextLSTM->addToContext(context, dict, config);
if (hasRawInputLSTM)
addRawInput(context, dict, config, leftWindowRawInput, rightWindowRawInput); rawInputLSTM->addToContext(context, dict, config);
splitTransLSTM->addToContext(context, dict, config);
addContext(context, dict, config, extractContextIndexes(config), unknownValueThreshold, {"FORM","LEMMA"}); for (auto & lstm : focusedLstms)
lstm->addToContext(context, dict, config);
addFocused(context, dict, config, extractFocusedIndexes(config), focusedColumns, maxNbElements);
if (!is_training() && context.size() > 1) if (!is_training() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
......
#include "MLP.hpp"
#include <regex>
MLPImpl::MLPImpl(const std::string & topology)
{
}
#include "NeuralNetwork.hpp" #include "NeuralNetwork.hpp"
#include "Transition.hpp"
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config) const
{
std::vector<long> context;
for (int index : bufferContext)
context.emplace_back(config.getRelativeWordIndex(index));
for (int index : stackContext)
if (config.hasStack(index))
context.emplace_back(config.getStack(index));
else
context.emplace_back(-1);
return context;
}
std::vector<long> NeuralNetworkImpl::extractFocusedIndexes(const Config & config) const
{
std::vector<long> context;
for (int index : bufferFocused)
context.emplace_back(config.getRelativeWordIndex(index));
for (int index : stackFocused)
if (config.hasStack(index))
context.emplace_back(config.getStack(index));
else
context.emplace_back(-1);
return context;
}
std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<long> indexes = extractContextIndexes(config);
std::vector<long> context;
for (auto & col : columns)
for (auto index : indexes)
if (index == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
return {context};
}
int NeuralNetworkImpl::getContextSize() const
{
return columns.size()*(bufferContext.size()+stackContext.size());
}
void NeuralNetworkImpl::setBufferContext(const std::vector<int> & bufferContext)
{
this->bufferContext = bufferContext;
}
void NeuralNetworkImpl::setStackContext(const std::vector<int> & stackContext)
{
this->stackContext = stackContext;
}
void NeuralNetworkImpl::setBufferFocused(const std::vector<int> & bufferFocused)
{
this->bufferFocused = bufferFocused;
}
void NeuralNetworkImpl::setStackFocused(const std::vector<int> & stackFocused)
{
this->stackFocused = stackFocused;
}
void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns)
{
this->columns = columns;
}
void NeuralNetworkImpl::addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
auto & splitTransitions = config.getAppliableSplitTransitions();
for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++)
if (i < (int)splitTransitions.size())
context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
else
context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
void NeuralNetworkImpl::addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const
{
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
return;
for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
void NeuralNetworkImpl::addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const
{
for (auto index : contextIndexes)
for (auto & col : columns)
if (index == -1)
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
for (auto & targetCol : unknownValueColumns)
if (col == targetCol)
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
{
context.emplace_back(context.back());
context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
}
}
}
void NeuralNetworkImpl::addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const
{
for (auto & contextElement : context)
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{
auto & col = focusedColumns[colIndex];
for (auto index : focusedIndexes)
{
if (index == -1)
{
for (int i = 0; i < maxNbElements[colIndex]; i++)
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue;
}
std::vector<std::string> elements;
if (col == "FORM")
{
auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("{}", asUtf8[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "FEATS")
{
auto splited = util::split(config.getAsFeature(col, index).get(), '|');
for (int i = 0; i < maxNbElements[colIndex]; i++)
if (i < (int)splited.size())
elements.emplace_back(fmt::format("FEATS({})", splited[i]));
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
}
else
{
elements.emplace_back(config.getAsFeature(col, index));
}
if ((int)elements.size() != maxNbElements[colIndex])
util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
for (auto & element : elements)
contextElement.emplace_back(dict.getIndexOrInsert(element));
}
}
}
#include "RLTNetwork.hpp"
RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
constexpr int embeddingsSize = 30;
constexpr int lstmOutputSize = 128;
constexpr int treeEmbeddingsSize = 256;
constexpr int hiddenSize = 500;
//TODO gerer ces context
this->leftBorder = leftBorder;
this->rightBorder = rightBorder;
setBufferContext({});
setStackContext({});
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false)));
S = register_parameter("S", torch::randn(treeEmbeddingsSize));
nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
}
torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
auto computeOrder = input.narrow(1, focusedIndexes.size(1), getContextSize()/columns.size());
auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(getContextSize()/columns.size()));
auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds});
auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), getContextSize());
auto baseEmbeddings = wordEmbeddings(wordIndexes);
auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()});
auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output;
std::vector<std::map<int, torch::Tensor>> treeRepresentations;
for (unsigned int batch = 0; batch < computeOrder.size(0); batch++)
{
treeRepresentations.emplace_back();
for (unsigned int i = 0; i < computeOrder[batch].size(0); i++)
{
int index = computeOrder[batch][i].item<int>();
if (index == -1)
break;
std::vector<torch::Tensor> inputVector;
inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], S}, 0));
for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
{
int child = childs[batch][index][childIndex].item<int>();
if (child == -1)
break;
inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], treeRepresentations[batch].count(child) ? treeRepresentations[batch][child] : nullTree}, 0));
}
auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
treeRepresentations[batch][index] = lstmOut;
}
}
std::vector<torch::Tensor> focusedTrees;
std::vector<torch::Tensor> representations;
for (unsigned int batch = 0; batch < focusedIndexes.size(0); batch++)
{
focusedTrees.clear();
for (unsigned int i = 0; i < focusedIndexes[batch].size(0); i++)
{
int index = focusedIndexes[batch][i].item<int>();
if (index == -1)
focusedTrees.emplace_back(nullTree);
else
focusedTrees.emplace_back(treeRepresentations[batch].count(index) ? treeRepresentations[batch][index] : nullTree);
}
representations.emplace_back(torch::cat(focusedTrees, 0).unsqueeze(0));
}
auto representation = torch::cat(representations, 0);
return linear2(torch::relu(linear1(representation)));
}
std::vector<std::vector<long>> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<long> contextIndexes;
std::stack<int> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
if (config.isToken(index))
leftContext.push(index);
while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size())
contextIndexes.emplace_back(-1);
while (!leftContext.empty())
{
contextIndexes.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index)
if (config.isToken(index))
contextIndexes.emplace_back(index);
while ((int)contextIndexes.size() < leftBorder+rightBorder+1)
contextIndexes.emplace_back(-1);
std::map<long, long> indexInContext;
for (auto & l : contextIndexes)
indexInContext.emplace(std::make_pair(l, indexInContext.size()));
std::vector<long> headOf;
for (auto & l : contextIndexes)
{
if (l == -1)
headOf.push_back(-1);
else
{
auto & head = config.getAsFeature(Config::headColName, l);
if (util::isEmpty(head) or head == "_")
headOf.push_back(-1);
else if (indexInContext.count(std::stoi(head)))
headOf.push_back(std::stoi(head));
else
headOf.push_back(-1);
}
}
std::vector<std::vector<long>> childs(headOf.size());
for (unsigned int i = 0; i < headOf.size(); i++)
if (headOf[i] != -1)
childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]);
std::vector<long> treeComputationOrder;
std::vector<bool> treeIsComputed(contextIndexes.size(), false);
std::function<void(long)> depthFirst;
depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root)
{
if (!indexInContext.count(root))
return;
if (treeIsComputed[indexInContext[root]])
return;
for (auto child : childs[indexInContext[root]])
depthFirst(child);
treeIsComputed[indexInContext[root]] = true;
treeComputationOrder.push_back(indexInContext[root]);
};
for (auto & l : focusedBufferIndexes)
if (contextIndexes[leftBorder+l] != -1)
depthFirst(contextIndexes[leftBorder+l]);
for (auto & l : focusedStackIndexes)
if (config.hasStack(l))
depthFirst(config.getStack(l));
std::vector<long> context;
for (auto & c : focusedBufferIndexes)
context.push_back(leftBorder+c);
for (auto & c : focusedStackIndexes)
if (config.hasStack(c) && indexInContext.count(config.getStack(c)))
context.push_back(indexInContext[config.getStack(c)]);
else
context.push_back(-1);
for (auto & c : treeComputationOrder)
context.push_back(c);
while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size())
context.push_back(-1);
for (auto & c : childs)
{
for (unsigned int i = 0; i < maxNbChilds; i++)
if (i < c.size())
context.push_back(indexInContext[c[i]]);
else
context.push_back(-1);
}
for (auto & l : contextIndexes)
for (auto & col : columns)
if (l == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));
return {context};
}
...@@ -2,11 +2,6 @@ ...@@ -2,11 +2,6 @@
RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize) RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize)
{ {
setBufferContext({0});
setStackContext({});
setBufferFocused({});
setStackFocused({});
setColumns({"FORM"});
} }
torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
...@@ -17,3 +12,8 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) ...@@ -17,3 +12,8 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true)); return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true));
} }
std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const
{
return std::vector<std::vector<long>>();
}
#include "RawInputLSTM.hpp"
RawInputLSTMImpl::RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : leftWindow(leftWindow), rightWindow(rightWindow)
{
lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
}
torch::Tensor RawInputLSTMImpl::forward(torch::Tensor input)
{
return lstm(input.narrow(1, firstInputIndex, getInputSize()));
}
std::size_t RawInputLSTMImpl::getOutputSize()
{
return lstm->getOutputSize(leftWindow + rightWindow + 1);
}
std::size_t RawInputLSTMImpl::getInputSize()
{
return leftWindow + rightWindow + 1;
}
void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
if (leftWindow < 0 or rightWindow < 0)
return;
for (int i = 0; i < leftWindow; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindow; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
#include "SplitTransLSTM.hpp"
#include "Transition.hpp"
SplitTransLSTMImpl::SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxNbTrans(maxNbTrans)
{
lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
}
torch::Tensor SplitTransLSTMImpl::forward(torch::Tensor input)
{
return lstm(input.narrow(1, firstInputIndex, getInputSize()));
}
std::size_t SplitTransLSTMImpl::getOutputSize()
{
return lstm->getOutputSize(maxNbTrans);
}
std::size_t SplitTransLSTMImpl::getInputSize()
{
return maxNbTrans;
}
void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
auto & splitTransitions = config.getAppliableSplitTransitions();
for (int i = 0; i < maxNbTrans; i++)
if (i < (int)splitTransitions.size())
context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
else
context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
#include "Submodule.hpp"
void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
{
this->firstInputIndex = firstInputIndex;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment