From b44bc247feb9a81282cd65b474f85609832c8763 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 30 Jul 2019 14:59:19 +0200
Subject: [PATCH] Added functions to pass continuous gold vector when updating
 neural network

---
 maca_common/src/programOptionsTemplates.cpp |   2 +-
 neural_network/include/GeneticAlgorithm.hpp |  17 ++
 neural_network/include/MLP.hpp              |  15 ++
 neural_network/include/MLPBase.hpp          |  36 ++--
 neural_network/include/NeuralNetwork.hpp    |  18 ++
 neural_network/src/GeneticAlgorithm.cpp     |  18 ++
 neural_network/src/MLP.cpp                  |  39 +++-
 neural_network/src/MLPBase.cpp              | 188 ++++++++++++--------
 trainer/include/Trainer.hpp                 |   3 +
 trainer/src/TrainInfos.cpp                  |   6 +-
 trainer/src/Trainer.cpp                     |  12 ++
 transition_machine/include/Classifier.hpp   |  23 +++
 transition_machine/src/Classifier.cpp       |  31 ++++
 13 files changed, 321 insertions(+), 87 deletions(-)

diff --git a/maca_common/src/programOptionsTemplates.cpp b/maca_common/src/programOptionsTemplates.cpp
index 0e97371..8a67aaa 100644
--- a/maca_common/src/programOptionsTemplates.cpp
+++ b/maca_common/src/programOptionsTemplates.cpp
@@ -54,7 +54,7 @@ po::options_description getTrainOptionsDescription()
     ("optimizer", po::value<std::string>()->default_value("amsgrad"),
       "The learning algorithm to use : amsgrad | adam | sgd")
     ("loss", po::value<std::string>()->default_value("neglogsoftmax"),
-      "The loss function to use : neglogsoftmax | weighted")
+      "The loss function to use : neglogsoftmax")
     ("dev", po::value<std::string>()->default_value(""),
       "Development corpus formated according to the MCD")
     ("lang", po::value<std::string>()->default_value("fr"),
diff --git a/neural_network/include/GeneticAlgorithm.hpp b/neural_network/include/GeneticAlgorithm.hpp
index 83ce2d1..b5ec4c3 100644
--- a/neural_network/include/GeneticAlgorithm.hpp
+++ b/neural_network/include/GeneticAlgorithm.hpp
@@ -105,6 +105,14 @@ class GeneticAlgorithm : public NeuralNetwork
   /// @return The loss.
   float update(FeatureModel::FeatureDescription & fd, int gold) override;
 
+  /// @brief Update the parameters according to the given gold vector.
+  ///
+  /// @param fd The input to use.
+  /// @param gold The gold vector for this input.
+  ///
+  /// @return The loss.
+  float update(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold) override;
+
   /// @brief Get the loss according to the given gold class.
   ///
   /// @param fd The input to use.
@@ -113,6 +121,14 @@ class GeneticAlgorithm : public NeuralNetwork
   /// @return The loss.
   float getLoss(FeatureModel::FeatureDescription & fd, int gold) override;
 
+  /// @brief Get the loss according to the given gold vector.
+  ///
+  /// @param fd The input to use.
+  /// @param gold The gold vector for this input.
+  ///
+  /// @return The loss.
+  float getLoss(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold) override;
+
   /// @brief Save the GeneticAlgorithm to a file.
   /// 
   /// @param filename The file to write the GeneticAlgorithm to.
@@ -122,6 +138,7 @@ class GeneticAlgorithm : public NeuralNetwork
   ///
   /// @param output Where the topology will be printed.
   void printTopology(FILE * output) override;
+  void endOfIteration();
 };
 
 #endif
diff --git a/neural_network/include/MLP.hpp b/neural_network/include/MLP.hpp
index 47bc80e..37d707d 100644
--- a/neural_network/include/MLP.hpp
+++ b/neural_network/include/MLP.hpp
@@ -60,7 +60,21 @@ class MLP : public NeuralNetwork
   /// @param gold The gold class of this input.
   ///
   /// @return The loss.
