From 0c86cb53315b168fc479310660cce1b8d7072b5b Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 3 Mar 2021 15:03:27 +0100 Subject: [PATCH] Removed state from neuralnetwork --- decoder/src/Beam.cpp | 7 +------ decoder/src/Decoder.cpp | 1 - reading_machine/include/Classifier.hpp | 6 ++---- reading_machine/src/Classifier.cpp | 10 ++-------- reading_machine/src/ReadingMachine.cpp | 2 +- torch_modules/include/ModularNetwork.hpp | 3 +-- torch_modules/include/NeuralNetwork.hpp | 9 ++------- torch_modules/include/RandomNetwork.hpp | 2 +- torch_modules/include/Submodule.hpp | 3 +-- torch_modules/src/ModularNetwork.cpp | 11 ++--------- torch_modules/src/RandomNetwork.cpp | 4 ++-- trainer/src/Trainer.cpp | 12 +++--------- 12 files changed, 18 insertions(+), 52 deletions(-) diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index e39593c..9aad4f0 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -39,8 +39,6 @@ void Beam::update(ReadingMachine & machine, bool debug) auto & classifier = *machine.getClassifier(elements[index].config.getState()); - classifier.setState(elements[index].config.getState()); - if (machine.hasSplitWordTransitionSet()) elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions)); @@ -50,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug) auto context = classifier.getNN()->extractContext(elements[index].config).back(); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device); - auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); + auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0), 0); float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction); std::vector<std::pair<float, int>> scoresOfTransitions; for (unsigned int i = 0; i < prediction.size(0); i++) @@ -123,9 +121,6 @@ void Beam::update(ReadingMachine & machine, bool debug) continue; auto & config = element.config; - auto & classifier = *machine.getClassifier(config.getState()); - - classifier.setState(config.getState()); auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition); diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 474db07..ad2e6a6 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -39,7 +39,6 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh } catch(std::exception & e) {util::myThrow(e.what());} baseConfig = beam[0].config; - machine.getClassifier(baseConfig.getState())->setState(baseConfig.getState()); if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) { diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 17bfc84..e4b2208 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -22,7 +22,6 @@ class Classifier std::shared_ptr<NeuralNetworkImpl> nn; std::unique_ptr<torch::optim::Optimizer> optimizer; std::string optimizerType, optimizerParameters; - std::string state; std::vector<std::string> states; std::filesystem::path path; bool regression{false}; @@ -39,7 +38,7 @@ class Classifier public : Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train); - TransitionSet & getTransitionSet(); + TransitionSet & getTransitionSet(const std::string & state); NeuralNetwork & getNN(); const std::string & getName() const; int getNbParameters() const; @@ -47,8 +46,7 @@ class Classifier void loadOptimizer(); void saveOptimizer(); torch::optim::Optimizer & getOptimizer(); - void setState(const std::string & state); - float getLossMultiplier(); + float getLossMultiplier(const std::string & state); const std::vector<std::string> & getStates() const; void saveDicts(); void saveBest(); diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index b5e929b..68e89b7 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -110,7 +110,7 @@ int Classifier::getNbParameters() const return nbParameters; } -TransitionSet & Classifier::getTransitionSet() +TransitionSet & Classifier::getTransitionSet(const std::string & state) { if (!transitionSets.count(state)) util::myThrow(fmt::format("cannot find transition set for state '{}'", state)); @@ -196,12 +196,6 @@ torch::optim::Optimizer & Classifier::getOptimizer() return *optimizer; } -void Classifier::setState(const std::string & state) -{ - this->state = state; - nn->setState(state); -} - void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) { std::string anyBlanks = "(?:(?:\\s|\\t)*)"; @@ -244,7 +238,7 @@ void Classifier::resetOptimizer() util::myThrow(expected); } -float Classifier::getLossMultiplier() +float Classifier::getLossMultiplier(const std::string & state) { return lossMultipliers.at(state); } diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 4336d3a..582acff 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -95,7 +95,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path) TransitionSet & ReadingMachine::getTransitionSet(const std::string & state) { - return classifiers[state2classifier.at(state)]->getTransitionSet(); + return classifiers[state2classifier.at(state)]->getTransitionSet(state); } bool ReadingMachine::hasSplitWordTransitionSet() const diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 9b7efae..ed73c30 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -29,7 +29,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl public : ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path); - torch::Tensor forward(torch::Tensor input) override; + torch::Tensor forward(torch::Tensor input, const std::string & state) override; std::vector<std::vector<long>> extractContext(Config & config) override; void registerEmbeddings() override; void saveDicts(std::filesystem::path path) override; @@ -37,7 +37,6 @@ class ModularNetworkImpl : public NeuralNetworkImpl void setDictsState(Dict::State state) override; void setCountOcc(bool countOcc) override; void removeRareDictElements(float rarityThreshold) override; - void setState(const std::string & state); }; #endif diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 6058ceb..8215ad2 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -5,21 +5,16 @@ #include <filesystem> #include "Config.hpp" #include "NameHolder.hpp" -#include "StateHolder.hpp" -class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder +class NeuralNetworkImpl : public torch::nn::Module, public NameHolder { public : static torch::Device device; - private : - - std::string state; - public : - virtual torch::Tensor forward(torch::Tensor input) = 0; + virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; virtual void registerEmbeddings() = 0; virtual void saveDicts(std::filesystem::path path) = 0; diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index b20a779..3c559e9 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -12,7 +12,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl public : RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); - torch::Tensor forward(torch::Tensor input) override; + torch::Tensor forward(torch::Tensor input, const std::string & state) override; std::vector<std::vector<long>> extractContext(Config &) override; void registerEmbeddings() override; void saveDicts(std::filesystem::path path) override; diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 553da4f..1dbbdc7 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -5,9 +5,8 @@ #include <filesystem> #include "Config.hpp" #include "DictHolder.hpp" -#include "StateHolder.hpp" -class Submodule : public torch::nn::Module, public DictHolder, public StateHolder +class Submodule : public torch::nn::Module, public DictHolder { private : diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index e2e225c..c936f85 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -80,7 +80,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); } -torch::Tensor ModularNetworkImpl::forward(torch::Tensor input) +torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string & state) { if (input.dim() == 1) input = input.unsqueeze(0); @@ -92,7 +92,7 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input) auto totalInput = inputDropout(torch::cat(outputs, 1)); - return outputLayersPerState.at(getState())(mlp(totalInput)); + return outputLayersPerState.at(state)(mlp(totalInput)); } std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config) @@ -149,10 +149,3 @@ void ModularNetworkImpl::removeRareDictElements(float rarityThreshold) } } -void ModularNetworkImpl::setState(const std::string & state) -{ - NeuralNetworkImpl::setState(state); - for (auto & mod : modules) - mod->setState(state); -} - diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 7a6491b..87a6046 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -5,12 +5,12 @@ RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std: setName(name); } -torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) +torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string & state) { if (input.dim() == 1) input = input.unsqueeze(0); - return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true)); + return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true)); } std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 7f1aaec..56ecc44 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -53,7 +53,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); - machine.getClassifier(config.getState())->setState(config.getState()); while (true) { @@ -94,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: { auto & classifier = *machine.getClassifier(config.getState()); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device); - auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); + auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0), 0); entropy = NeuralNetworkImpl::entropy(prediction); std::vector<int> candidates; @@ -176,7 +175,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: break; config.setState(movement.first); - machine.getClassifier(config.getState())->setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) @@ -217,9 +215,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (train) machine.getClassifier(state)->getOptimizer().zero_grad(); - machine.getClassifier(state)->setState(state); - - auto prediction = machine.getClassifier(state)->getNN()(data); + auto prediction = machine.getClassifier(state)->getNN()->forward(data, state); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); @@ -229,7 +225,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance labels /= util::float2longScale; } - auto loss = machine.getClassifier(state)->getLossMultiplier()*machine.getClassifier(state)->getLossFunction()(prediction, labels); + auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels); float lossAsFloat = 0.0; try { @@ -340,7 +336,6 @@ void Trainer::extractActionSequence(BaseConfig & config) config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); - machine.getClassifier(config.getState())->setState(config.getState()); int curSeq = 0; int curSeqStartIndex = -1; @@ -403,7 +398,6 @@ void Trainer::extractActionSequence(BaseConfig & config) break; config.setState(movement.first); - machine.getClassifier(config.getState())->setState(movement.first); config.moveWordIndexRelaxed(movement.second); } -- GitLab