diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index fb3738f9117e1e35e672e35827f474095fde71e8..273e9a555d350433e8adbf7e3196ae0505f15acb 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -29,6 +29,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); + config.setAppliableTransitions(appliableTransitions); auto context = machine.getClassifier()->getNN()->extractContext(config).back(); @@ -45,7 +47,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool for (unsigned int i = 0; i < softmaxed.size(0); i++) { float score = softmaxed[i].item<float>(); - std::string nicePrint = fmt::format("{} {:7.2f} {}", machine.getTransitionSet().getTransition(i)->appliable(config) ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); + std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); toPrint.emplace_back(std::make_pair(score,nicePrint)); } std::sort(toPrint.rbegin(), toPrint.rend()); @@ -58,7 +60,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); - if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config)) + if ((chosenTransition == -1 or score > bestScore) and appliableTransitions[i]) { chosenTransition = i; bestScore = score; diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index aef9932fec06c2ce2a4663f4c8ed00726f362b3f..a047a8a0b2327cf94adf28f550c1333caf2abf66 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -47,6 +47,7 @@ class Config int lastPoppedStack{-1}; int currentWordId{0}; std::vector<Transition *> appliableSplitTransitions; + std::vector<int> appliableTransitions; protected : @@ -145,7 +146,9 @@ class Config void addMissingColumns(); void addComment(); void setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions); + void setAppliableTransitions(const std::vector<int> & appliableTransitions); const std::vector<Transition *> & getAppliableSplitTransitions() const; + const std::vector<int> & getAppliableTransitions() const; bool isExtraColumn(const std::string & colName) const; }; diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index d0c7c1fd219af24fbea2879e65c4765ea8ebc321..8f7b733b50dc6d8feba2cbd88b585fbb2b52e692 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -23,6 +23,7 @@ class TransitionSet std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); Transition * getBestAppliableTransition(const Config & c); std::vector<Transition *> getNAppliableTransitions(const Config & c, int n); + std::vector<int> getAppliableTransitions(const Config & c); std::size_t getTransitionIndex(const Transition * transition) const; Transition * getTransition(std::size_t index); Transition * getTransition(const std::string & name); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 1ea995cb7e4bd33a176665e02eb58621470ae488..f68baf3b4a56b739da01aa8d50383133e16b4536 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -662,11 +662,21 @@ void Config::setAppliableSplitTransitions(const std::vector<Transition *> & appl this->appliableSplitTransitions = appliableSplitTransitions; } +void Config::setAppliableTransitions(const std::vector<int> & appliableTransitions) +{ + this->appliableTransitions = appliableTransitions; +} + const std::vector<Transition *> & Config::getAppliableSplitTransitions() const { return appliableSplitTransitions; } +const std::vector<int> & Config::getAppliableTransitions() const +{ + return appliableTransitions; +} + Config::Object Config::str2object(const std::string & s) { if (s == "b") diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index a6ed1b0c50b64f55c0572b89cf9e0829951e3180..5701c70a46a9dabc04cbe53b8058ff5676ce62f7 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -67,6 +67,19 @@ std::vector<Transition *> TransitionSet::getNAppliableTransitions(const Config & return result; } +std::vector<int> TransitionSet::getAppliableTransitions(const Config & c) +{ + std::vector<int> result; + + for (unsigned int i = 0; i < transitions.size(); i++) + if (transitions[i].appliable(c)) + result.emplace_back(1); + else + result.emplace_back(0); + + return result; +} + Transition * TransitionSet::getBestAppliableTransition(const Config & c) { Transition * result = nullptr; diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5e6f9e461109eac691920e9763106681f1461f38 --- /dev/null +++ b/torch_modules/include/AppliableTransModule.hpp @@ -0,0 +1,28 @@ +#ifndef APPLIABLETRANSRANSMODULE__H +#define APPLIABLETRANSRANSMODULE__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "MyModule.hpp" +#include "LSTM.hpp" +#include "GRU.hpp" + +class AppliableTransModuleImpl : public Submodule +{ + private : + + int nbTrans; + + public : + + AppliableTransModuleImpl(std::string name, int nbTrans); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; +}; +TORCH_MODULE(AppliableTransModule); + +#endif + diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 7e721b923cadb307ff86016792b6573a0df6cbf2..40b1919186dab44e5da506419d9aae60505d76fb 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -5,6 +5,7 @@ #include "ContextModule.hpp" #include "RawInputModule.hpp" #include "SplitTransModule.hpp" +#include "AppliableTransModule.hpp" #include "FocusedColumnModule.hpp" #include "DepthLayerTreeEmbeddingModule.hpp" #include "StateNameModule.hpp" @@ -33,6 +34,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl void setDictsState(Dict::State state) override; void setCountOcc(bool countOcc) override; void removeRareDictElements(float rarityThreshold) override; + void setState(const std::string & state); }; #endif diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 3cbfe47443d98de4166f551442df31befff1db9c..ee32d2b2eadc666ef7e38ac70b8ed9f64055d3e4 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -5,8 +5,9 @@ #include <filesystem> #include "Config.hpp" #include "NameHolder.hpp" +#include "StateHolder.hpp" -class NeuralNetworkImpl : public torch::nn::Module, public NameHolder +class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder { public : @@ -21,8 +22,6 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; virtual void registerEmbeddings() = 0; - void setState(const std::string & state); - const std::string & getState() const; virtual void saveDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0; virtual void setDictsState(Dict::State state) = 0; diff --git a/torch_modules/include/StateHolder.hpp b/torch_modules/include/StateHolder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8712e550056b94ac32a378b441fc71b46e159942 --- /dev/null +++ b/torch_modules/include/StateHolder.hpp @@ -0,0 +1,19 @@ +#ifndef STATEHOLDER__H +#define STATEHOLDER__H + +#include <string> + +class StateHolder +{ + private : + + std::string state; + + public : + + const std::string & getState() const; + void setState(const std::string & state); +}; + +#endif + diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 135b0f9a781f90e7def6d5453181ffa6d9ce735f..f773d70194231a4d6b4ec2be6bcda43079f5441b 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -4,8 +4,9 @@ #include <torch/torch.h> #include "Config.hpp" #include "DictHolder.hpp" +#include "StateHolder.hpp" -class Submodule : public torch::nn::Module, public DictHolder +class Submodule : public torch::nn::Module, public DictHolder, public StateHolder { protected : diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c50586f1e49ed7001d0ecc6643b2f6af36047226 --- /dev/null +++ b/torch_modules/src/AppliableTransModule.cpp @@ -0,0 +1,37 @@ +#include "AppliableTransModule.hpp" + +AppliableTransModuleImpl::AppliableTransModuleImpl(std::string name, int nbTrans) : nbTrans(nbTrans) +{ + setName(name); +} + +torch::Tensor AppliableTransModuleImpl::forward(torch::Tensor input) +{ + return input.narrow(1, firstInputIndex, getInputSize()).to(torch::kFloat); +} + +std::size_t AppliableTransModuleImpl::getOutputSize() +{ + return nbTrans; +} + +std::size_t AppliableTransModuleImpl::getInputSize() +{ + return nbTrans; +} + +void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +{ + auto & appliableTrans = config.getAppliableTransitions(); + for (auto & contextElement : context) + for (int i = 0; i < nbTrans; i++) + if (i < (int)appliableTrans.size()) + contextElement.emplace_back(appliableTrans[i]); + else + contextElement.emplace_back(0); +} + +void AppliableTransModuleImpl::registerEmbeddings() +{ +} + diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index b4e23cc1a4effcce03ca6f261bfaacc030512c8c..c79791cc61109edbf0623741351a821af60d0f8d 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -15,6 +15,10 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st return result; }; + std::size_t maxNbOutputs = 0; + for (auto & it : nbOutputsPerState) + maxNbOutputs = std::max<std::size_t>(it.second, maxNbOutputs); + int currentInputSize = 0; int currentOutputSize = 0; std::string mlpDef; @@ -37,6 +41,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second))); else if (splited.first == "SplitTrans") modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second))); + else if (splited.first == "AppliableTrans") + modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs))); else if (splited.first == "DepthLayerTree") modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second))); else if (splited.first == "MLP") @@ -134,3 +140,10 @@ void ModularNetworkImpl::removeRareDictElements(float rarityThreshold) } } +void ModularNetworkImpl::setState(const std::string & state) +{ + NeuralNetworkImpl::setState(state); + for (auto & mod : modules) + mod->setState(state); +} + diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index aa149fa00bf82210021569bf06da946bae6002c6..02e8a191bfb4b2bc718b6e815a266bec252fb24b 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -2,13 +2,3 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); -void NeuralNetworkImpl::setState(const std::string & state) -{ - this->state = state; -} - -const std::string & NeuralNetworkImpl::getState() const -{ - return state; -} - diff --git a/torch_modules/src/StateHolder.cpp b/torch_modules/src/StateHolder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2209bade4b32782936621750682219b2c0995e58 --- /dev/null +++ b/torch_modules/src/StateHolder.cpp @@ -0,0 +1,16 @@ +#include "StateHolder.hpp" +#include "util.hpp" + +const std::string & StateHolder::getState() const +{ + if (state.empty()) + util::myThrow("state is empty"); + + return state; +} + +void StateHolder::setState(const std::string & state) +{ + this->state = state; +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 328b00c141085d23dec63a6d23ea8e3f4eb09118..d40794bb9944b492ecd9718b918b0df991eba6c7 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -70,6 +70,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); + config.setAppliableTransitions(appliableTransitions); std::vector<std::vector<long>> context; @@ -300,6 +302,8 @@ void Trainer::fillDicts(SubConfig & config, bool debug) if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); + config.setAppliableTransitions(appliableTransitions); try {