+  float update(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold) override;
+  /// @brief Get the loss according to the given gold class.
+  ///
+  /// @param fd The input to use.
+  /// @param gold The gold class of this input.
+  ///
+  /// @return The loss.
   float getLoss(FeatureModel::FeatureDescription & fd, int gold) override;
+  /// @brief Get the loss according to the given gold vector.
+  ///
+  /// @param fd The input to use.
+  /// @param gold The gold vector for this input.
+  ///
+  /// @return The loss.
+  float getLoss(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold) override;
   /// @brief Save the MLP to a file.
   /// 
   /// @param filename The file to write the MLP to.
@@ -73,6 +87,7 @@ class MLP : public NeuralNetwork
   ///
   /// @return A pointer to the newly allocated trainer.
   dynet::Trainer * createTrainer();
+  void endOfIteration();
 };
 
 #endif
diff --git a/neural_network/include/MLPBase.hpp b/neural_network/include/MLPBase.hpp
index 06c02ad..826a7d8 100644
--- a/neural_network/include/MLPBase.hpp
+++ b/neural_network/include/MLPBase.hpp
@@ -30,10 +30,14 @@ class MLPBase
   /// @brief Must the Layer dropout rate be taken into account during the computations ? Usually it is only during the training step.
   bool dropoutActive;
 
-  /// @brief The current minibatch.
-  std::vector<FeatureModel::FeatureDescription> fds;
+  /// @brief The current minibatch for one hot golds.
+  std::vector<FeatureModel::FeatureDescription> fdsOneHot;
   /// @brief gold classes of the current minibatch.
-  std::vector<unsigned int> golds;
+  std::vector<unsigned int> goldsOneHot;
+  /// @brief The current minibatch for continuous golds.
+  std::vector<FeatureModel::FeatureDescription> fdsContinuous;
+  /// @brief gold outputs of the current minibatch.
+  std::vector< std::vector<float> > goldsContinuous;
 
   private :
 
@@ -86,16 +90,6 @@ class MLPBase
   /// @param model The dynet model that will contain the loaded parameters.
   /// @param filename The file from which the parameters will be read.
   void loadParameters(dynet::ParameterCollection & model, const std::string & filename);
-  /// @brief Get the loss expression 
-  ///
-  /// @param output Output from the neural network
-  /// @param oneHotGolds Indexes of gold classes (batched form)
-  ///
-  /// @return The loss expression
-  dynet::Expression weightedLoss(dynet::Expression & output, std::vector<unsigned int> & oneHotGolds);
-
-  dynet::Expression errorCorrectionLoss(dynet::ComputationGraph & cg, dynet::Expression & output, std::vector<unsigned int> & oneHotGolds);
-
   /// @brief initialize a new untrained MLP from a desired topology.
   ///
   /// topology example for 2 hidden layers : (150,RELU,0.3)(50,ELU,0.2)\n
@@ -123,6 +117,13 @@ class MLPBase
   ///
   /// @return The loss.
   float update(FeatureModel::FeatureDescription & fd, int gold);
+  /// @brief Update the parameters according to the given gold output.
+  ///
+  /// @param fd The input to use.
+  /// @param gold The gold output for this input.
+  ///
+  /// @return The loss.
+  float update(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold);
   /// @brief Get the loss according to the given gold class.
   ///
   /// @param fd The input to use.
@@ -130,10 +131,19 @@ class MLPBase
   ///
   /// @return The loss.
   float getLoss(FeatureModel::FeatureDescription & fd, int gold);
+  /// @brief Get the loss according to the given output vector.
+  ///
+  /// @param fd The input to use.
+  /// @param gold The gold output for this input.
+  ///
+  /// @return The loss.
+  float getLoss(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold);
   /// @brief Print the topology (Layers) of the MLP.
   ///
   /// @param output Where the topology will be printed.
   void printTopology(FILE * output);
+  /// @brief Clear the current batch.
+  void endOfIteration();
 };
 
 #endif
diff --git a/neural_network/include/NeuralNetwork.hpp b/neural_network/include/NeuralNetwork.hpp
index 4b88540..7473548 100644
--- a/neural_network/include/NeuralNetwork.hpp
+++ b/neural_network/include/NeuralNetwork.hpp
@@ -145,6 +145,14 @@ class NeuralNetwork
   /// @return The loss.
   virtual float update(FeatureModel::FeatureDescription & fd, int gold) = 0;
 
