diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index e39593c8aa3189cc4cbb0ba3a0be82043f0fff37..9aad4f048b3e065e554732bdf1b56326bbff78c3 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -39,8 +39,6 @@ void Beam::update(ReadingMachine & machine, bool debug) auto & classifier = *machine.getClassifier(elements[index].config.getState()); - classifier.setState(elements[index].config.getState()); - if (machine.hasSplitWordTransitionSet()) elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions)); @@ -50,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug) auto context = classifier.getNN()->extractContext(elements[index].config).back(); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device); - auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); + auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0), 0); float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction); std::vector<std::pair<float, int>> scoresOfTransitions; for (unsigned int i = 0; i < prediction.size(0); i++) @@ -123,9 +121,6 @@ void Beam::update(ReadingMachine & machine, bool debug) continue; auto & config = element.config; - auto & classifier = *machine.getClassifier(config.getState()); - - classifier.setState(config.getState()); auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition); diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 474db0788b45b047b03b16905bfab0f9ed9f8dce..ad2e6a66bbc9eb5fdbe874bbc1760895294f3e42 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -39,7 +39,6 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh } catch(std::exception & e) {util::myThrow(e.what());} baseConfig = beam[0].config; - machine.getClassifier(baseConfig.getState())->setState(baseConfig.getState()); if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) { diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 17bfc8487da18fa2079b295a32a69c5a80edb849..e4b22080c791147f700e88a4dd2a50cbea3cd207 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -22,7 +22,6 @@ class Classifier std::shared_ptr<NeuralNetworkImpl> nn; std::unique_ptr<torch::optim::Optimizer> optimizer; std::string optimizerType, optimizerParameters; - std::string state; std::vector<std::string> states; std::filesystem::path path; bool regression{false}; @@ -39,7 +38,7 @@ class Classifier public : Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train); - TransitionSet & getTransitionSet(); + TransitionSet & getTransitionSet(const std::string & state); NeuralNetwork & getNN(); const std::string & getName() const; int getNbParameters() const; @@ -47,8 +46,7 @@ class Classifier void loadOptimizer(); void saveOptimizer(); torch::optim::Optimizer & getOptimizer(); - void setState(const std::string & state); - float getLossMultiplier(); + float getLossMultiplier(const std::string & state); const std::vector<std::string> & getStates() const; void saveDicts(); void saveBest(); diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index b5e929bfee0b6245bbee111e2ba1353909afa2f3..68e89b7ec1d94ddc22ee4671aceabc003321ce53 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -110,7 +110,7 @@ int Classifier::getNbParameters() const return nbParameters; } -TransitionSet & Classifier::getTransitionSet() +TransitionSet & Classifier::getTransitionSet(const std::string & state) { if (!transitionSets.count(state)) util::myThrow(fmt::format("cannot find transition set for state '{}'", state)); @@ -196,12 +196,6 @@ torch::optim::Optimizer & Classifier::getOptimizer() return *optimizer; } -void Classifier::setState(const std::string & state) -{ - this->state = state; - nn->setState(state); -} - void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) { std::string anyBlanks = "(?:(?:\\s|\\t)*)"; @@ -244,7 +238,7 @@ void Classifier::resetOptimizer() util::myThrow(expected); } -float Classifier::getLossMultiplier() +float Classifier::getLossMultiplier(const std::string & state) { return lossMultipliers.at(state); } diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 4336d3adab6957dfb50dd9c1e7fd9e2f5221cda9..582acff8d4cf6f1311897210b368ed52bb8dabf8 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -95,7 +95,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path) TransitionSet & ReadingMachine::getTransitionSet(const std::string & state) { - return classifiers[state2classifier.at(state)]->getTransitionSet(); + return classifiers[state2classifier.at(state)]->getTransitionSet(state); } bool ReadingMachine::hasSplitWordTransitionSet() const diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 9b7efaec8e94d55f745aaf4d16ab7c5f5877c811..ed73c301bd90134f50b98a4778d5a6539b54f9aa 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -29,7 +29,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl public : ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path); - torch::Tensor forward(torch::Tensor input) override; + torch::Tensor forward(torch::Tensor input, const std::string & state) override; std::vector<std::vector<long>> extractContext(Config & config) override; void registerEmbeddings() override; void saveDicts(std::filesystem::path path) override; @@ -37,7 +37,6 @@ class ModularNetworkImpl : public NeuralNetworkImpl void setDictsState(Dict::State state) override; void setCountOcc(bool countOcc) override; void removeRareDictElements(float rarityThreshold) override; - void setState(const std::string & state); }; #endif diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 6058cebf9c39dff266622c656aefff66dd094f7b..8215ad2fff9438ab0b6e133f1591b9cddc7c369e 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -5,21 +5,16 @@ #include <filesystem> #include "Config.hpp" #include "NameHolder.hpp" -#include "StateHolder.hpp" -class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder +class NeuralNetworkImpl : public torch::nn::Module, public NameHolder { public : static torch::Device device; - private : - - std::string state; - public : - virtual torch::Tensor forward(torch::Tensor input) = 0; + virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; virtual void registerEmbeddings() = 0; virtual void saveDicts(std::filesystem::path path) = 0; diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index b20a779eb5f47979eecd7d67f64af3193492ff53..3c559e9c393146a86bb9ffc1c6f3dd42f0a89ac2 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -12,7 +12,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl public : RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); - torch::Tensor forward(torch::Tensor input) override; + torch::Tensor forward(torch::Tensor input, const std::string & state) override; std::vector<std::vector<long>> extractContext(Config &) override; void registerEmbeddings() override; void saveDicts(std::filesystem::path path) override; diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 553da4f7163b9e5e702213f5b74b42c5c8c9bbcc..1dbbdc7e46844a910a5d0884c46e8f6e62f192ae 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -5,9 +5,8 @@ #include <filesystem> #include "Config.hpp" #include "DictHolder.hpp" -#include "StateHolder.hpp" -class Submodule : public torch::nn::Module, public DictHolder, public StateHolder +class Submodule : public torch::nn::Module, public DictHolder { private : diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index e2e225c4ea35fe7fce904bb5e3d52ec138d28288..c936f85b75ebfc4d6ed9686b091a183d53bc5adc 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -80,7 +80,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); } -torch::Tensor ModularNetworkImpl::forward(torch::Tensor input) +torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string & state) { if (input.dim() == 1) input = input.unsqueeze(0); @@ -92,7 +92,7 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input) auto totalInput = inputDropout(torch::cat(outputs, 1)); - return outputLayersPerState.at(getState())(mlp(totalInput)); + return outputLayersPerState.at(state)(mlp(totalInput)); } std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config) @@ -149,10 +149,3 @@ void ModularNetworkImpl::removeRareDictElements(float rarityThreshold) } } -void ModularNetworkImpl::setState(const std::string & state) -{ - NeuralNetworkImpl::setState(state); - for (auto & mod : modules) - mod->setState(state); -} - diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 7a6491b6351a9a20e6d56d6423db44fb268067c2..87a604636595062008ecbd5d442111b4101c8b39 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -5,12 +5,12 @@ RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std: setName(name); } -torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) +torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string & state) { if (input.dim() == 1) input = input.unsqueeze(0); - return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true)); + return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true)); } std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 7f1aaec54deb294576221006fc9b5469f824f87c..56ecc44cbfdd3fdf069856b4da3963665ae9c068 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -53,7 +53,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); - machine.getClassifier(config.getState())->setState(config.getState()); while (true) { @@ -94,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: { auto & classifier = *machine.getClassifier(config.getState()); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device); - auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); + auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0), 0); entropy = NeuralNetworkImpl::entropy(prediction); std::vector<int> candidates; @@ -176,7 +175,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: break; config.setState(movement.first); - machine.getClassifier(config.getState())->setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) @@ -217,9 +215,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (train) machine.getClassifier(state)->getOptimizer().zero_grad(); - machine.getClassifier(state)->setState(state); - - auto prediction = machine.getClassifier(state)->getNN()(data); + auto prediction = machine.getClassifier(state)->getNN()->forward(data, state); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); @@ -229,7 +225,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance labels /= util::float2longScale; } - auto loss = machine.getClassifier(state)->getLossMultiplier()*machine.getClassifier(state)->getLossFunction()(prediction, labels); + auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels); float lossAsFloat = 0.0; try { @@ -340,7 +336,6 @@ void Trainer::extractActionSequence(BaseConfig & config) config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); - machine.getClassifier(config.getState())->setState(config.getState()); int curSeq = 0; int curSeqStartIndex = -1; @@ -403,7 +398,6 @@ void Trainer::extractActionSequence(BaseConfig & config) break; config.setState(movement.first); - machine.getClassifier(config.getState())->setState(movement.first); config.moveWordIndexRelaxed(movement.second); }