From 0c86cb53315b168fc479310660cce1b8d7072b5b Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 3 Mar 2021 15:03:27 +0100
Subject: [PATCH] Removed state from neuralnetwork

---
 decoder/src/Beam.cpp                     |  7 +------
 decoder/src/Decoder.cpp                  |  1 -
 reading_machine/include/Classifier.hpp   |  6 ++----
 reading_machine/src/Classifier.cpp       | 10 ++--------
 reading_machine/src/ReadingMachine.cpp   |  2 +-
 torch_modules/include/ModularNetwork.hpp |  3 +--
 torch_modules/include/NeuralNetwork.hpp  |  9 ++-------
 torch_modules/include/RandomNetwork.hpp  |  2 +-
 torch_modules/include/Submodule.hpp      |  3 +--
 torch_modules/src/ModularNetwork.cpp     | 11 ++---------
 torch_modules/src/RandomNetwork.cpp      |  4 ++--
 trainer/src/Trainer.cpp                  | 12 +++---------
 12 files changed, 18 insertions(+), 52 deletions(-)

diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index e39593c..9aad4f0 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -39,8 +39,6 @@ void Beam::update(ReadingMachine & machine, bool debug)
 
     auto & classifier = *machine.getClassifier(elements[index].config.getState());
 
-    classifier.setState(elements[index].config.getState());
-
     if (machine.hasSplitWordTransitionSet())
       elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions));
 
@@ -50,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
     auto context = classifier.getNN()->extractContext(elements[index].config).back();
     auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
 
-    auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
+    auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0), 0);
     float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
     std::vector<std::pair<float, int>> scoresOfTransitions;
     for (unsigned int i = 0; i < prediction.size(0); i++)
@@ -123,9 +121,6 @@ void Beam::update(ReadingMachine & machine, bool debug)
       continue;
 
     auto & config = element.config;
-    auto & classifier = *machine.getClassifier(config.getState());
-
-    classifier.setState(config.getState());
 
     auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition);
 
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 474db07..ad2e6a6 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -39,7 +39,6 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
   } catch(std::exception & e) {util::myThrow(e.what());}
 
   baseConfig = beam[0].config;
-  machine.getClassifier(baseConfig.getState())->setState(baseConfig.getState());
 
   if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
   {
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 17bfc84..e4b2208 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -22,7 +22,6 @@ class Classifier
   std::shared_ptr<NeuralNetworkImpl> nn;
   std::unique_ptr<torch::optim::Optimizer> optimizer;
   std::string optimizerType, optimizerParameters;
-  std::string state;
   std::vector<std::string> states;
   std::filesystem::path path;
   bool regression{false};
@@ -39,7 +38,7 @@ class Classifier
   public :
 
   Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train);
-  TransitionSet & getTransitionSet();
+  TransitionSet & getTransitionSet(const std::string & state);
   NeuralNetwork & getNN();
   const std::string & getName() const;
   int getNbParameters() const;
@@ -47,8 +46,7 @@ class Classifier
   void loadOptimizer();
   void saveOptimizer();
   torch::optim::Optimizer & getOptimizer();
-  void setState(const std::string & state);
-  float getLossMultiplier();
+  float getLossMultiplier(const std::string & state);
   const std::vector<std::string> & getStates() const;
   void saveDicts();
   void saveBest();
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index b5e929b..68e89b7 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -110,7 +110,7 @@ int Classifier::getNbParameters() const
   return nbParameters;
 }
 
-TransitionSet & Classifier::getTransitionSet()
+TransitionSet & Classifier::getTransitionSet(const std::string & state)
 {
   if (!transitionSets.count(state))
     util::myThrow(fmt::format("cannot find transition set for state '{}'", state));
@@ -196,12 +196,6 @@ torch::optim::Optimizer & Classifier::getOptimizer()
   return *optimizer;
 }
 
-void Classifier::setState(const std::string & state)
-{
-  this->state = state;
-  nn->setState(state);
-}
-
 void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
 {
   std::string anyBlanks = "(?:(?:\\s|\\t)*)";
@@ -244,7 +238,7 @@ void Classifier::resetOptimizer()
     util::myThrow(expected);
 }
 
-float Classifier::getLossMultiplier()
+float Classifier::getLossMultiplier(const std::string & state)
 {
   return lossMultipliers.at(state);
 }
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 4336d3a..582acff 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -95,7 +95,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
 
 TransitionSet & ReadingMachine::getTransitionSet(const std::string & state)
 {
-  return classifiers[state2classifier.at(state)]->getTransitionSet();
+  return classifiers[state2classifier.at(state)]->getTransitionSet(state);
 }
 
 bool ReadingMachine::hasSplitWordTransitionSet() const
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 9b7efae..ed73c30 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -29,7 +29,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   public :
 
   ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
-  torch::Tensor forward(torch::Tensor input) override;
+  torch::Tensor forward(torch::Tensor input, const std::string & state) override;
   std::vector<std::vector<long>> extractContext(Config & config) override;
   void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
