From 0089639fe6a542da173a4fe773a5b67401b7fb4e Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 17 May 2020 18:34:31 +0200 Subject: [PATCH] Added module HistoryModule --- torch_modules/include/HistoryModule.hpp | 31 +++++++++++ torch_modules/include/ModularNetwork.hpp | 1 + torch_modules/src/HistoryModule.cpp | 68 ++++++++++++++++++++++++ torch_modules/src/ModularNetwork.cpp | 2 + 4 files changed, 102 insertions(+) create mode 100644 torch_modules/include/HistoryModule.hpp create mode 100644 torch_modules/src/HistoryModule.cpp diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp new file mode 100644 index 0000000..abcd26f --- /dev/null +++ b/torch_modules/include/HistoryModule.hpp @@ -0,0 +1,31 @@ +#ifndef HISTORYMODULE__H +#define HISTORYMODULE__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "MyModule.hpp" +#include "LSTM.hpp" +#include "GRU.hpp" + +class HistoryModuleImpl : public Submodule +{ + private : + + torch::nn::Embedding wordEmbeddings{nullptr}; + std::shared_ptr<MyModule> myModule{nullptr}; + int maxNbElements; + int inSize; + + public : + + HistoryModuleImpl(std::string name, const std::string & definition); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; +}; +TORCH_MODULE(HistoryModule); + +#endif + diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 40b1919..f49ba3f 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -11,6 +11,7 @@ #include "StateNameModule.hpp" #include "UppercaseRateModule.hpp" #include "NumericColumnModule.hpp" +#include "HistoryModule.hpp" #include "MLP.hpp" class ModularNetworkImpl : public NeuralNetworkImpl diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp new file mode 100644 index 0000000..bc9434b --- /dev/null +++ b/torch_modules/src/HistoryModule.cpp @@ -0,0 +1,68 @@ +#include "HistoryModule.hpp" + +HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & definition) +{ + setName(name); + std::regex regex("(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + maxNbElements = std::stoi(sm.str(1)); + + auto subModuleType = sm.str(2); + auto subModuleArguments = util::split(sm.str(3), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + inSize = std::stoi(sm.str(4)); + int outSize = std::stoi(sm.str(5)); + + if (subModuleType == "LSTM") + myModule = register_module("myModule", LSTM(inSize, outSize, options)); + else if (subModuleType == "GRU") + myModule = register_module("myModule", GRU(inSize, outSize, options)); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); +} + +torch::Tensor HistoryModuleImpl::forward(torch::Tensor input) +{ + return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements))); +} + +std::size_t HistoryModuleImpl::getOutputSize() +{ + return myModule->getOutputSize(maxNbElements); +} + +std::size_t HistoryModuleImpl::getInputSize() +{ + return maxNbElements; +} + +void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +{ + auto & dict = getDict(); + + for (auto & contextElement : context) + for (int i = 0; i < maxNbElements; i++) + if (config.hasHistory(i)) + contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i))); + else + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); +} + +void HistoryModuleImpl::registerEmbeddings() +{ + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); +} + diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index c79791c..11b6962 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -31,6 +31,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st modules.emplace_back(register_module(name, ContextModule(nameH, splited.second))); else if (splited.first == "StateName") modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second))); + else if (splited.first == "History") + modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second))); else if (splited.first == "NumericColumn") modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second))); else if (splited.first == "UppercaseRate") -- GitLab