From 10d4492e3a7a705d069706ceaf0019204258860d Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 30 Apr 2020 20:07:35 +0200 Subject: [PATCH] Added module StateName --- torch_modules/include/ModularNetwork.hpp | 4 +- torch_modules/include/StateNameModule.hpp | 30 +++++++++++++++ torch_modules/src/ModularNetwork.cpp | 2 + torch_modules/src/StateNameModule.cpp | 45 +++++++++++++++++++++++ 4 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 torch_modules/include/StateNameModule.hpp create mode 100644 torch_modules/src/StateNameModule.cpp diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 08ace90..41c8beb 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -7,15 +7,13 @@ #include "SplitTransModule.hpp" #include "FocusedColumnModule.hpp" #include "DepthLayerTreeEmbeddingModule.hpp" +#include "StateNameModule.hpp" #include "MLP.hpp" class ModularNetworkImpl : public NeuralNetworkImpl { private : - //torch::nn::Embedding wordEmbeddings{nullptr}; - //torch::nn::Dropout2d embeddingsDropout2d{nullptr}; - //torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout inputDropout{nullptr}; MLP mlp{nullptr}; diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp new file mode 100644 index 0000000..8a2ae71 --- /dev/null +++ b/torch_modules/include/StateNameModule.hpp @@ -0,0 +1,30 @@ +#ifndef STATENAMEMODULE__H +#define STATENAMEMODULE__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "MyModule.hpp" +#include "LSTM.hpp" +#include "GRU.hpp" + +class StateNameModuleImpl : public Submodule +{ + private : + + std::map<std::string,int> state2index; + torch::nn::Embedding embeddings{nullptr}; + int outSize; + + public : + + StateNameModuleImpl(const std::string & definition); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void registerEmbeddings(std::size_t nbElements) override; +}; +TORCH_MODULE(StateNameModule); + +#endif + diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 13b7ca4..82cebd7 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -23,6 +23,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutpu std::string name = fmt::format("{}_{}", modules.size(), splited.first); if (splited.first == "Context") modules.emplace_back(register_module(name, ContextModule(splited.second))); + else if (splited.first == "StateName") + modules.emplace_back(register_module(name, StateNameModule(splited.second))); else if (splited.first == "Focused") modules.emplace_back(register_module(name, FocusedColumnModule(splited.second))); else if (splited.first == "RawInput") diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp new file mode 100644 index 0000000..afc5721 --- /dev/null +++ b/torch_modules/src/StateNameModule.cpp @@ -0,0 +1,45 @@ +#include "StateNameModule.hpp" + +StateNameModuleImpl::StateNameModuleImpl(const std::string & definition) +{ + std::regex regex("(?:(?:\\s|\\t)*)States\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + auto states = util::split(sm.str(1), ' '); + outSize = std::stoi(sm.str(2)); + + for (auto & state : states) + state2index.emplace(state, state2index.size()); + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); +} + +torch::Tensor StateNameModuleImpl::forward(torch::Tensor input) +{ + return embeddings(input.narrow(1,firstInputIndex,1).squeeze(1)); +} + +std::size_t StateNameModuleImpl::getOutputSize() +{ + return outSize; +} + +std::size_t StateNameModuleImpl::getInputSize() +{ + return 1; +} + +void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +{ + for (auto & contextElement : context) + contextElement.emplace_back(state2index.at(config.getState())); +} + +void StateNameModuleImpl::registerEmbeddings(std::size_t) +{ + embeddings = register_module("embeddings", torch::nn::Embedding(state2index.size(), outSize)); +} + -- GitLab