From a31419f9b423d3fa3c5ec2b67321f22773976d3c Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 20 Apr 2020 15:58:55 +0200
Subject: [PATCH] Each classifier can have a different output layer in neural
 networks

---
 decoder/src/Decoder.cpp                 |  2 +
 reading_machine/include/Classifier.hpp  |  6 ++-
 reading_machine/src/Classifier.cpp      | 49 ++++++++++++++++++++-----
 torch_modules/include/ConfigDataset.hpp |  6 +--
 torch_modules/include/LSTMNetwork.hpp   |  3 +-
 torch_modules/include/MLP.hpp           |  4 +-
 torch_modules/include/NeuralNetwork.hpp |  3 ++
 torch_modules/include/RandomNetwork.hpp |  4 +-
 torch_modules/src/ConfigDataset.cpp     | 25 ++++++++-----
 torch_modules/src/LSTMNetwork.cpp       |  9 +++--
 torch_modules/src/MLP.cpp               | 14 ++++---
 torch_modules/src/NeuralNetwork.cpp     | 10 +++++
 torch_modules/src/RandomNetwork.cpp     |  4 +-
 trainer/include/Trainer.hpp             | 13 ++++++-
 trainer/src/Trainer.cpp                 | 45 ++++++++++++++---------
 15 files changed, 141 insertions(+), 56 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 33c4837..352a9e8 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -21,6 +21,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
   try
   {
   config.setState(machine.getStrategy().getInitialState());
+  machine.getClassifier()->setState(machine.getStrategy().getInitialState());
 
   while (true)
   {
@@ -94,6 +95,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
       break;
 
     config.setState(movement.first);
+    machine.getClassifier()->setState(movement.first);
     config.moveWordIndexRelaxed(movement.second);
   }
   } catch(std::exception & e) {util::myThrow(e.what());}
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 2bb3af9..4951e5d 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -10,14 +10,15 @@ class Classifier
   private :
 
   std::string name;
-  std::unique_ptr<TransitionSet> transitionSet;
+  std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets;
   std::shared_ptr<NeuralNetworkImpl> nn;
   std::unique_ptr<torch::optim::Adam> optimizer;
+  std::string state;
 
   private :
 
   void initNeuralNetwork(const std::vector<std::string> & definition);
-  void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex);
+  void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
 
   public :
 
@@ -29,6 +30,7 @@ class Classifier
   void loadOptimizer(std::filesystem::path path);
   void saveOptimizer(std::filesystem::path path);
   torch::optim::Adam & getOptimizer();
+  void setState(const std::string & state);
 };
 
 #endif
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index b3e0710..d8418c9 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -8,12 +8,31 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
   this->name = name;
   if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
         {
-          std::vector<std::string> tsFiles;
+          auto splited = util::split(sm.str(1), ' ');
 
-          for (auto & tsFilename : util::split(sm.str(1), ' '))
-            tsFiles.emplace_back(path.parent_path() / tsFilename);
+          for (auto & ss : splited)
+          {
+            std::vector<std::string> tsFiles;
+            std::vector<std::string> states;
+            for (auto & elem : util::split(ss, ','))
+              if (std::filesystem::path(elem).extension().empty())
+                states.emplace_back(elem);
+              else
+                tsFiles.emplace_back(path.parent_path() / elem);
+            if (tsFiles.empty())
+              util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss));
+            if (states.empty())
+              util::myThrow(fmt::format("invalid '{}' no states specified", ss));
+
+            for (auto & stateName : states)
+            {
+              if (transitionSets.count(stateName))
+                util::myThrow(fmt::format("state '{}' already assigned", stateName));
+
+              this->transitionSets.emplace(stateName, new TransitionSet(tsFiles));
+            }
+          }
 
-          this->transitionSet.reset(new TransitionSet(tsFiles));
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
 
@@ -32,7 +51,10 @@ int Classifier::getNbParameters() const
 
 TransitionSet & Classifier::getTransitionSet()
 {
-  return *transitionSet;
+  if (!transitionSets.count(state))
+    util::myThrow(fmt::format("cannot find transition set for state '{}'", state));
+
+  return *transitionSets[state];
 }
 
 NeuralNetwork & Classifier::getNN()
