diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp
index 1d1bc7543de401c86afabe1c4f9c6db4caaadb6c..936fa88b0e715a7d2b731748101ad784f8ff10ca 100644
--- a/reading_machine/include/TransitionSet.hpp
+++ b/reading_machine/include/TransitionSet.hpp
@@ -21,7 +21,7 @@ class TransitionSet
   TransitionSet(const std::vector<std::string> & filenames);
   TransitionSet(const std::string & filename);
   std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c, bool dynamic = false);
-  Transition * getBestAppliableTransition(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic = false);
+  std::vector<Transition *> getBestAppliableTransitions(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic = false);
   std::vector<Transition *> getNAppliableTransitions(const Config & c, int n);
   std::vector<int> getAppliableTransitions(const Config & c);
   std::size_t getTransitionIndex(const Transition * transition) const;
diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp
index 8c5c1a88712b155d21e0ee4615394c70e639c25a..8f146e7f04e39e1bc15495dbc1e35107fcd43799 100644
--- a/reading_machine/src/TransitionSet.cpp
+++ b/reading_machine/src/TransitionSet.cpp
@@ -80,28 +80,31 @@ std::vector<int> TransitionSet::getAppliableTransitions(const Config & c)
   return result;
 }
 
-Transition * TransitionSet::getBestAppliableTransition(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic)
+std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic)
 {
-  Transition * result = nullptr;
   int bestCost = std::numeric_limits<int>::max();
+  std::vector<Transition *> result;
+  std::vector<int> costs(transitions.size());
 
   for (unsigned int i = 0; i < transitions.size(); i++)
   {
     if (!appliableTransitions[i])
+    {
+      costs[i] = std::numeric_limits<int>::max();
       continue;
+    }
 
     int cost = dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c);
 
-    if (cost == 0)
-      return &transitions[i];
-
+    costs[i] = cost;
     if (cost < bestCost)
-    {
-      result = &transitions[i];
       bestCost = cost;
-    }
   }
 
+  for (unsigned int i = 0; i < transitions.size(); i++)
+    if (costs[i] == bestCost)
+      result.emplace_back(&transitions[i]);
+
   return result;
 }
 
@@ -115,7 +118,10 @@ std::size_t TransitionSet::getTransitionIndex(const Transition * transition) con
   if (!transition)
     util::myThrow("transition is null");
 
-  return transition - &transitions[0];
+  int index = transition - &transitions[0];
+  if (index < 0 or index >= (int)transitions.size())
+    util::myThrow(fmt::format("transition index '{}' out of bounds [0;{}[", index, transitions.size()));
+  return index;
 }
 
 Transition * TransitionSet::getTransition(std::size_t index)
diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp
index 7ea2335f6b647d39c4df166191255e8bdd078346..4090a808d2d8464fbd16d0b90d9c7dce2bfe6cab 100644
--- a/torch_modules/include/ConfigDataset.hpp
+++ b/torch_modules/include/ConfigDataset.hpp
@@ -17,8 +17,9 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase
     int nextIndexToGive{0};
     std::size_t size_{0};
     std::size_t nbGiven{0};
+    int nbClasses;
 
-    Holder(std::string state);
+    Holder(std::string state, int nbClasses);
     void addFile(std::string filename, int filesize);
     void reset();
     std::size_t size() const;
diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index d15064e007be26faa109ffa14a841d72915f3baf..2b42eefe624646cb17a7c869d881c5289801ad25 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -10,13 +10,15 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir)
       if (stem == "extracted")
         continue;
       auto underSplit = util::split(stem, '_');
-      auto state = util::join("_", std::vector<std::string>(underSplit.begin(), underSplit.end()-1));
+      auto stateAndNbClasses = util::split(util::join("_", std::vector<std::string>(underSplit.begin(), underSplit.end()-1)), '-');
+      auto state = stateAndNbClasses[0];
+      auto nbClasses = std::stoi(stateAndNbClasses[1]);
       auto splited = util::split(underSplit.back(), '-');
       int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]);
       size_ += fileSize;
       if (!holders.count(state))
       {
-        holders.emplace(state, Holder(state));
+        holders.emplace(state, Holder(state, nbClasses));
         order.emplace_back(state);
       }
       holders.at(state).addFile(entry.path().string(), fileSize);
@@ -111,10 +113,10 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
   nbGiven += nbElementsToGive;
   auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive);
   nextIndexToGive += nbElementsToGive;
-  return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state);
+  return std::make_tuple(batch.narrow(1, 0, batch.size(1)-nbClasses), batch.narrow(1, batch.size(1)-nbClasses, nbClasses), state);
 }
 
-ConfigDataset::Holder::Holder(std::string state) : state(state)
+ConfigDataset::Holder::Holder(std::string state, int nbClasses) : state(state), nbClasses(nbClasses)
 {
 }
 
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index d5667471e0c3e1a08fd4947a1abc3dc5433d8b2b..ad14ef658b5d151b01c3b553a33c7eaddd5eccd7 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -5,6 +5,19 @@
 #include "ConfigDataset.hpp"
 #include "SubConfig.hpp"
 
+class LossFunction
+{
+  private :
+
+  std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss> fct;
+
+  public :
+
+  LossFunction(std::string name);
+  torch::Tensor operator()(torch::Tensor prediction, torch::Tensor gold);
+  torch::Tensor getGoldFromClassesIndexes(int nbClasses, const std::vector<int> & goldIndexes) const;
+};
+
 class Trainer
 {
   public :
@@ -35,7 +48,7 @@ class Trainer
 
     void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle);
     void addContext(std::vector<std::vector<long>> & context);
-    void addClass(int goldIndex);
+    void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<int> & goldIndexes);
   };
 
   private :
@@ -52,16 +65,16 @@ class Trainer
   DataLoader devDataLoader{nullptr};
   std::size_t epochNumber{0};
   int batchSize;
+  LossFunction lossFct;
 
   private :
 
   void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
   float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
-  void fillDicts(SubConfig & config, bool debug);
 
   public :
 
-  Trainer(ReadingMachine & machine, int batchSize);
+  Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName);
   void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
   void makeDataLoader(std::filesystem::path dir);
   void makeDevDataLoader(std::filesystem::path dir);
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 8e43cd34a07dd683c92153526a50e92db466e0e1..9db5c5a9d0a2d98632b663ab9eac9f67265c3b9a 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -37,6 +37,8 @@ po::options_description MacaonTrain::getOptionsDescription()
       "Reading machine file content")
     ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
       "Description of what should happen during training")
+    ("loss", po::value<std::string>()->default_value("CrossEntropy"),
+      "Loss function to use during training : CrossEntropy | bce | mse")
     ("help,h", "Produce this help message");
 
   desc.add(req).add(opt);
@@ -125,6 +127,7 @@ int MacaonTrain::main()
   bool computeDevScore = variables.count("devScore") == 0 ? false : true;
   auto machineContent = variables["machine"].as<std::string>();
   auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
+  auto lossFunction = variables["loss"].as<std::string>();
 
   auto trainStrategy = parseTrainStrategy(trainStrategyStr);
 
@@ -149,7 +152,7 @@ int MacaonTrain::main()
   BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile);
   BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
 
-  Trainer trainer(machine, batchSize);
+  Trainer trainer(machine, batchSize, lossFunction);
   Decoder decoder(machine);
 
   if (!util::findFilesByExtension(machinePath.parent_path(), ".dict").empty())
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index af1ef2e90dfc3b0e60d6a7bab116f5b46ef5f4d7..90b65b70a1f13776cae2d1a05d12bd4e203370d9 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -1,7 +1,56 @@
 #include "Trainer.hpp"
 #include "SubConfig.hpp"
 
-Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
+LossFunction::LossFunction(std::string name)
+{
+  if (util::lower(name) == "crossentropy")
+    fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean));
+  else if (util::lower(name) == "bce")
+    fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
+  else if (util::lower(name) == "mse")
+    fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));
+  else
+    util::myThrow(fmt::format("unknown loss function name '{}'", name));
+}
+
+torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold)
+{
+  auto index = fct.index();
+
+  if (index == 0)
+    return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0)));
+  if (index == 1)
+    return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
+  if (index == 2)
+    return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
+
+  util::myThrow("loss is not defined");
+  return torch::Tensor();
+}
+
+torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<int> & goldIndexes) const
+{
+  auto index = fct.index();
+
+  if (index == 0)
+  {
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
+    gold[0] = goldIndexes.at(0);
+    return gold;
+  }
+  if (index == 1 or index == 2)
+  {
+    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
+    for (auto goldIndex : goldIndexes)
+      gold[goldIndex] = 1;
+    return gold;
+  }
+
+  util::myThrow("loss is not defined");
+  return torch::Tensor();
+}
+
+Trainer::Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName) : machine(machine), batchSize(batchSize), lossFct(lossFunctionName)
 {
 }
 
@@ -72,9 +121,10 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     }
 
     Transition * transition = nullptr;
-    Transition * goldTransition = nullptr;
 
-    goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, true or dynamicOracle);
+    auto goldTransitions = machine.getTransitionSet().getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
+    Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()];
+    int nbClasses = machine.getTransitionSet().size();
       
     if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
     {
@@ -107,14 +157,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
       util::myThrow("No transition appliable !");
     }
 
-    int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition);
-
     totalNbExamples += context.size();
     if (totalNbExamples >= (int)safetyNbExamplesMax)
       util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
 
+    std::vector<int> goldIndexes;
+    for (auto & t : goldTransitions)
+      goldIndexes.emplace_back(machine.getTransitionSet().getTransitionIndex(t));
+
     examplesPerState[config.getState()].addContext(context);
-    examplesPerState[config.getState()].addClass(goldIndex);
+    examplesPerState[config.getState()].addClass(lossFct, nbClasses, goldIndexes);
     examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
 
     transition->apply(config);
@@ -156,8 +208,6 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
   torch::AutoGradMode useGrad(train);
   machine.trainMode(train);
 
-  auto lossFct = torch::nn::CrossEntropyLoss();
-
   auto pastTime = std::chrono::high_resolution_clock::now();
 
   for (auto & batch : *loader)
@@ -175,26 +225,27 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
     if (prediction.dim() == 1)
       prediction = prediction.unsqueeze(0);
 
-    labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
-
     auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
+    float lossAsFloat = 0.0;
     try
     {
-      totalLoss += loss.item<float>();
-      lossSoFar += loss.item<float>();
+      lossAsFloat = loss.item<float>();
     } catch(std::exception & e) {util::myThrow(e.what());}
 
+    totalLoss += lossAsFloat;
+    lossSoFar += lossAsFloat;
+
     if (train)
     {
       loss.backward();
       machine.getClassifier()->getOptimizer().step();
     }
 
-    totalNbExamplesProcessed += torch::numel(labels);
+    totalNbExamplesProcessed += labels.size(0);
 
     if (printAdvancement)
     {
-      nbExamplesProcessed += torch::numel(labels);
+      nbExamplesProcessed += labels.size(0);
 
       if (nbExamplesProcessed >= printInterval)
       {
@@ -234,8 +285,10 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
   if (contexts.empty())
     return;
 
+  int nbClasses = classes[0].size(0);
+
   auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
-  auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
+  auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
   torch::save(tensorToSave, dir/filename);
   lastSavedIndex = currentExampleIndex;
   contexts.clear();
@@ -250,67 +303,12 @@ void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
   currentExampleIndex += context.size();
 }
 
-void Trainer::Examples::addClass(int goldIndex)
+void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<int> & goldIndexes)
 {
-    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
-    gold[0] = goldIndex;
-
-    while (classes.size() < contexts.size())
-      classes.emplace_back(gold);
-}
-
-void Trainer::fillDicts(SubConfig & config, bool debug)
-{
-  torch::AutoGradMode useGrad(false);
-
-  config.addPredicted(machine.getPredicted());
-  config.setStrategy(machine.getStrategyDefinition());
-  config.setState(config.getStrategy().getInitialState());
-  machine.getClassifier()->setState(config.getState());
+  auto gold = lossFct.getGoldFromClassesIndexes(nbClasses, goldIndexes);
 
-  while (true)
-  {
-    if (debug)
-      config.printForDebug(stderr);
-
-    if (machine.hasSplitWordTransitionSet())
-      config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
-    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
-    config.setAppliableTransitions(appliableTransitions);
-
-    try
-    {
-      machine.getClassifier()->getNN()->extractContext(config);
-    } catch(std::exception & e)
-    {
-      util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
-    }
-
-    Transition * goldTransition = nullptr;
-    goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions);
-      
-    if (!goldTransition)
-    {
-      config.printForDebug(stderr);
-      util::myThrow("No transition appliable !");
-    }
-
-    goldTransition->apply(config);
-    config.addToHistory(goldTransition->getName());
-
-    auto movement = config.getStrategy().getMovement(config, goldTransition->getName());
-    if (debug)
-      fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second);
-    if (movement == Strategy::endMovement)
-      break;
-
-    config.setState(movement.first);
-    machine.getClassifier()->setState(movement.first);
-    config.moveWordIndexRelaxed(movement.second);
-
-    if (config.needsUpdate())
-      config.update();
-  }
+  while (classes.size() < contexts.size())
+    classes.emplace_back(gold);
 }
 
 Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)