From 9811325d5b17f8e36f48a7daaf62481af0c1a8f7 Mon Sep 17 00:00:00 2001
From: Hartbook <franck.dary@etu.univ-amu.fr>
Date: Sun, 1 Jul 2018 22:36:28 +0200
Subject: [PATCH] First version of a working training

---
 MLP/include/MLP.hpp                  |  2 ++
 MLP/src/MLP.cpp                      | 50 ++++++++++++++++++----------
 tape_machine/include/TapeMachine.hpp |  5 ++-
 tape_machine/src/Classifier.cpp      |  6 +++-
 trainer/src/Trainer.cpp              | 28 ++++++++++++++--
 5 files changed, 69 insertions(+), 22 deletions(-)

diff --git a/MLP/include/MLP.hpp b/MLP/include/MLP.hpp
index bb42ccc..1c205fc 100644
--- a/MLP/include/MLP.hpp
+++ b/MLP/include/MLP.hpp
@@ -57,6 +57,8 @@ class MLP
   dynet::Parameter & featValue2parameter(const FeatureModel::FeatureValue & fv);
   dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x);
   inline dynet::Expression activate(dynet::Expression h, Activation f);
+  dynet::Expression getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label);
+  void printParameters(FILE * output);
 
   public :
 
diff --git a/MLP/src/MLP.cpp b/MLP/src/MLP.cpp
index 7f2a84f..e6bef2a 100644
--- a/MLP/src/MLP.cpp
+++ b/MLP/src/MLP.cpp
@@ -120,25 +120,16 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC
   }
 
   dynet::Expression input = dynet::concatenate(expressions);
-  dynet::Expression output = run(cg, input);
-
-  /*
-  int nbInputs = layers[0].input_dim;
-
-  dynet::Expression x = reshape(concatenate_cols(cur_batch),
-                              dynet::Dim({nb_inputs}, cur_batch_size));
 
-  dynet::Expression loss_expr = get_loss(x_batch, cur_labels);
-
-  loss += as_scalar(computation_graph.forward(loss_expr));
-  nb_samples += cur_batch_size;
-  computation_graph.backward(loss_expr);
-  trainer.update();
-  */
+  dynet::Expression output = run(cg, input);
 
-  std::vector<float> res = as_vector(cg.forward(output));
+  if(trainMode)
+  {
+    cg.backward(pickneglogsoftmax(output, goldClass));
+    trainer.update();
+  }
 
-  return res;
+  return as_vector(cg.forward(output));
 }
 
 dynet::DynetParams & MLP::getDefaultParams()
@@ -158,9 +149,11 @@ dynet::Parameter & MLP::featValue2parameter(const FeatureModel::FeatureValue & f
   if(it != ptr2parameter.end())
     return it->second;
 
-  ptr2parameter[fv.vec] = model.add_parameters({1,fv.vec->size()}, dynet::ParameterInitFromVector(*fv.vec));
+  //ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1}, dynet::ParameterInitFromVector(*fv.vec));
+  ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1});
   it = ptr2parameter.find(fv.vec);
-  it->second.set_updated(fv.policy == FeatureModel::Policy::Final ? false : true);
+
+//  it->second.values()->v = fv.vec->data();
 
   return it->second;
 }
@@ -230,3 +223,24 @@ inline dynet::Expression MLP::activate(dynet::Expression h, Activation f)
   return h;
 }
 
+dynet::Expression MLP::getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label)
+{
+  dynet::Expression y = run(cg, x);
+  return pickneglogsoftmax(y, label);
+}
+
+void MLP::printParameters(FILE * output)
+{
+  for(auto & it : ptr2parameter)
+  {
+    auto & param = it.second;
+    dynet::Tensor * tensor = param.values();
+    float * value = tensor->v;
+    int dim = tensor->d.size();
+    fprintf(output, "Param : ");
+    for(int i = 0; i < dim; i++)
+      fprintf(output, "%.2f ", value[i]);
+    fprintf(output, "\n");
+  }
+}
+
diff --git a/tape_machine/include/TapeMachine.hpp b/tape_machine/include/TapeMachine.hpp
index 8e00b07..ac562f7 100644
--- a/tape_machine/include/TapeMachine.hpp
+++ b/tape_machine/include/TapeMachine.hpp
@@ -29,13 +29,16 @@ class TapeMachine
 
   private :
 
-  std::string name;
   std::map< std::string, std::unique_ptr<Classifier> > str2classifier;
   std::map< std::string, std::unique_ptr<State> > str2state;
   State * currentState;
 
   public :
 
+  std::string name;
+
+  public :
+
   TapeMachine(const std::string & filename);
   State * getCurrentState();
   Transition * getTransition(const std::string & action);
diff --git a/tape_machine/src/Classifier.cpp b/tape_machine/src/Classifier.cpp
index 18dd133..f728c14 100644
--- a/tape_machine/src/Classifier.cpp
+++ b/tape_machine/src/Classifier.cpp
@@ -72,7 +72,11 @@ Classifier::WeightedActions Classifier::weightActions(Config & config, const std
   for (unsigned int i = 0; i < scores.size(); i++)
     result.emplace_back(scores[i], as->actions[i].name);
 
-  std::sort(result.begin(), result.end());
+  std::sort(result.begin(), result.end(),
+    [](const std::pair<float, std::string> & a, const std::pair<float, std::string> & b)
+    {
+      return a.first > b.first; 
+    });
 
   return result;
 }
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index a85780d..3f181df 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -7,8 +7,14 @@ Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config)
 
 void Trainer::train()
 {
-  for (int i = 0; i < 2; i++)
+  int nbIter = 20;
+
+  fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
+
+  for (int i = 0; i < nbIter; i++)
   {
+    std::map< std::string, std::pair<int, int> > nbExamples;
+
     while (!config.isFinal())
     {
       TapeMachine::State * currentState = tm.getCurrentState();
@@ -20,7 +26,12 @@ void Trainer::train()
 
       std::string neededActionName = classifier->oracle->getAction(config);
       auto weightedActions = classifier->weightActions(config, neededActionName);
-      printWeightedActions(stderr, weightedActions);
+      //printWeightedActions(stderr, weightedActions);
+      std::string & predictedAction = weightedActions[0].second;
+
+      nbExamples[classifier->name].first++;
+      if(predictedAction == neededActionName)
+        nbExamples[classifier->name].second++;
 
       //fprintf(stderr, "Action : \'%s\'\n", neededActionName.c_str());
 
@@ -29,13 +40,26 @@ void Trainer::train()
       config.moveHead(transition->headMvt);
     }
 
+    fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
+    for(auto & it : nbExamples)
+      fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
+
     config.reset();
   }
 }
 
 void Trainer::printWeightedActions(FILE * output, Classifier::WeightedActions & wa)
 {
+  int nbCols = 80;
+  char symbol = '-';
+
+  for(int i = 0; i < nbCols; i++)
+    fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
+
   for (auto it : wa)
     fprintf(output, "%.2f\t%s\n", it.first, it.second.c_str());
+
+  for(int i = 0; i < nbCols; i++)
+    fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
 }
 
-- 
GitLab