@@ -47,6 +69,10 @@ const std::string & Classifier::getName() const
 
 void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
 {
+  std::map<std::string,std::size_t> nbOutputsPerState;
+  for (auto & it : this->transitionSets)
+    nbOutputsPerState[it.first] = it.second->size();
+
   std::size_t curIndex = 1;
 
   std::string networkType;
@@ -58,9 +84,9 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));
 
   if (networkType == "Random")
-    this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
+    this->nn.reset(new RandomNetworkImpl(nbOutputsPerState));
   else if (networkType == "LSTM")
-    initLSTM(definition, curIndex);
+    initLSTM(definition, curIndex, nbOutputsPerState);
   else
     util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));
 
@@ -83,7 +109,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
 }
 
-void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex)
+void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
 {
   int unknownValueThreshold;
   std::vector<int> bufferContext, stackContext;
@@ -299,7 +325,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
 
-  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
+  this->nn.reset(new LSTMNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
 }
 
 void Classifier::loadOptimizer(std::filesystem::path path)
@@ -317,4 +343,9 @@ torch::optim::Adam & Classifier::getOptimizer()
   return *optimizer;
 }
 
+void Classifier::setState(const std::string & state)
+{
+  this->state = state;
+  nn->setState(state);
+}
 
diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp
index 59289f4..ee8d46c 100644
--- a/torch_modules/include/ConfigDataset.hpp
+++ b/torch_modules/include/ConfigDataset.hpp
@@ -4,12 +4,12 @@
 #include <torch/torch.h>
 #include "Config.hpp"
 
-class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::pair<torch::Tensor,torch::Tensor>>
+class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::tuple<torch::Tensor,torch::Tensor,std::string>>
 {
   private :
 
   std::size_t size_{0};
-  std::vector<std::tuple<int,int,std::filesystem::path>> exampleLocations;
+  std::vector<std::tuple<int,int,std::filesystem::path,std::string>> exampleLocations;
   torch::Tensor loadedTensor;
   std::optional<std::size_t> loadedTensorIndex;
   std::size_t nextIndexToGive{0};
@@ -18,7 +18,7 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase
 
   explicit ConfigDataset(std::filesystem::path dir);
   c10::optional<std::size_t> size() const override;
-  c10::optional<std::pair<torch::Tensor,torch::Tensor>> get_batch(std::size_t batchSize) override;
+  c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize) override;
   void reset() override;
   void load(torch::serialize::InputArchive &) override;
   void save(torch::serialize::OutputArchive &) const override;
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index e742b97..76b9303 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -24,10 +24,11 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
   SplitTransLSTM splitTransLSTM{nullptr};
   DepthLayerTreeEmbedding treeEmbedding{nullptr};
   std::vector<FocusedColumnLSTM> focusedLstms;
+  std::map<std::string,torch::nn::Linear> outputLayersPerState;
 
   public :
 
-  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
+  LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp
index 71520f2..be272f1 100644
--- a/torch_modules/include/MLP.hpp
+++ b/torch_modules/include/MLP.hpp
@@ -9,11 +9,13 @@ class MLPImpl : public torch::nn::Module
 
   std::vector<torch::nn::Linear> layers;
   std::vector<torch::nn::Dropout> dropouts;
+  std::size_t outSize{0};
 
   public :
 
-  MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params);
+  MLPImpl(int inputSize, std::vector<std::pair<int, float>> params);
   torch::Tensor forward(torch::Tensor input);
+  std::size_t outputSize() const;
 };
 TORCH_MODULE(MLP);
 
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 1237f09..3db8651 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -14,6 +14,7 @@ class NeuralNetworkImpl : public torch::nn::Module
   private :
 
   bool splitUnknown{false};
+  std::string state;
 
   protected : 
 
@@ -25,6 +26,8 @@ class NeuralNetworkImpl : public torch::nn::Module
   virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
   bool mustSplitUnknown() const;
   void setSplitUnknown(bool splitUnknown);
+  void setState(const std::string & state);
+  const std::string & getState() const;
 };
 TORCH_MODULE(NeuralNetwork);
 
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index 8f58d7b..f40c1b0 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -7,11 +7,11 @@ class RandomNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
-  long outputSize;
+  std::map<std::string,std::size_t> nbOutputsPerState;
 
   public :
 
-  RandomNetworkImpl(long outputSize);
+  RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config &, Dict &) const override;
 };
diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index e73e88f..3a4a2c5 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -6,12 +6,15 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir)
   for (auto & entry : std::filesystem::directory_iterator(dir))
     if (entry.is_regular_file())
     {
-      auto splited = util::split(entry.path().stem().string(), '-');
-      if (splited.size() != 2)
+      auto stem = entry.path().stem().string();
+      if (stem == "extracted")
         continue;
-      exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path()));
+      auto state = util::split(stem, '_')[0];
+      auto splited = util::split(util::split(stem, '_')[1], '-');
+      exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state));
       size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back());
     }
+
 }
 
 c10::optional<std::size_t> ConfigDataset::size() const
@@ -19,7 +22,7 @@ c10::optional<std::size_t> ConfigDataset::size() const
   return size_;
 }
 
-c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(std::size_t batchSize)
+c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
 {
   if (!loadedTensorIndex.has_value())
   {
@@ -33,25 +36,27 @@ c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(s
     loadedTensorIndex = loadedTensorIndex.value() + 1;
 
     if (loadedTensorIndex >= exampleLocations.size())
-      return c10::optional<std::pair<torch::Tensor,torch::Tensor>>();
+      return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
 
     torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
   }
 
-  std::pair<torch::Tensor, torch::Tensor> batch;
+  std::tuple<torch::Tensor, torch::Tensor, std::string> batch;
   if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0))
   {
-    batch.first = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
-    batch.second = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
+    std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
+    std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
     nextIndexToGive += batchSize;
   }
   else
   {
-    batch.first = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
-    batch.second = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
+    std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
+    std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
     nextIndexToGive = loadedTensor.size(0);
   }
 
+  std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]);
+
   return batch;
 }
 
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index a4f5863..476e24d 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -1,6 +1,6 @@
 #include "LSTMNetwork.hpp"
 
-LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
+LSTMNetworkImpl::LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
 {
   LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
   auto lstmOptionsAll = lstmOptions;
@@ -50,7 +50,10 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
     embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
   inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
 
-  mlp = register_module("mlp", MLP(currentOutputSize, nbOutputs, mlpParams));
+  mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
+
+  for (auto & it : nbOutputsPerState)
+    outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
 }
 
 torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
@@ -81,7 +84,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
 
   auto totalInput = inputDropout(torch::cat(outputs, 1));
 
-  return mlp(totalInput);
+  return outputLayersPerState.at(getState())(mlp(totalInput));
 }
 
 std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp
index b886245..03880ec 100644
--- a/torch_modules/src/MLP.cpp
+++ b/torch_modules/src/MLP.cpp
@@ -1,7 +1,7 @@
 #include "MLP.hpp"
 #include "fmt/core.h"
 
-MLPImpl::MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params)
+MLPImpl::MLPImpl(int inputSize, std::vector<std::pair<int, float>> params)
 {
   int inSize = inputSize;
 
@@ -10,17 +10,21 @@ MLPImpl::MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float
     layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first)));
     dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second)));
     inSize = param.first;
+    outSize = inSize;
   }
-
-  layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, outputSize)));
 }
 
 torch::Tensor MLPImpl::forward(torch::Tensor input)
 {
   torch::Tensor output = input;
-  for (unsigned int i = 0; i < layers.size()-1; i++)
+  for (unsigned int i = 0; i < layers.size(); i++)
     output = dropouts[i](torch::relu(layers[i](output)));
 
-  return layers.back()(output);
+  return output;
+}
+
+std::size_t MLPImpl::outputSize() const
+{
+  return outSize;
 }
 
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 235c677..987cfcb 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -12,3 +12,13 @@ void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown)
   this->splitUnknown = splitUnknown;
 }
 
+void NeuralNetworkImpl::setState(const std::string & state)
+{
+  this->state = state;
+}
+
+const std::string & NeuralNetworkImpl::getState() const
+{
+  return state;
+}
+
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 7f0137d..8dfafc2 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -1,6 +1,6 @@
 #include "RandomNetwork.hpp"
 
-RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize)
+RandomNetworkImpl::RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState)
 {
 }
 
@@ -9,7 +9,7 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
   if (input.dim() == 1)
     input = input.unsqueeze(0);
 