+  /// @brief Update the parameters according to the given gold vector.
+  ///
+  /// @param fd The input to use.
+  /// @param gold The gold vector for this input.
+  ///
+  /// @return The loss.
+  virtual float update(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold) = 0;
+
   /// @brief Get the loss according to the given gold class.
   ///
   /// @param fd The input to use.
@@ -153,6 +161,14 @@ class NeuralNetwork
   /// @return The loss.
   virtual float getLoss(FeatureModel::FeatureDescription & fd, int gold) = 0;
 
+  /// @brief Get the loss according to the given gold vector.
+  ///
+  /// @param fd The input to use.
+  /// @param gold The gold vector for this input.
+  ///
+  /// @return The loss.
+  virtual float getLoss(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold) = 0;
+
   /// @brief Save the NeuralNetwork to a file.
   /// 
   /// @param filename The file to write the NeuralNetwork to.
@@ -163,6 +179,8 @@ class NeuralNetwork
   /// @param output Where the topology will be printed.
   virtual void printTopology(FILE * output) = 0;
 
+  virtual void endOfIteration() = 0;
+
   /// @brief Return the model.
   ///
   /// @return The model of this NeuralNetwork.
diff --git a/neural_network/src/GeneticAlgorithm.cpp b/neural_network/src/GeneticAlgorithm.cpp
index 6d206e9..c57577c 100644
--- a/neural_network/src/GeneticAlgorithm.cpp
+++ b/neural_network/src/GeneticAlgorithm.cpp
@@ -71,6 +71,18 @@ float GeneticAlgorithm::getLoss(FeatureModel::FeatureDescription &, int)
   return loss;
 }
 
+float GeneticAlgorithm::getLoss(FeatureModel::FeatureDescription &, const std::vector<float> &)
+{
+  fprintf(stderr, "ERROR (%s) : not implemented. Aborting.\n", ERRINFO);
+  exit(1);
+}
+
+float GeneticAlgorithm::update(FeatureModel::FeatureDescription & , const std::vector<float> & )
+{
+  fprintf(stderr, "ERROR (%s) : not implemented. Aborting.\n", ERRINFO);
+  exit(1);
+}
+
 float GeneticAlgorithm::update(FeatureModel::FeatureDescription & fd, int gold)
 {
   bool haveBeenUpdated = false;
@@ -300,3 +312,9 @@ void GeneticAlgorithm::Individual::mutate(float probability)
     }
 }
 
+void GeneticAlgorithm::endOfIteration()
+{
+  for (auto & it : generation)
+    it->mlp.endOfIteration();
+}
+
diff --git a/neural_network/src/MLP.cpp b/neural_network/src/MLP.cpp
index 15b5742..1b65be2 100644
--- a/neural_network/src/MLP.cpp
+++ b/neural_network/src/MLP.cpp
@@ -68,9 +68,41 @@ float MLP::update(FeatureModel::FeatureDescription & fd, int gold)
   }
 }
 
+float MLP::update(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold)
+{
+  try
+  {
+    float loss = mlp.update(fd, gold);
+    trainer->update();
+    return loss;
+  } catch (BatchNotFull &)
+  {
+    return 0.0;
+  }
+}
+
 float MLP::getLoss(FeatureModel::FeatureDescription & fd, int gold)
 {
-  return mlp.getLoss(fd, gold);
+  try
+  {
+    float loss = mlp.getLoss(fd, gold);
+    return loss;
+  } catch (BatchNotFull &)
+  {
+    return 0.0;
+  }
+}
+
+float MLP::getLoss(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold)
+{
+  try
+  {
+    float loss = mlp.getLoss(fd, gold);
+    return loss;
+  } catch (BatchNotFull &)
+  {
+    return 0.0;
+  }
 }
 
 void MLP::save(const std::string & filename)
@@ -86,3 +118,8 @@ void MLP::printTopology(FILE * output)
   mlp.printTopology(output);
 }
 
+void MLP::endOfIteration()
+{
+  mlp.endOfIteration();
+}
+
diff --git a/neural_network/src/MLPBase.cpp b/neural_network/src/MLPBase.cpp
index b88147c..00ce27c 100644
--- a/neural_network/src/MLPBase.cpp
+++ b/neural_network/src/MLPBase.cpp
@@ -91,16 +91,16 @@ std::vector<float> MLPBase::predict(FeatureModel::FeatureDescription & fd)
 
 float MLPBase::update(FeatureModel::FeatureDescription & fd, int gold)
 {
-  fds.emplace_back(fd);
-  golds.emplace_back(gold);
+  fdsOneHot.emplace_back(fd);
+  goldsOneHot.emplace_back(gold);
 
-  if ((int)fds.size() < ProgramParameters::batchSize)
+  if ((int)fdsOneHot.size() < ProgramParameters::batchSize)
     throw BatchNotFull();
 
   std::vector<dynet::Expression> inputs;
   dynet::ComputationGraph cg;
 
-  for (auto & example : fds)
+  for (auto & example : fdsOneHot)
   {
     std::vector<dynet::Expression> expressions;
 
@@ -117,15 +117,7 @@ float MLPBase::update(FeatureModel::FeatureDescription & fd, int gold)
  
   if (ProgramParameters::loss == "neglogsoftmax")
   {
-    batchedLoss = dynet::sum_batches(pickneglogsoftmax(output, golds));
-  }
-  else if (ProgramParameters::loss == "weighted")
-  {
-    batchedLoss = weightedLoss(output, golds);
-  }
-  else if (ProgramParameters::loss == "errorCorrection")
-  {
-    batchedLoss = errorCorrectionLoss(cg, output, golds);
+    batchedLoss = dynet::sum_batches(pickneglogsoftmax(output, goldsOneHot));
   }
   else
   {
@@ -137,42 +129,83 @@ float MLPBase::update(FeatureModel::FeatureDescription & fd, int gold)
 
   checkGradients();
 
-  fds.clear();
-  golds.clear();
+  fdsOneHot.clear();
+  goldsOneHot.clear();
 
   return as_scalar(batchedLoss.value());
 }
 
-float MLPBase::getLoss(FeatureModel::FeatureDescription & fd, int gold)
+float MLPBase::update(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold)
 {
+  fdsContinuous.emplace_back(fd);
+  goldsContinuous.emplace_back(gold);
+
+  if ((int)fdsContinuous.size() < ProgramParameters::batchSize)
+    throw BatchNotFull();
+
   std::vector<dynet::Expression> inputs;
-  std::vector<unsigned int> goldss;
-  goldss.emplace_back(gold);
   dynet::ComputationGraph cg;
 
-  std::vector<dynet::Expression> expressions;
+  for (auto & example : fdsContinuous)
+  {
+    std::vector<dynet::Expression> expressions;
 
-  for (auto & featValue : fd.values)
-    expressions.emplace_back(NeuralNetwork::featValue2Expression(cg, featValue));
+    for (auto & featValue : example.values)
+      expressions.emplace_back(NeuralNetwork::featValue2Expression(cg, featValue));
 
-  dynet::Expression input = dynet::concatenate(expressions);
-  inputs.emplace_back(input);
+    dynet::Expression input = dynet::concatenate(expressions);
+    inputs.emplace_back(input);
+  }
 
   dynet::Expression batchedInput = dynet::concatenate_to_batch(inputs);
   dynet::Expression output = run(cg, batchedInput);
   dynet::Expression batchedLoss;
+  std::vector<dynet::Expression> goldExpressions;
+  for (auto & gold : goldsContinuous)
+    goldExpressions.emplace_back(dynet::input(cg, dynet::Dim({1,(unsigned int)gold.size()}), gold));
  
-  if (ProgramParameters::loss == "neglogsoftmax")
-  {
-    batchedLoss = dynet::sum_batches(pickneglogsoftmax(output, goldss));
-  }
-  else if (ProgramParameters::loss == "weighted")
+  dynet::Expression batchedGold = dynet::concatenate_to_batch(goldExpressions);
+  batchedLoss = dynet::sum_batches(dynet::squared_distance(output, batchedGold));
+
+  cg.backward(batchedLoss);
+
+  checkGradients();
+
+  fdsContinuous.clear();
+  goldsContinuous.clear();
+
+  return as_scalar(batchedLoss.value());
+}
+
+float MLPBase::getLoss(FeatureModel::FeatureDescription & fd, int gold)
+{
+  fdsOneHot.emplace_back(fd);
+  goldsOneHot.emplace_back(gold);
+
+  if ((int)fdsOneHot.size() < ProgramParameters::batchSize)
+    throw BatchNotFull();
+
+  std::vector<dynet::Expression> inputs;
+  dynet::ComputationGraph cg;
+
+  for (auto & example : fdsOneHot)
   {
-    batchedLoss = weightedLoss(output, goldss);
+    std::vector<dynet::Expression> expressions;
+
+    for (auto & featValue : example.values)
+      expressions.emplace_back(NeuralNetwork::featValue2Expression(cg, featValue));
+
+    dynet::Expression input = dynet::concatenate(expressions);
+    inputs.emplace_back(input);
   }
-  else if (ProgramParameters::loss == "errorCorrection")
+
+  dynet::Expression batchedInput = dynet::concatenate_to_batch(inputs);
+  dynet::Expression output = run(cg, batchedInput);
+  dynet::Expression batchedLoss;
+ 
+  if (ProgramParameters::loss == "neglogsoftmax")
   {
-    batchedLoss = errorCorrectionLoss(cg, output, goldss);
+    batchedLoss = dynet::sum_batches(pickneglogsoftmax(output, goldsOneHot));
   }
   else
   {
@@ -180,6 +213,51 @@ float MLPBase::getLoss(FeatureModel::FeatureDescription & fd, int gold)
     exit(1);
   }
 
+  checkGradients();
+
+  fdsOneHot.clear();
+  goldsOneHot.clear();
+
+  return as_scalar(batchedLoss.value());
+}
+
+float MLPBase::getLoss(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold)
+{
+  fdsContinuous.emplace_back(fd);
+  goldsContinuous.emplace_back(gold);
+
+  if ((int)fdsContinuous.size() < ProgramParameters::batchSize)
+    throw BatchNotFull();
+
+  std::vector<dynet::Expression> inputs;
+  dynet::ComputationGraph cg;
+
+  for (auto & example : fdsContinuous)
+  {
+    std::vector<dynet::Expression> expressions;
+
+    for (auto & featValue : example.values)
+      expressions.emplace_back(NeuralNetwork::featValue2Expression(cg, featValue));
+
+    dynet::Expression input = dynet::concatenate(expressions);
+    inputs.emplace_back(input);
+  }
+
+  dynet::Expression batchedInput = dynet::concatenate_to_batch(inputs);
+  dynet::Expression output = run(cg, batchedInput);
+  dynet::Expression batchedLoss;
+  std::vector<dynet::Expression> goldExpressions;
+  for (auto & gold : goldsContinuous)
+    goldExpressions.emplace_back(dynet::input(cg, dynet::Dim({1,(unsigned int)gold.size()}), gold));
+ 
+  dynet::Expression batchedGold = dynet::concatenate_to_batch(goldExpressions);
+  batchedLoss = dynet::sum_batches(dynet::squared_distance(output, batchedGold));
+
+  checkGradients();
+
+  fdsContinuous.clear();
+  goldsContinuous.clear();
+
   return as_scalar(batchedLoss.value());
 }
 
@@ -206,46 +284,6 @@ void MLPBase::checkGradients()
   }
 }
 
-dynet::Expression MLPBase::weightedLoss(dynet::Expression & output, std::vector<unsigned int> & oneHotGolds)
-{
-  std::vector<dynet::Expression> lossExpr;
-  for (unsigned int i = 0; i < output.dim().batch_elems(); i++)
-  {
-    lossExpr.emplace_back(dynet::pickneglogsoftmax(dynet::pick_batch_elem(output, i), oneHotGolds[i]));
-    auto outputVect = dynet::as_vector(dynet::pick_batch_elem(output,i).value());
-    int prediction = 0;
-    for (unsigned int j = 1; j < outputVect.size(); j++)
-      if(outputVect[j] > outputVect[prediction])
-        prediction = j;
-    int gold = oneHotGolds[i];
-    if (prediction == 1 && gold == 0)
-    {
-      lossExpr.back() = lossExpr.back() * 100.0;
-    }
-  }
-
-  return dynet::sum(lossExpr);
-}
-
-dynet::Expression MLPBase::errorCorrectionLoss(dynet::ComputationGraph & cg, dynet::Expression & output, std::vector<unsigned int> & oneHotGolds)
-{
-  std::vector<dynet::Expression> lossExpr;
-  for (unsigned int i = 0; i < output.dim().batch_elems(); i++)
-  {
-    unsigned int u = 0;
-    dynet::Expression c = dynet::pick(dynet::one_hot(cg, layers.back().output_dim, oneHotGolds[i]),u);
-    dynet::Expression a = dynet::pick(dynet::softmax(dynet::pick_batch_elem(output,i)),u);
-    lossExpr.emplace_back(dynet::pickneglogsoftmax(dynet::pick_batch_elem(output, i),oneHotGolds[i])+2-c-a*c+(dynet::acos(a-1)*(c-1)));
-    if (ProgramParameters::debug)
-    {
-      cg.forward(lossExpr.back());
-      fprintf(stderr, "a=%.2f c=%.2f loss=%.2f\n", dynet::as_scalar(a.value()),dynet::as_scalar(c.value()),dynet::as_scalar(lossExpr.back().value()));
-    }
-  }
-
-  return dynet::sum(lossExpr);
-}
-
 dynet::Expression MLPBase::run(dynet::ComputationGraph & cg, dynet::Expression x)
 {
   static std::vector< std::pair<std::string,dynet::Expression> > exprForDebug;
@@ -436,3 +474,11 @@ void MLPBase::printTopology(FILE * output)
   fprintf(output, ")\n");
 }
 
+void MLPBase::endOfIteration()
+{
+  fdsOneHot.clear();
+  fdsContinuous.clear();
+  goldsOneHot.clear();
+  goldsContinuous.clear();
+}
+
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 1f1171c..9cacbe4 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -113,6 +113,9 @@ class Trainer
   /// on the current epoch is its all time best.\n
   /// When a Classifier is saved that way, all the Dict involved are also saved.
   void train();
+
+  /// @brief Prepare Classifiers for next iteration.
+  void endOfIteration();
 };
 
 #endif
diff --git a/trainer/src/TrainInfos.cpp b/trainer/src/TrainInfos.cpp
index c1f8137..daaa46b 100644
--- a/trainer/src/TrainInfos.cpp
+++ b/trainer/src/TrainInfos.cpp
@@ -130,7 +130,7 @@ void TrainInfos::addTrainLoss(const std::string & classifier, float loss)
 
 void TrainInfos::addDevLoss(const std::string & classifier, float loss)
 {
-  devLossesPerClassifierPerEpoch[classifier].emplace_back(loss);
+  devLossesPerClassifierPerEpoch[classifier].back() += loss;
 }
 
 void TrainInfos::addTrainScore(const std::string & classifier, float score)
@@ -222,7 +222,11 @@ void TrainInfos::nextEpoch()
   lastEpoch++;
   saveToFilename();
   for (auto & it : topologyPrinted)
+  {
     trainLossesPerClassifierPerEpoch[it.first].emplace_back(0.0);
+    if (ProgramParameters::devLoss)
+      devLossesPerClassifierPerEpoch[it.first].emplace_back(0.0);
+  }
 }
 
 void TrainInfos::computeMustSaves()
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 76bf1bd..61032a3 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -461,12 +461,24 @@ void Trainer::doStepTrain()
   trainConfig.addToEntropyHistory(entropy);
 }
 
+void Trainer::endOfIteration()
+{
+  auto classifiers = tm.getClassifiers();
+  for (auto * cla : classifiers)
+    if (cla->needsTrain())
+      cla->endOfIteration();
+}
+
 void Trainer::prepareNextEpoch()
 {
+  endOfIteration();
+
   printScoresAndSave(stderr);
   nbSteps = 0;
   TI.nextEpoch();
 
+  endOfIteration();
+
   if (TI.getEpoch() > ProgramParameters::nbIter)
     throw EndOfTraining();
 }
diff --git a/transition_machine/include/Classifier.hpp b/transition_machine/include/Classifier.hpp
index a432fef..001869a 100644
--- a/transition_machine/include/Classifier.hpp
+++ b/transition_machine/include/Classifier.hpp
@@ -158,11 +158,25 @@ class Classifier
   float trainOnExample(Config & config, int gold);
   /// @brief Train the classifier on a training example.
   ///
+  /// @param config The Config to work with.
+  /// @param gold The gold vector for this Config.
+  ///
+  /// @return The loss.
+  float trainOnExample(Config & config, const std::vector<float> & gold);
+  /// @brief Train the classifier on a training example.
+  ///
   /// @param fd The FeatureDescription to work with.
   /// @param gold The gold class of the FeatureDescription.
   ///
   /// @return The loss.
   float trainOnExample(FeatureModel::FeatureDescription & fd, int gold);
+  /// @brief Train the classifier on a training example.
+  ///
+  /// @param fd The FeatureDescription to work with.
+  /// @param gold The gold vector for this FeatureDescription.
+  ///
+  /// @return The loss.
+  float trainOnExample(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold);
   /// @brief Get the loss of the classifier on a training example.
   ///
   /// @param config The Config to work with.
@@ -170,6 +184,13 @@ class Classifier
   ///
   /// @return The loss.
   float getLoss(Config & config, int gold);
+  /// @brief Get the loss of the classifier on a training example.
+  ///
+  /// @param config The Config to work with.
+  /// @param gold The gold vector for this Config.
+  ///
+  /// @return The loss.
+  float getLoss(Config & config, const std::vector<float> & gold);
   /// @brief Get the name of an Action from its index.
   ///
   /// The index of an Action can be seen as the index of the corresponding output neuron in the underlying neural network.
@@ -213,6 +234,8 @@ class Classifier
   unsigned int getNbActions();
   /// @brief Get a pointer to the FeatureModel.
   FeatureModel * getFeatureModel();
+  /// @brief Prepare Classifier for next iteration.
+  void endOfIteration();
 };
 
 #endif
diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp
index 89f9a93..d28ed2c 100644
--- a/transition_machine/src/Classifier.cpp
+++ b/transition_machine/src/Classifier.cpp
@@ -286,6 +286,15 @@ float Classifier::trainOnExample(Config & config, int gold)
   return nn->update(fd, gold);
 }
 
+float Classifier::trainOnExample(Config & config, const std::vector<float> & gold)
+{
+  if (ProgramParameters::noNeuralNetwork)
+    return 0.0;
+
+  auto & fd = fm->getFeatureDescription(config);
+  return nn->update(fd, gold);
+}
+
 float Classifier::trainOnExample(FeatureModel::FeatureDescription & fd, int gold)
 {
   if (ProgramParameters::noNeuralNetwork)
@@ -294,6 +303,14 @@ float Classifier::trainOnExample(FeatureModel::FeatureDescription & fd, int gold
   return nn->update(fd, gold);
 }
 
+float Classifier::trainOnExample(FeatureModel::FeatureDescription & fd, const std::vector<float> & gold)
+{
+  if (ProgramParameters::noNeuralNetwork)
+    return 0.0;
+
+  return nn->update(fd, gold);
+}
+
 float Classifier::getLoss(Config & config, int gold)
 {
   if (ProgramParameters::noNeuralNetwork)
@@ -303,6 +320,15 @@ float Classifier::getLoss(Config & config, int gold)
   return nn->getLoss(fd, gold);
 }
 
+float Classifier::getLoss(Config & config, const std::vector<float> & gold)
+{
+  if (ProgramParameters::noNeuralNetwork)
+    return 0.0;
+
+  auto & fd = fm->getFeatureDescription(config);
+  return nn->getLoss(fd, gold);
+}
+
 void Classifier::explainCostOfActions(FILE * output, Config & config)
 {
   for (Action & a : as->actions)
@@ -362,3 +388,8 @@ FeatureModel * Classifier::getFeatureModel()
   return fm.get();
 }
 
+void Classifier::endOfIteration()
+{
+  nn->endOfIteration();
+}
+
-- 
GitLab