From 1b96cca2e679d400bb5a2f9f4b57ef504b4f3ca8 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 10 May 2020 18:04:05 +0200 Subject: [PATCH] Aded AppliableTransModule --- decoder/src/Decoder.cpp | 6 ++- reading_machine/include/Config.hpp | 3 ++ reading_machine/include/TransitionSet.hpp | 1 + reading_machine/src/Config.cpp | 10 +++++ reading_machine/src/TransitionSet.cpp | 13 +++++++ .../include/AppliableTransModule.hpp | 28 ++++++++++++++ torch_modules/include/ModularNetwork.hpp | 2 + torch_modules/include/NeuralNetwork.hpp | 5 +-- torch_modules/include/StateHolder.hpp | 19 ++++++++++ torch_modules/include/Submodule.hpp | 3 +- torch_modules/src/AppliableTransModule.cpp | 37 +++++++++++++++++++ torch_modules/src/ModularNetwork.cpp | 13 +++++++ torch_modules/src/NeuralNetwork.cpp | 10 ----- torch_modules/src/StateHolder.cpp | 16 ++++++++ trainer/src/Trainer.cpp | 4 ++ 15 files changed, 154 insertions(+), 16 deletions(-) create mode 100644 torch_modules/include/AppliableTransModule.hpp create mode 100644 torch_modules/include/StateHolder.hpp create mode 100644 torch_modules/src/AppliableTransModule.cpp create mode 100644 torch_modules/src/StateHolder.cpp diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index fb3738f..273e9a5 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -29,6 +29,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); + config.setAppliableTransitions(appliableTransitions); auto context = machine.getClassifier()->getNN()->extractContext(config).back(); @@ -45,7 +47,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool for (unsigned int i = 0; i < softmaxed.size(0); i++) { float score = softmaxed[i].item<float>(); - std::string nicePrint = fmt::format("{} {:7.2f} {}", machine.getTransitionSet().getTransition(i)->appliable(config) ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); + std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); toPrint.emplace_back(std::make_pair(score,nicePrint)); } std::sort(toPrint.rbegin(), toPrint.rend()); @@ -58,7 +60,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); - if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config)) + if ((chosenTransition == -1 or score > bestScore) and appliableTransitions[i]) { chosenTransition = i; bestScore = score; diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index aef9932..a047a8a 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -47,6 +47,7 @@ class Config int lastPoppedStack{-1}; int currentWordId{0}; std::vector<Transition *> appliableSplitTransitions; + std::vector<int> appliableTransitions; protected : @@ -145,7 +146,9 @@ class Config void addMissingColumns(); void addComment(); void setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions); + void setAppliableTransitions(const std::vector<int> & appliableTransitions); const std::vector<Transition *> & getAppliableSplitTransitions() const; + const std::vector<int> & getAppliableTransitions() const; bool isExtraColumn(const std::string & colName) const; }; diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index d0c7c1f..8f7b733 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -23,6 +23,7 @@ class TransitionSet std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); Transition * getBestAppliableTransition(const Config & c); std::vector<Transition *> getNAppliableTransitions(const Config & c, int n); + std::vector<int> getAppliableTransitions(const Config & c); std::size_t getTransitionIndex(const Transition * transition) const; Transition * getTransition(std::size_t index); Transition * getTransition(const std::string & name); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 1ea995c..f68baf3 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -662,11 +662,21 @@ void Config::setAppliableSplitTransitions(const std::vector<Transition *> & appl this->appliableSplitTransitions = appliableSplitTransitions; } +void Config::setAppliableTransitions(const std::vector<int> & appliableTransitions) +{ + this->appliableTransitions = appliableTransitions; +} + const std::vector<Transition *> & Config::getAppliableSplitTransitions() const { return appliableSplitTransitions; } +const std::vector<int> & Config::getAppliableTransitions() const +{ + return appliableTransitions; +} + Config::Object Config::str2object(const std::string & s) { if (s == "b") diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index a6ed1b0..5701c70 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -67,6 +67,19 @@ std::vector<Transition *> TransitionSet::getNAppliableTransitions(const Config & return result; } +std::vector<int> TransitionSet::getAppliableTransitions(const Config & c) +{ + std::vector<int> result; + + for (unsigned int i = 0; i < transitions.size(); i++) + if (transitions[i].appliable(c)) + result.emplace_back(1); + else + result.emplace_back(0); + + return result; +} + Transition * TransitionSet::getBestAppliableTransition(const Config & c) { Transition * result = nullptr; diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp new file mode 100644 index 0000000..5e6f9e4 --- /dev/null +++ b/torch_modules/include/AppliableTransModule.hpp @@ -0,0 +1,28 @@ +#ifndef APPLIABLETRANSRANSMODULE__H +#define APPLIABLETRANSRANSMODULE__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "MyModule.hpp" +#include "LSTM.hpp" +#include "GRU.hpp" + +class AppliableTransModuleImpl : public Submodule +{ + private : + + int nbTrans; + + public : + + AppliableTransModuleImpl(std::string name, int nbTrans); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; +}; +TORCH_MODULE(AppliableTransModule); + +#endif + diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 7e721b9..40b1919 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -5,6 +5,7 @@ #include "ContextModule.hpp" #include "RawInputModule.hpp" #include "SplitTransModule.hpp" +#include "AppliableTransModule.hpp" #include "FocusedColumnModule.hpp" #include "DepthLayerTreeEmbeddingModule.hpp" #include "StateNameModule.hpp" @@ -33,6 +34,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl void setDictsState(Dict::State state) override; void setCountOcc(bool countOcc) override; void removeRareDictElements(float rarityThreshold) override; + void setState(const std::string & state); }; #endif diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 3cbfe47..ee32d2b 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -5,8 +5,9 @@ #include <filesystem> #include "Config.hpp" #include "NameHolder.hpp" +#include "StateHolder.hpp" -class NeuralNetworkImpl : public torch::nn::Module, public NameHolder +class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder { public : @@ -21,8 +22,6 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; virtual void registerEmbeddings() = 0; - void setState(const std::string & state); - const std::string & getState() const; virtual void saveDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0; virtual void setDictsState(Dict::State state) = 0; diff --git a/torch_modules/include/StateHolder.hpp b/torch_modules/include/StateHolder.hpp new file mode 100644 index 0000000..8712e55 --- /dev/null +++ b/torch_modules/include/StateHolder.hpp @@ -0,0 +1,19 @@ +#ifndef STATEHOLDER__H +#define STATEHOLDER__H + +#include <string> + +class StateHolder +{ + private : + + std::string state; + + public : + + const std::string & getState() const; + void setState(const std::string & state); +}; + +#endif + diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 135b0f9..f773d70 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -4,8 +4,9 @@ #include <torch/torch.h> #include "Config.hpp" #include "DictHolder.hpp" +#include "StateHolder.hpp" -class Submodule : public torch::nn::Module, public DictHolder +class Submodule : public torch::nn::Module, public DictHolder, public StateHolder { protected : diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp new file mode 100644 index 0000000..c50586f --- /dev/null +++ b/torch_modules/src/AppliableTransModule.cpp @@ -0,0 +1,37 @@ +#include "AppliableTransModule.hpp" + +AppliableTransModuleImpl::AppliableTransModuleImpl(std::string name, int nbTrans) : nbTrans(nbTrans) +{ + setName(name); +} + +torch::Tensor AppliableTransModuleImpl::forward(torch::Tensor input) +{ + return input.narrow(1, firstInputIndex, getInputSize()).to(torch::kFloat); +} + +std::size_t AppliableTransModuleImpl::getOutputSize() +{ + return nbTrans; +} + +std::size_t AppliableTransModuleImpl::getInputSize() +{ + return nbTrans; +} + +void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +{ + auto & appliableTrans = config.getAppliableTransitions(); + for (auto & contextElement : context) + for (int i = 0; i < nbTrans; i++) + if (i < (int)appliableTrans.size()) + contextElement.emplace_back(appliableTrans[i]); + else + contextElement.emplace_back(0); +} + +void AppliableTransModuleImpl::registerEmbeddings() +{ +} + diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index b4e23cc..c79791c 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -15,6 +15,10 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st return result; }; + std::size_t maxNbOutputs = 0; + for (auto & it : nbOutputsPerState) + maxNbOutputs = std::max<std::size_t>(it.second, maxNbOutputs); + int currentInputSize = 0; int currentOutputSize = 0; std::string mlpDef; @@ -37,6 +41,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second))); else if (splited.first == "SplitTrans") modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second))); + else if (splited.first == "AppliableTrans") + modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs))); else if (splited.first == "DepthLayerTree") modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second))); else if (splited.first == "MLP") @@ -134,3 +140,10 @@ void ModularNetworkImpl::removeRareDictElements(float rarityThreshold) } } +void ModularNetworkImpl::setState(const std::string & state) +{ + NeuralNetworkImpl::setState(state); + for (auto & mod : modules) + mod->setState(state); +} + diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index aa149fa..02e8a19 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -2,13 +2,3 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); -void NeuralNetworkImpl::setState(const std::string & state) -{ - this->state = state; -} - -const std::string & NeuralNetworkImpl::getState() const -{ - return state; -} - diff --git a/torch_modules/src/StateHolder.cpp b/torch_modules/src/StateHolder.cpp new file mode 100644 index 0000000..2209bad --- /dev/null +++ b/torch_modules/src/StateHolder.cpp @@ -0,0 +1,16 @@ +#include "StateHolder.hpp" +#include "util.hpp" + +const std::string & StateHolder::getState() const +{ + if (state.empty()) + util::myThrow("state is empty"); + + return state; +} + +void StateHolder::setState(const std::string & state) +{ + this->state = state; +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 328b00c..d40794b 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -70,6 +70,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); + config.setAppliableTransitions(appliableTransitions); std::vector<std::vector<long>> context; @@ -300,6 +302,8 @@ void Trainer::fillDicts(SubConfig & config, bool debug) if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); + config.setAppliableTransitions(appliableTransitions); try { -- GitLab