@@ -37,7 +37,6 @@ class ModularNetworkImpl : public NeuralNetworkImpl
   void setDictsState(Dict::State state) override;
   void setCountOcc(bool countOcc) override;
   void removeRareDictElements(float rarityThreshold) override;
-  void setState(const std::string & state);
 };
 
 #endif
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 6058ceb..8215ad2 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -5,21 +5,16 @@
 #include <filesystem>
 #include "Config.hpp"
 #include "NameHolder.hpp"
-#include "StateHolder.hpp"
 
-class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder
+class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
 {
   public :
 
   static torch::Device device;
 
-  private :
-
-  std::string state;
-
   public :
 
-  virtual torch::Tensor forward(torch::Tensor input) = 0;
+  virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
   virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
   virtual void registerEmbeddings() = 0;
   virtual void saveDicts(std::filesystem::path path) = 0;
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index b20a779..3c559e9 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -12,7 +12,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
   public :
 
   RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
-  torch::Tensor forward(torch::Tensor input) override;
+  torch::Tensor forward(torch::Tensor input, const std::string & state) override;
   std::vector<std::vector<long>> extractContext(Config &) override;
   void registerEmbeddings() override;
   void saveDicts(std::filesystem::path path) override;
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 553da4f..1dbbdc7 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -5,9 +5,8 @@
 #include <filesystem>
 #include "Config.hpp"
 #include "DictHolder.hpp"
-#include "StateHolder.hpp"
 
-class Submodule : public torch::nn::Module, public DictHolder, public StateHolder
+class Submodule : public torch::nn::Module, public DictHolder
 {
   private :
 
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index e2e225c..c936f85 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -80,7 +80,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
     outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
 }
 
-torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
+torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string & state)
 {
   if (input.dim() == 1)
     input = input.unsqueeze(0);
@@ -92,7 +92,7 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
 
   auto totalInput = inputDropout(torch::cat(outputs, 1));
 
-  return outputLayersPerState.at(getState())(mlp(totalInput));
+  return outputLayersPerState.at(state)(mlp(totalInput));
 }
 
 std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config)
@@ -149,10 +149,3 @@ void ModularNetworkImpl::removeRareDictElements(float rarityThreshold)
   }
 }
 
-void ModularNetworkImpl::setState(const std::string & state)
-{
-  NeuralNetworkImpl::setState(state);
-  for (auto & mod : modules)
-    mod->setState(state);
-}
-
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 7a6491b..87a6046 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -5,12 +5,12 @@ RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std:
   setName(name);
 }
 
-torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
+torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string & state)
 {
   if (input.dim() == 1)
     input = input.unsqueeze(0);
 
-  return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true));
+  return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true));
 }
 
 std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 7f1aaec..56ecc44 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -53,7 +53,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
     config.addPredicted(machine.getPredicted());
     config.setStrategy(machine.getStrategyDefinition());
     config.setState(config.getStrategy().getInitialState());
-    machine.getClassifier(config.getState())->setState(config.getState());
 
     while (true)
     {
@@ -94,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
       {
         auto & classifier = *machine.getClassifier(config.getState());
         auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
-        auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
+        auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0), 0);
         entropy  = NeuralNetworkImpl::entropy(prediction);
     
         std::vector<int> candidates;
@@ -176,7 +175,6 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
         break;
 
       config.setState(movement.first);
-      machine.getClassifier(config.getState())->setState(movement.first);
       config.moveWordIndexRelaxed(movement.second);
 
       if (config.needsUpdate())
@@ -217,9 +215,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
     if (train)
       machine.getClassifier(state)->getOptimizer().zero_grad();
 
-    machine.getClassifier(state)->setState(state);
-
-    auto prediction = machine.getClassifier(state)->getNN()(data);
+    auto prediction = machine.getClassifier(state)->getNN()->forward(data, state);
     if (prediction.dim() == 1)
       prediction = prediction.unsqueeze(0);
 
@@ -229,7 +225,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
       labels /= util::float2longScale;
     }
 
-    auto loss = machine.getClassifier(state)->getLossMultiplier()*machine.getClassifier(state)->getLossFunction()(prediction, labels);
+    auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels);
     float lossAsFloat = 0.0;
     try
     {
@@ -340,7 +336,6 @@ void Trainer::extractActionSequence(BaseConfig & config)
   config.addPredicted(machine.getPredicted());
   config.setStrategy(machine.getStrategyDefinition());
   config.setState(config.getStrategy().getInitialState());
-  machine.getClassifier(config.getState())->setState(config.getState());
 
   int curSeq = 0;
   int curSeqStartIndex = -1;
@@ -403,7 +398,6 @@ void Trainer::extractActionSequence(BaseConfig & config)
       break;
 
     config.setState(movement.first);
-    machine.getClassifier(config.getState())->setState(movement.first);
     config.moveWordIndexRelaxed(movement.second);
   }
 
-- 
GitLab