-  return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true));
+  return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true));
 }
 
 std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 7994a2f..91aedbe 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -9,6 +9,17 @@ class Trainer
 {
   private :
 
+  struct Examples
+  {
+    std::vector<torch::Tensor> contexts;
+    std::vector<torch::Tensor> classes;
+
+    int currentExampleIndex{0};
+    int lastSavedIndex{0};
+  };
+
+  private :
+
   using Dataset = ConfigDataset;
   using DataLoader = std::unique_ptr<torch::data::StatefulDataLoader<Dataset>>;
 
@@ -26,7 +37,7 @@ class Trainer
 
   void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
   float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
-  void saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir);
+void saveExamples(std::string state, Examples & examples, std::filesystem::path dir);
 
   public :
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 95e98eb..f76f220 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -33,30 +33,28 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
   devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 }
 
-void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir)
+void Trainer::saveExamples(std::string state, Examples & examples, std::filesystem::path dir)
 {
-  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
-  auto filename = fmt::format("{}-{}.tensor", lastSavedIndex, currentExampleIndex-1);
+  auto tensorToSave = torch::cat({torch::stack(examples.contexts), torch::stack(examples.classes)}, 1);
+  auto filename = fmt::format("{}_{}-{}.tensor", state, examples.lastSavedIndex, examples.currentExampleIndex-1);
   torch::save(tensorToSave, dir/filename);
-  lastSavedIndex = currentExampleIndex;
-  contexts.clear();
-  classes.clear();
+  examples.lastSavedIndex = examples.currentExampleIndex;
+  examples.contexts.clear();
+  examples.classes.clear();
 }
 
 void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
 {
   torch::AutoGradMode useGrad(false);
 
-  int maxNbExamplesPerFile = 250000;
-  int currentExampleIndex = 0;
-  int lastSavedIndex = 0;
-  std::vector<torch::Tensor> contexts;
-  std::vector<torch::Tensor> classes;
+  int maxNbExamplesPerFile = 50000;
+  std::map<std::string, Examples> examplesPerState;
 
   std::filesystem::create_directories(dir);
 
   config.addPredicted(machine.getPredicted());
   config.setState(machine.getStrategy().getInitialState());
+  machine.getClassifier()->setState(machine.getStrategy().getInitialState());
   machine.getStrategy().reset();
 
   auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
@@ -75,6 +73,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     if (entry.is_regular_file())
       std::filesystem::remove(entry.path());
 
+  int totalNbExamples = 0;
+
   while (true)
   {
     if (debug)
@@ -85,6 +85,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     std::vector<std::vector<long>> context;
 
+    auto & contexts = examplesPerState[config.getState()].contexts;
+    auto & classes = examplesPerState[config.getState()].classes;
+    auto & currentExampleIndex = examplesPerState[config.getState()].currentExampleIndex;
+    auto & lastSavedIndex = examplesPerState[config.getState()].lastSavedIndex;
+
     try
     {
       context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
@@ -133,10 +138,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     gold[0] = goldIndex;
 
     currentExampleIndex += context.size();
+    totalNbExamples += context.size();
     classes.insert(classes.end(), context.size(), gold);
 
     if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile)
-      saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
+      saveExamples(config.getState(), examplesPerState[config.getState()], dir);
 
     transition->apply(config);
     config.addToHistory(transition->getName());
@@ -148,14 +154,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
       break;
 
     config.setState(movement.first);
+    machine.getClassifier()->setState(movement.first);
     config.moveWordIndexRelaxed(movement.second);
 
     if (config.needsUpdate())
       config.update();
   }
 
-  if (!contexts.empty())
-    saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
+  for (auto & it : examplesPerState)
+    if (!it.second.contexts.empty())
+      saveExamples(it.first, it.second, dir);
 
   std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
   if (!f)
@@ -164,7 +172,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
   machine.saveDicts();
 
-  fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
+  fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
 }
 
 float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
@@ -188,8 +196,11 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
     if (train)
       machine.getClassifier()->getOptimizer().zero_grad();
 
-    auto data = batch.first;
-    auto labels = batch.second;
+    auto data = std::get<0>(batch);
+    auto labels = std::get<1>(batch);
+    auto state = std::get<2>(batch);
+
+    machine.getClassifier()->setState(state);
 
     auto prediction = machine.getClassifier()->getNN()(data);
     if (prediction.dim() == 1)
-- 
GitLab