Skip to content
Snippets Groups Projects
Commit 0089639f authored by Franck Dary's avatar Franck Dary
Browse files

Added module HistoryModule

parent caaca1a7
No related branches found
No related tags found
No related merge requests found
#ifndef HISTORYMODULE__H
#define HISTORYMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
class HistoryModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
int maxNbElements;
int inSize;
public :
HistoryModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(HistoryModule);
#endif
......@@ -11,6 +11,7 @@
#include "StateNameModule.hpp"
#include "UppercaseRateModule.hpp"
#include "NumericColumnModule.hpp"
#include "HistoryModule.hpp"
#include "MLP.hpp"
class ModularNetworkImpl : public NeuralNetworkImpl
......
#include "HistoryModule.hpp"
HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & definition)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
maxNbElements = std::stoi(sm.str(1));
auto subModuleType = sm.str(2);
auto subModuleArguments = util::split(sm.str(3), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
inSize = std::stoi(sm.str(4));
int outSize = std::stoi(sm.str(5));
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
torch::Tensor HistoryModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements)));
}
std::size_t HistoryModuleImpl::getOutputSize()
{
return myModule->getOutputSize(maxNbElements);
}
std::size_t HistoryModuleImpl::getInputSize()
{
return maxNbElements;
}
void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
{
auto & dict = getDict();
for (auto & contextElement : context)
for (int i = 0; i < maxNbElements; i++)
if (config.hasHistory(i))
contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i)));
else
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
void HistoryModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
}
......@@ -31,6 +31,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
modules.emplace_back(register_module(name, ContextModule(nameH, splited.second)));
else if (splited.first == "StateName")
modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
else if (splited.first == "History")
modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second)));
else if (splited.first == "NumericColumn")
modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second)));
else if (splited.first == "UppercaseRate")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment