From 048c959c91b3ca37086a50ff3361b570a747f68f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 22 Jun 2020 22:58:31 +0200
Subject: [PATCH] Added custom hinge loss and added explorationThreshold
 program argument

---
 torch_modules/include/CustomHingeLoss.hpp | 13 ++++++++
 torch_modules/src/CustomHingeLoss.cpp     | 17 +++++++++++
 trainer/include/Trainer.hpp               |  7 +++--
 trainer/src/MacaonTrain.cpp               |  9 ++++--
 trainer/src/Trainer.cpp                   | 36 ++++++++++++++++-------
 5 files changed, 65 insertions(+), 17 deletions(-)
 create mode 100644 torch_modules/include/CustomHingeLoss.hpp
 create mode 100644 torch_modules/src/CustomHingeLoss.cpp

diff --git a/torch_modules/include/CustomHingeLoss.hpp b/torch_modules/include/CustomHingeLoss.hpp
new file mode 100644
index 0000000..e30e2b8
--- /dev/null
+++ b/torch_modules/include/CustomHingeLoss.hpp
@@ -0,0 +1,13 @@
+#ifndef CUSTOMHINGELOSS__H
+#define CUSTOMHINGELOSS__H
+
+#include <torch/torch.h>
+
+class CustomHingeLoss
+{
+  public :
+
+  torch::Tensor operator()(torch::Tensor prediction, torch::Tensor gold);
+};
+
+#endif
diff --git a/torch_modules/src/CustomHingeLoss.cpp b/torch_modules/src/CustomHingeLoss.cpp
new file mode 100644
index 0000000..bc3fc85
--- /dev/null
+++ b/torch_modules/src/CustomHingeLoss.cpp
@@ -0,0 +1,17 @@
+#include "CustomHingeLoss.hpp"
+
+torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold)
+{
+  torch::Tensor loss = torch::zeros(1);
+
+  for (unsigned int i = 0; i < prediction.size(0); i++)
+  {
+    loss += torch::relu(1 - torch::max(gold[i]*prediction[i])
+                          + torch::max((1-gold[i])*prediction[i]));
+  }
+
+  loss /= prediction.size(0);
+
+  return loss;
+}
+
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index ad14ef6..25d48f4 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -4,12 +4,13 @@
 #include "ReadingMachine.hpp"
 #include "ConfigDataset.hpp"
 #include "SubConfig.hpp"
+#include "CustomHingeLoss.hpp"
 
 class LossFunction
 {
   private :
 
-  std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss> fct;
+  std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss> fct;
 
   public :
 
@@ -69,13 +70,13 @@ class Trainer
 
   private :
 
-  void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
+  void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
   float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
 
   public :
 
   Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName);
-  void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
+  void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
   void makeDataLoader(std::filesystem::path dir);
   void makeDevDataLoader(std::filesystem::path dir);
   float epoch(bool printAdvancement);
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 9db5c5a..b999e2c 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -33,12 +33,14 @@ po::options_description MacaonTrain::getOptionsDescription()
       "Number of training epochs")
     ("batchSize", po::value<int>()->default_value(64),
       "Number of examples per batch")
+    ("explorationThreshold", po::value<float>()->default_value(0.1),
+      "Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset")
     ("machine", po::value<std::string>()->default_value(""),
       "Reading machine file content")
     ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
       "Description of what should happen during training")
     ("loss", po::value<std::string>()->default_value("CrossEntropy"),
-      "Loss function to use during training : CrossEntropy | bce | mse")
+      "Loss function to use during training : CrossEntropy | bce | mse | hinge")
     ("help,h", "Produce this help message");
 
   desc.add(req).add(opt);
@@ -128,6 +130,7 @@ int MacaonTrain::main()
   auto machineContent = variables["machine"].as<std::string>();
   auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
   auto lossFunction = variables["loss"].as<std::string>();
+  auto explorationThreshold = variables["explorationThreshold"].as<float>();
 
   auto trainStrategy = parseTrainStrategy(trainStrategyStr);
 
@@ -211,11 +214,11 @@ int MacaonTrain::main()
     if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
     {
       machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open);
-      trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
+      trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
       if (!computeDevScore)
       {
         machine.setDictsState(Dict::State::Closed);
-        trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
+        trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
       }
     }
     if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 90b65b7..74d6d13 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -9,6 +9,8 @@ LossFunction::LossFunction(std::string name)
     fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
   else if (util::lower(name) == "mse")
     fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));
+  else if (util::lower(name) == "hinge")
+    fct = CustomHingeLoss();
   else
     util::myThrow(fmt::format("unknown loss function name '{}'", name));
 }
@@ -23,6 +25,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g
     return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
   if (index == 2)
     return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
+  if (index == 3)
+    return std::get<3>(fct)(torch::softmax(prediction, 1), gold);
 
   util::myThrow("loss is not defined");
   return torch::Tensor();
@@ -38,7 +42,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
     gold[0] = goldIndexes.at(0);
     return gold;
   }
-  if (index == 1 or index == 2)
+  if (index == 1 or index == 2 or index == 3)
   {
     auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
     for (auto goldIndex : goldIndexes)
@@ -66,18 +70,18 @@ void Trainer::makeDevDataLoader(std::filesystem::path dir)
   devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 }
 
-void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
+void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
 {
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
 
-  extractExamples(config, debug, dir, epoch, dynamicOracle);
+  extractExamples(config, debug, dir, epoch, dynamicOracle, explorationThreshold);
 
   machine.saveDicts();
 }
 
-void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
+void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
 {
   torch::AutoGradMode useGrad(false);
 
@@ -129,22 +133,32 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
     {
       auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
-      auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
+      auto prediction = torch::softmax(machine.getClassifier()->getNN()(neuralInput), -1).squeeze();
   
-      int chosenTransition = -1;
       float bestScore = std::numeric_limits<float>::min();
+      std::vector<int> candidates;
 
       for (unsigned int i = 0; i < prediction.size(0); i++)
       {
         float score = prediction[i].item<float>();
-        if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
-        {
-          chosenTransition = i;
+        if (score > bestScore and appliableTransitions[i])
           bestScore = score;
-        }
       }
 
-      transition = machine.getTransitionSet().getTransition(chosenTransition);
+      for (unsigned int i = 0; i < prediction.size(0); i++)
+      {
+        float score = prediction[i].item<float>();
+        if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
+          candidates.emplace_back(i);
+      }
+
+      if (candidates.size() != 1)
+      {
+        fmt::print(stderr, "nbCand = {}\n", candidates.size());
+        std::cerr << prediction << std::endl;
+      }
+
+      transition = machine.getTransitionSet().getTransition(candidates[std::rand()%candidates.size()]);
     }
     else
     {
-- 
GitLab