Skip to content
Snippets Groups Projects
Commit 10d4492e authored by Franck Dary's avatar Franck Dary
Browse files

Added module StateName

parent 47cbed2e
No related branches found
No related tags found
No related merge requests found
...@@ -7,15 +7,13 @@ ...@@ -7,15 +7,13 @@
#include "SplitTransModule.hpp" #include "SplitTransModule.hpp"
#include "FocusedColumnModule.hpp" #include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp" #include "DepthLayerTreeEmbeddingModule.hpp"
#include "StateNameModule.hpp"
#include "MLP.hpp" #include "MLP.hpp"
class ModularNetworkImpl : public NeuralNetworkImpl class ModularNetworkImpl : public NeuralNetworkImpl
{ {
private : private :
//torch::nn::Embedding wordEmbeddings{nullptr};
//torch::nn::Dropout2d embeddingsDropout2d{nullptr};
//torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout inputDropout{nullptr}; torch::nn::Dropout inputDropout{nullptr};
MLP mlp{nullptr}; MLP mlp{nullptr};
......
#ifndef STATENAMEMODULE__H
#define STATENAMEMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
class StateNameModuleImpl : public Submodule
{
private :
std::map<std::string,int> state2index;
torch::nn::Embedding embeddings{nullptr};
int outSize;
public :
StateNameModuleImpl(const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void registerEmbeddings(std::size_t nbElements) override;
};
TORCH_MODULE(StateNameModule);
#endif
...@@ -23,6 +23,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutpu ...@@ -23,6 +23,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutpu
std::string name = fmt::format("{}_{}", modules.size(), splited.first); std::string name = fmt::format("{}_{}", modules.size(), splited.first);
if (splited.first == "Context") if (splited.first == "Context")
modules.emplace_back(register_module(name, ContextModule(splited.second))); modules.emplace_back(register_module(name, ContextModule(splited.second)));
else if (splited.first == "StateName")
modules.emplace_back(register_module(name, StateNameModule(splited.second)));
else if (splited.first == "Focused") else if (splited.first == "Focused")
modules.emplace_back(register_module(name, FocusedColumnModule(splited.second))); modules.emplace_back(register_module(name, FocusedColumnModule(splited.second)));
else if (splited.first == "RawInput") else if (splited.first == "RawInput")
......
#include "StateNameModule.hpp"
StateNameModuleImpl::StateNameModuleImpl(const std::string & definition)
{
std::regex regex("(?:(?:\\s|\\t)*)States\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
auto states = util::split(sm.str(1), ' ');
outSize = std::stoi(sm.str(2));
for (auto & state : states)
state2index.emplace(state, state2index.size());
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
torch::Tensor StateNameModuleImpl::forward(torch::Tensor input)
{
return embeddings(input.narrow(1,firstInputIndex,1).squeeze(1));
}
std::size_t StateNameModuleImpl::getOutputSize()
{
return outSize;
}
std::size_t StateNameModuleImpl::getInputSize()
{
return 1;
}
void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
for (auto & contextElement : context)
contextElement.emplace_back(state2index.at(config.getState()));
}
void StateNameModuleImpl::registerEmbeddings(std::size_t)
{
embeddings = register_module("embeddings", torch::nn::Embedding(state2index.size(), outSize));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment