diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 08ace9018e1920acf96fb65e3deebd4eb57592db..41c8beb4eb19fadcb1f9eaca652e8b310d35cf20 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -7,15 +7,13 @@ #include "SplitTransModule.hpp" #include "FocusedColumnModule.hpp" #include "DepthLayerTreeEmbeddingModule.hpp" +#include "StateNameModule.hpp" #include "MLP.hpp" class ModularNetworkImpl : public NeuralNetworkImpl { private : - //torch::nn::Embedding wordEmbeddings{nullptr}; - //torch::nn::Dropout2d embeddingsDropout2d{nullptr}; - //torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout inputDropout{nullptr}; MLP mlp{nullptr}; diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8a2ae71682202ba5fe1b070e46cf0d5e8a428063 --- /dev/null +++ b/torch_modules/include/StateNameModule.hpp @@ -0,0 +1,30 @@ +#ifndef STATENAMEMODULE__H +#define STATENAMEMODULE__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "MyModule.hpp" +#include "LSTM.hpp" +#include "GRU.hpp" + +class StateNameModuleImpl : public Submodule +{ + private : + + std::map<std::string,int> state2index; + torch::nn::Embedding embeddings{nullptr}; + int outSize; + + public : + + StateNameModuleImpl(const std::string & definition); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; + void registerEmbeddings(std::size_t nbElements) override; +}; +TORCH_MODULE(StateNameModule); + +#endif + diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 13b7ca4936a3b1371387d35cabeb5fdb8d4b3795..82cebd7ad62203d92fd19a9e1046dd24ed6e0ebf 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -23,6 +23,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutpu std::string name = fmt::format("{}_{}", modules.size(), splited.first); if (splited.first == "Context") modules.emplace_back(register_module(name, ContextModule(splited.second))); + else if (splited.first == "StateName") + modules.emplace_back(register_module(name, StateNameModule(splited.second))); else if (splited.first == "Focused") modules.emplace_back(register_module(name, FocusedColumnModule(splited.second))); else if (splited.first == "RawInput") diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afc572142fb4725753c0c8a1eece8613c33b36ef --- /dev/null +++ b/torch_modules/src/StateNameModule.cpp @@ -0,0 +1,45 @@ +#include "StateNameModule.hpp" + +StateNameModuleImpl::StateNameModuleImpl(const std::string & definition) +{ + std::regex regex("(?:(?:\\s|\\t)*)States\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + auto states = util::split(sm.str(1), ' '); + outSize = std::stoi(sm.str(2)); + + for (auto & state : states) + state2index.emplace(state, state2index.size()); + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); +} + +torch::Tensor StateNameModuleImpl::forward(torch::Tensor input) +{ + return embeddings(input.narrow(1,firstInputIndex,1).squeeze(1)); +} + +std::size_t StateNameModuleImpl::getOutputSize() +{ + return outSize; +} + +std::size_t StateNameModuleImpl::getInputSize() +{ + return 1; +} + +void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +{ + for (auto & contextElement : context) + contextElement.emplace_back(state2index.at(config.getState())); +} + +void StateNameModuleImpl::registerEmbeddings(std::size_t) +{ + embeddings = register_module("embeddings", torch::nn::Embedding(state2index.size(), outSize)); +} +