Skip to content
Snippets Groups Projects
Commit b4228d7b authored by Franck Dary's avatar Franck Dary
Browse files

First draft of modular neural network

parent 5a2ea279
Branches
No related tags found
No related merge requests found
#include "LSTM.hpp"
LSTMImpl::LSTMImpl(int inputSize, int outputSize, LSTMOptions options) : outputAll(std::get<4>(options))
LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
{
auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
.batch_first(std::get<0>(options))
......
......@@ -2,117 +2,119 @@
LSTMNetworkImpl::LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
{
LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
auto lstmOptionsAll = lstmOptions;
std::get<4>(lstmOptionsAll) = true;
int currentOutputSize = embeddingsSize;
int currentInputSize = 1;
contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold));
contextLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += contextLSTM->getOutputSize();
currentInputSize += contextLSTM->getInputSize();
if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
{
rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
rawInputLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += rawInputLSTM->getOutputSize();
currentInputSize += rawInputLSTM->getInputSize();
}
if (!treeEmbeddingColumns.empty())
{
treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions));
treeEmbedding->setFirstInputIndex(currentInputSize);
currentOutputSize += treeEmbedding->getOutputSize();
currentInputSize += treeEmbedding->getInputSize();
}
splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll));
splitTransLSTM->setFirstInputIndex(currentInputSize);
currentOutputSize += splitTransLSTM->getOutputSize();
currentInputSize += splitTransLSTM->getInputSize();
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions)));
focusedLstms.back()->setFirstInputIndex(currentInputSize);
currentOutputSize += focusedLstms.back()->getOutputSize();
currentInputSize += focusedLstms.back()->getInputSize();
}
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
if (drop2d)
embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
else
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
for (auto & it : nbOutputsPerState)
outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
// MyModule::ModuleOptions moduleOptions{true,bilstm,numLayers,lstmDropout,false};
// auto moduleOptionsAll = moduleOptions;
// std::get<4>(moduleOptionsAll) = true;
//
// int currentOutputSize = embeddingsSize;
// int currentInputSize = 1;
//
// contextLSTM = register_module("contextLSTM", ContextModule(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, moduleOptions, unknownValueThreshold));
// contextLSTM->setFirstInputIndex(currentInputSize);
// currentOutputSize += contextLSTM->getOutputSize();
// currentInputSize += contextLSTM->getInputSize();
//
// if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
// {
// rawInputLSTM = register_module("rawInputLSTM", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, moduleOptionsAll));
// rawInputLSTM->setFirstInputIndex(currentInputSize);
// currentOutputSize += rawInputLSTM->getOutputSize();
// currentInputSize += rawInputLSTM->getInputSize();
// }
//
// if (!treeEmbeddingColumns.empty())
// {
// treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,moduleOptions));
// treeEmbedding->setFirstInputIndex(currentInputSize);
// currentOutputSize += treeEmbedding->getOutputSize();
// currentInputSize += treeEmbedding->getInputSize();
// }
//
// splitTransLSTM = register_module("splitTransLSTM", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, moduleOptionsAll));
// splitTransLSTM->setFirstInputIndex(currentInputSize);
// currentOutputSize += splitTransLSTM->getOutputSize();
// currentInputSize += splitTransLSTM->getInputSize();
//
// for (unsigned int i = 0; i < focusedColumns.size(); i++)
// {
// focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, moduleOptions)));
// focusedLstms.back()->setFirstInputIndex(currentInputSize);
// currentOutputSize += focusedLstms.back()->getOutputSize();
// currentInputSize += focusedLstms.back()->getInputSize();
// }
//
// wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
// if (drop2d)
// embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
// else
// embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
// inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
//
// mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
//
// for (auto & it : nbOutputsPerState)
// outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
}
torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto embeddings = wordEmbeddings(input);
if (embeddingsDropout2d.is_empty())
embeddings = embeddingsDropout(embeddings);
else
embeddings = embeddingsDropout2d(embeddings);
std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
outputs.emplace_back(contextLSTM(embeddings));
if (!rawInputLSTM.is_empty())
outputs.emplace_back(rawInputLSTM(embeddings));
if (!treeEmbedding.is_empty())
outputs.emplace_back(treeEmbedding(embeddings));
outputs.emplace_back(splitTransLSTM(embeddings));
for (auto & lstm : focusedLstms)
outputs.emplace_back(lstm(embeddings));
auto totalInput = inputDropout(torch::cat(outputs, 1));
return outputLayersPerState.at(getState())(mlp(totalInput));
return input;
// if (input.dim() == 1)
// input = input.unsqueeze(0);
//
// auto embeddings = wordEmbeddings(input);
// if (embeddingsDropout2d.is_empty())
// embeddings = embeddingsDropout(embeddings);
// else
// embeddings = embeddingsDropout2d(embeddings);
//
// std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
//
// outputs.emplace_back(contextLSTM(embeddings));
//
// if (!rawInputLSTM.is_empty())
// outputs.emplace_back(rawInputLSTM(embeddings));
//
// if (!treeEmbedding.is_empty())
// outputs.emplace_back(treeEmbedding(embeddings));
//
// outputs.emplace_back(splitTransLSTM(embeddings));
//
// for (auto & lstm : focusedLstms)
// outputs.emplace_back(lstm(embeddings));
//
// auto totalInput = inputDropout(torch::cat(outputs, 1));
//
// return outputLayersPerState.at(getState())(mlp(totalInput));
}
std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
{
if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<std::vector<long>> context;
context.emplace_back();
context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
if (!rawInputLSTM.is_empty())
rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
if (!treeEmbedding.is_empty())
treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
for (auto & lstm : focusedLstms)
lstm->addToContext(context, dict, config, mustSplitUnknown());
if (!mustSplitUnknown() && context.size() > 1)
util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
return context;
// if (dict.size() >= maxNbEmbeddings)
// util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
//
// context.emplace_back();
//
// context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
//
// contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
//
// if (!rawInputLSTM.is_empty())
// rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
//
// if (!treeEmbedding.is_empty())
// treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
//
// splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
//
// for (auto & lstm : focusedLstms)
// lstm->addToContext(context, dict, config, mustSplitUnknown());
//
// if (!mustSplitUnknown() && context.size() > 1)
// util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
//
// return context;
}
#include "MLP.hpp"
#include "util.hpp"
#include "fmt/core.h"
#include <regex>
MLPImpl::MLPImpl(int inputSize, std::vector<std::pair<int, float>> params)
MLPImpl::MLPImpl(int inputSize, std::string definition)
{
std::regex regex("(?:(?:\\s|\\t)*)\\{(.*)\\}(?:(?:\\s|\\t)*)");
std::vector<std::pair<int, float>> params;
if (!util::doIfNameMatch(regex, definition, [this,&definition,&params](auto sm)
{
try
{
auto splited = util::split(sm.str(1), ' ');
for (unsigned int i = 0; i < splited.size()/2; i++)
{
params.emplace_back(std::stoi(splited[2*i]), std::stof(splited[2*i+1]));
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
int inSize = inputSize;
for (auto & param : params)
......
#include "ModularNetwork.hpp"
ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions)
{
std::string anyBlanks = "(?:(?:\\s|\\t)*)";
auto splitLine = [anyBlanks](std::string line)
{
std::pair<std::string,std::string> result;
util::doIfNameMatch(std::regex(fmt::format("{}(\\S+){}:{}(.+)",anyBlanks,anyBlanks,anyBlanks)),line,[&result](auto sm)
{
result.first = sm.str(1);
result.second = sm.str(2);
});
return result;
};
int currentInputSize = 0;
int currentOutputSize = 0;
std::string mlpDef;
for (auto & line : definitions)
{
auto splited = splitLine(line);
std::string name = fmt::format("{}_{}", modules.size(), splited.first);
if (splited.first == "Context")
modules.emplace_back(register_module(name, ContextModule(splited.second)));
else if (splited.first == "MLP")
{
mlpDef = splited.second;
continue;
}
else if (splited.first == "InputDropout")
{
inputDropout = register_module("inputDropout", torch::nn::Dropout(std::stof(splited.second)));
continue;
}
else
util::myThrow(fmt::format("unknown module '{}' for line '{}'", splited.first, line));
modules.back()->setFirstInputIndex(currentInputSize);
currentInputSize += modules.back()->getInputSize();
currentOutputSize += modules.back()->getOutputSize();
}
if (mlpDef.empty())
util::myThrow("no MLP definition found");
if (inputDropout.is_empty())
util::myThrow("no InputDropout definition found");
mlp = register_module("mlp", MLP(currentOutputSize, mlpDef));
for (auto & it : nbOutputsPerState)
outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
}
torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
std::vector<torch::Tensor> outputs;
for (auto & mod : modules)
outputs.emplace_back(mod->forward(input));
auto totalInput = inputDropout(torch::cat(outputs, 1));
return outputLayersPerState.at(getState())(mlp(totalInput));
}
std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<std::vector<long>> context(1);
for (auto & mod : modules)
mod->addToContext(context, dict, config, mustSplitUnknown());
return context;
}
#include "RawInputLSTM.hpp"
#include "RawInputModule.hpp"
RawInputLSTMImpl::RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : leftWindow(leftWindow), rightWindow(rightWindow)
RawInputModule::RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : leftWindow(leftWindow), rightWindow(rightWindow)
{
lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
}
torch::Tensor RawInputLSTMImpl::forward(torch::Tensor input)
torch::Tensor RawInputModule::forward(torch::Tensor input)
{
return lstm(input.narrow(1, firstInputIndex, getInputSize()));
return myModule->forward(input.narrow(1, firstInputIndex, getInputSize()));
}
std::size_t RawInputLSTMImpl::getOutputSize()
std::size_t RawInputModule::getOutputSize()
{
return lstm->getOutputSize(leftWindow + rightWindow + 1);
return myModule->getOutputSize(leftWindow + rightWindow + 1);
}
std::size_t RawInputLSTMImpl::getInputSize()
std::size_t RawInputModule::getInputSize()
{
return leftWindow + rightWindow + 1;
}
void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
void RawInputModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
if (leftWindow < 0 or rightWindow < 0)
return;
......
#include "SplitTransLSTM.hpp"
#include "SplitTransModule.hpp"
#include "Transition.hpp"
SplitTransLSTMImpl::SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxNbTrans(maxNbTrans)
SplitTransModule::SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : maxNbTrans(maxNbTrans)
{
lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
}
torch::Tensor SplitTransLSTMImpl::forward(torch::Tensor input)
torch::Tensor SplitTransModule::forward(torch::Tensor input)
{
return lstm(input.narrow(1, firstInputIndex, getInputSize()));
return myModule->forward(input.narrow(1, firstInputIndex, getInputSize()));
}
std::size_t SplitTransLSTMImpl::getOutputSize()
std::size_t SplitTransModule::getOutputSize()
{
return lstm->getOutputSize(maxNbTrans);
return myModule->getOutputSize(maxNbTrans);
}
std::size_t SplitTransLSTMImpl::getInputSize()
std::size_t SplitTransModule::getInputSize()
{
return maxNbTrans;
}
void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
void SplitTransModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
auto & splitTransitions = config.getAppliableSplitTransitions();
for (auto & contextElement : context)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment