From f0df050838d46b19bace7a1b3d04039df2200e27 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 15 Apr 2021 23:55:18 +0200
Subject: [PATCH] Speed up oracle

---
 reading_machine/include/Transition.hpp    |   8 +-
 reading_machine/include/TransitionSet.hpp |   1 +
 reading_machine/src/BaseConfig.cpp        |   8 ++
 reading_machine/src/Transition.cpp        | 115 ++++++++++------------
 reading_machine/src/TransitionSet.cpp     |  60 ++++++++++-
 trainer/src/Trainer.cpp                   |   4 +
 6 files changed, 127 insertions(+), 69 deletions(-)

diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp
index 75f9c40..f0bb97f 100644
--- a/reading_machine/include/Transition.hpp
+++ b/reading_machine/include/Transition.hpp
@@ -12,8 +12,8 @@ class Transition
   std::string name;
   std::string state;
   std::vector<Action> sequence;
-  std::function<int(const Config & config)> costDynamic;
-  std::function<int(const Config & config)> costStatic;
+  std::function<int(const Config & config, const std::map<std::string, int> & links)> costDynamic;
+  std::function<int(const Config & config, const std::map<std::string, int> & links)> costStatic;
   std::function<bool(const Config & config)> precondition{[](const Config&){return true;}};
 
   private :
@@ -64,8 +64,8 @@ class Transition
   void apply(Config & config, float entropy);
   void apply(Config & config);
   bool appliable(const Config & config) const;
-  int getCostDynamic(const Config & config) const;
-  int getCostStatic(const Config & config) const;
+  int getCostDynamic(const Config & config, const std::map<std::string, int> & links) const;
+  int getCostStatic(const Config & config, const std::map<std::string, int> & links) const;
   const std::string & getName() const;
 };
 
diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp
index 936fa88..21782b0 100644
--- a/reading_machine/include/TransitionSet.hpp
+++ b/reading_machine/include/TransitionSet.hpp
@@ -28,6 +28,7 @@ class TransitionSet
   Transition * getTransition(std::size_t index);
   Transition * getTransition(const std::string & name);
   std::size_t size() const;
+  std::map<std::string, int> computeLinks(const Config & c);
 };
 
 #endif
diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp
index 0d373c6..61488eb 100644
--- a/reading_machine/src/BaseConfig.cpp
+++ b/reading_machine/src/BaseConfig.cpp
@@ -112,6 +112,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent
     try
     {
       std::map<std::string, int> id2index;
+      std::map<int, std::vector<std::string>> childs;
       int firstIndexOfSequence = getNbLines()-1;
       for (int i = (int)getNbLines()-1; has(0, i, 0); --i)
       {
@@ -125,6 +126,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent
         id2index[getConst(idColName, i, 0)] = i;
       }
       if (hasColIndex(headColName))
+      {
         for (int i = firstIndexOfSequence; i < (int)getNbLines(); ++i)
         {
           if (!isToken(i))
@@ -133,8 +135,14 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent
           if (head == "0")
             head = "-1";
           else
+          {
+            childs[id2index[head]].emplace_back(fmt::format("{}",i));
             head = std::to_string(id2index[head]);
+          }
         }
+        for (auto it : childs)
+          get(Config::childsColName, it.first, 0) = util::join("|", it.second);
+      }
 
       get(EOSColName, getNbLines()-1, 0) = EOSSymbol1;
     } catch(std::exception & e) {util::myThrow(e.what());}
diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp
index 858a044..aeee6b4 100644
--- a/reading_machine/src/Transition.cpp
+++ b/reading_machine/src/Transition.cpp
@@ -175,17 +175,17 @@ bool Transition::appliable(const Config & config) const
   return true;
 }
 
-int Transition::getCostDynamic(const Config & config) const
+int Transition::getCostDynamic(const Config & config, const std::map<std::string, int> & links) const
 {
-  try {return costDynamic(config);}
+  try {return costDynamic(config, links);}
   catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
 
   return 0;
 }
 
-int Transition::getCostStatic(const Config & config) const
+int Transition::getCostStatic(const Config & config, const std::map<std::string, int> & links) const
 {
-  try {return costStatic(config);}
+  try {return costStatic(config, links);}
   catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
 
   return 0;
@@ -203,7 +203,7 @@ void Transition::initWrite(std::string colName, std::string object, std::string
 
   sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
 
-  costDynamic = [colName, objectValue, indexValue, value](const Config & config)
+  costDynamic = [colName, objectValue, indexValue, value](const Config & config, const std::map<std::string, int> &)
   {
     int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
 
@@ -223,7 +223,7 @@ void Transition::initWriteScore(std::string colName, std::string object, std::st
 
   sequence.emplace_back(Action::writeScore(colName, objectValue, indexValue));
 
-  costDynamic = [](const Config &)
+  costDynamic = [](const Config &, const std::map<std::string, int> &)
   {
     return 0;
   };
@@ -238,7 +238,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in
 
   sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
 
-  costDynamic = [colName, objectValue, indexValue, value](const Config & config)
+  costDynamic = [colName, objectValue, indexValue, value](const Config & config, const std::map<std::string, int> &)
   {
     int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
 
@@ -256,7 +256,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in
 
 void Transition::initNothing()
 {
-  costDynamic = [](const Config &)
+  costDynamic = [](const Config &, const std::map<std::string, int> &)
   {
     return 0;
   };
@@ -268,7 +268,7 @@ void Transition::initIgnoreChar()
 {
   sequence.emplace_back(Action::ignoreCurrentCharacter());
 
-  costDynamic = [](const Config & config)
+  costDynamic = [](const Config & config, const std::map<std::string, int> &)
   {
     auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
     auto goldWord = util::splitAsUtf8(std::string(config.getConst("FORM", config.getWordIndex(), 0)));
@@ -286,7 +286,7 @@ void Transition::initEndWord()
 {
   sequence.emplace_back(Action::endWord());
 
-  costDynamic = [](const Config & config)
+  costDynamic = [](const Config & config, const std::map<std::string, int> &)
   {
     if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
       return 0;
@@ -304,7 +304,7 @@ void Transition::initAddCharToWord(int n)
   sequence.emplace_back(Action::addCharsToCol("FORM", n, Config::Object::Buffer, 0));
   sequence.emplace_back(Action::moveCharacterIndex(n));
 
-  costDynamic = [n](const Config & config)
+  costDynamic = [n](const Config & config, const std::map<std::string, int> &)
   {
     if (!config.hasCharacter(config.getCharacterIndex()+n-1))
       return std::numeric_limits<int>::max();
@@ -345,7 +345,7 @@ void Transition::initSplitWord(std::vector<std::string> words)
   }
   sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
 
-  costDynamic = [words](const Config & config)
+  costDynamic = [words](const Config & config, const std::map<std::string, int> &)
   {
     if (!config.isMultiword(config.getWordIndex()))
       return std::numeric_limits<int>::max();
@@ -367,14 +367,14 @@ void Transition::initSplit(int index)
 {
   sequence.emplace_back(Action::split(index));
 
-  costDynamic = [index](const Config & config)
+  costDynamic = [index](const Config & config, const std::map<std::string, int> & links)
   {
     auto & transitions = config.getAppliableSplitTransitions();
 
     if (index < 0 or index >= (int)transitions.size())
       return std::numeric_limits<int>::max();
 
-    return transitions[index]->getCostDynamic(config);
+    return transitions[index]->getCostDynamic(config, links);
   };
 
   costStatic = costDynamic;
@@ -384,25 +384,22 @@ void Transition::initEagerShift()
 {
   sequence.emplace_back(Action::pushWordIndexOnStack());
 
-  costDynamic = [](const Config & config)
+  costDynamic = [](const Config & config, const std::map<std::string, int> & links)
   {
     if (!config.isToken(config.getWordIndex()))
       return 0;
 
-    return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
+    return links.at("BufferStack");
   };
 
-  costStatic = [](const Config &)
-  {
-    return 0;
-  };
+  costStatic = costDynamic;
 }
 
 void Transition::initGoldEagerShift()
 {
   sequence.emplace_back(Action::pushWordIndexOnStack());
 
-  costDynamic = [](const Config & config)
+  costDynamic = [](const Config & config, const std::map<std::string, int> &)
   {
     if (!config.isToken(config.getWordIndex()))
       return 0;
@@ -410,7 +407,7 @@ void Transition::initGoldEagerShift()
     return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
   };
 
-  costStatic = [](const Config &)
+  costStatic = [](const Config &, const std::map<std::string, int> &)
   {
     return 0;
   };
@@ -428,7 +425,7 @@ void Transition::initStandardShift()
 {
   sequence.emplace_back(Action::pushWordIndexOnStack());
 
-  costDynamic = [](const Config &)
+  costDynamic = [](const Config &, const std::map<std::string, int> &)
   {
     return 0;
   };
@@ -442,23 +439,23 @@ void Transition::initEagerLeft_rel(std::string label)
   sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
   sequence.emplace_back(Action::popStack(0));
 
-  costDynamic = [label](const Config & config)
+  costDynamic = [label](const Config & config, const std::map<std::string, int> & links)
   {
     auto depIndex = config.getStack(0);
     auto govIndex = config.getWordIndex();
     auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
 
-    int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
+    int cost = 0;
 
     if (label != config.getConst(Config::deprelColName, depIndex, 0))
       ++cost;
     if (depGovIndex != std::to_string(govIndex))
-      ++cost;
+      cost += links.at("StackRight");
 
     return cost;
   };
 
-  costStatic = [label](const Config & config)
+  costStatic = [label](const Config & config, const std::map<std::string, int> &)
   {
     auto depIndex = config.getStack(0);
     auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
@@ -477,7 +474,7 @@ void Transition::initGoldEagerLeft_rel(std::string label)
   sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
   sequence.emplace_back(Action::popStack(0));
 
-  costDynamic = [label](const Config & config)
+  costDynamic = [label](const Config & config, const std::map<std::string, int> &)
   {
     auto depIndex = config.getStack(0);
     auto govIndex = config.getWordIndex();
@@ -490,7 +487,7 @@ void Transition::initGoldEagerLeft_rel(std::string label)
     return cost;
   };
 
-  costStatic = [label](const Config & config)
+  costStatic = [label](const Config & config, const std::map<std::string, int> &)
   {
     auto depIndex = config.getStack(0);
     auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
@@ -522,7 +519,7 @@ void Transition::initStandardLeft_rel(std::string label)
   sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label));
   sequence.emplace_back(Action::popStack(1));
 
-  costDynamic = [label](const Config & config)
+  costDynamic = [label](const Config & config, const std::map<std::string, int> &)
   {
     auto depIndex = config.getStack(1);
     auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
@@ -548,7 +545,7 @@ void Transition::initEagerLeft()
   sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
   sequence.emplace_back(Action::popStack(0));
 
-  costDynamic = [](const Config & config)
+  costDynamic = [](const Config & config, const std::map<std::string, int> &)
   {
     auto depIndex = config.getStack(0);
     auto govIndex = config.getWordIndex();
@@ -562,7 +559,7 @@ void Transition::initEagerLeft()
     return cost;
   };
 
-  costStatic = [](const Config & config)
+  costStatic = [](const Config & config, const std::map<std::string, int> &)
   {
     auto depIndex = config.getStack(0);
     auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
@@ -581,24 +578,23 @@ void Transition::initEagerRight_rel(std::string label)
   sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
   sequence.emplace_back(Action::pushWordIndexOnStack());
 
-  costDynamic = [label](const Config & config)
+  costDynamic = [label](const Config & config, const std::map<std::string, int> & links)
   {
     auto govIndex = config.getStack(0);
     auto depIndex = config.getWordIndex();
     auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
 
-    int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
-    cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
+    int cost = 0;
 
     if (label != config.getConst(Config::deprelColName, depIndex, 0))
       ++cost;
-    if (depGovIndex == std::to_string(govIndex))
-      ++cost;
+    if (depGovIndex != std::to_string(govIndex))
+      cost += links.at("BufferStack") + links.at("BufferRightHead");
 
     return cost;
   };
 
-  costStatic = [label](const Config & config)
+  costStatic = [label](const Config & config, const std::map<std::string, int> &)
   {
     auto govIndex = config.getStack(0);
     auto depIndex = config.getWordIndex();
@@ -617,7 +613,7 @@ void Transition::initGoldEagerRight_rel(std::string label)
   sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
   sequence.emplace_back(Action::pushWordIndexOnStack());
 
-  costDynamic = [label](const Config & config)
+  costDynamic = [label](const Config & config, const std::map<std::string, int> &)
   {
     auto depIndex = config.getWordIndex();
 
@@ -630,7 +626,7 @@ void Transition::initGoldEagerRight_rel(std::string label)
     return cost;
   };
 
-  costStatic = [label](const Config & config)
+  costStatic = [label](const Config & config, const std::map<std::string, int> &)
   {
     auto govIndex = config.getStack(0);
     auto depIndex = config.getWordIndex();
@@ -662,7 +658,7 @@ void Transition::initStandardRight_rel(std::string label)
   sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
   sequence.emplace_back(Action::popStack(0));
 
-  costDynamic = [label](const Config & config)
+  costDynamic = [label](const Config & config, const std::map<std::string, int> &)
   {
     auto govIndex = config.getStack(1);
     auto depIndex = config.getStack(0);
@@ -688,7 +684,7 @@ void Transition::initEagerRight()
   sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
   sequence.emplace_back(Action::pushWordIndexOnStack());
 
-  costDynamic = [](const Config & config)
+  costDynamic = [](const Config & config, const std::map<std::string, int> &)
   {
     auto depIndex = config.getWordIndex();
     auto govIndex = config.getStack(0);
@@ -703,7 +699,7 @@ void Transition::initEagerRight()
     return cost;
   };
 
-  costStatic = [](const Config & config)
+  costStatic = [](const Config & config, const std::map<std::string, int> &)
   {
     auto govIndex = config.getStack(0);
     auto depIndex = config.getWordIndex();
@@ -722,17 +718,12 @@ void Transition::initReduce_strict()
   sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
   sequence.emplace_back(Action::popStack(0));
 
-  costDynamic = [](const Config & config)
+  costDynamic = [](const Config & config, const std::map<std::string, int> & links)
   {
-    auto stackIndex = config.getStack(0);
-    auto wordIndex = config.getWordIndex();
-
-    if (!config.isToken(stackIndex))
+    if (!config.isToken(config.getStack(0)))
       return 0;
 
-    int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
-
-    return cost;
+    return links.at("StackRight");
   };
 
   costStatic = costDynamic;
@@ -743,7 +734,7 @@ void Transition::initGoldReduce_strict()
   sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
   sequence.emplace_back(Action::popStack(0));
 
-  costDynamic = [](const Config & config)
+  costDynamic = [](const Config & config, const std::map<std::string, int> &)
   {
     auto stackIndex = config.getStack(0);
     auto wordIndex = config.getWordIndex();
@@ -776,7 +767,7 @@ void Transition::initReduce_relaxed()
 {
   sequence.emplace_back(Action::popStack(0));
 
-  costDynamic = [](const Config & config)
+  costDynamic = [](const Config & config, const std::map<std::string, int> &)
   {
     auto stackIndex = config.getStack(0);
     auto wordIndex = config.getWordIndex();
@@ -799,7 +790,7 @@ void Transition::initEOS(int bufferIndex)
   sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
   sequence.emplace_back(Action::emptyStack());
 
-  costDynamic = [bufferIndex](const Config & config)
+  costDynamic = [bufferIndex](const Config & config, const std::map<std::string, int> &)
   {
     int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
     if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
@@ -813,7 +804,7 @@ void Transition::initEOS(int bufferIndex)
 
 void Transition::initNotEOS(int bufferIndex)
 {
-  costDynamic = [bufferIndex](const Config & config)
+  costDynamic = [bufferIndex](const Config & config, const std::map<std::string, int> &)
   {
     int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
     if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1)
@@ -829,7 +820,7 @@ void Transition::initDeprel(std::string label)
 {
   sequence.emplace_back(Action::deprel(label));
 
-  costDynamic = [label](const Config & config)
+  costDynamic = [label](const Config & config, const std::map<std::string, int> &)
   {
     return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
   };
@@ -857,7 +848,7 @@ void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, s
   toAddUtf8 = util::splitAsUtf8(toAdd);
   sequence.emplace_back(Action::transformSuffix(fromCol, fromObjectValue, fromIndexValue, toCol, toObjectValue, toIndexValue, toRemoveUtf8, toAddUtf8));
 
-  costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config)
+  costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config, const std::map<std::string, int> &)
   {
     int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue);
     int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue);
@@ -883,7 +874,7 @@ void Transition::initUppercase(std::string col, std::string obj, std::string ind
 
   sequence.emplace_back(Action::uppercase(col, objectValue, indexValue));
 
-  costDynamic = [col, objectValue, indexValue](const Config & config)
+  costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &)
   {
     int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
     auto & expectedValue = config.getConst(col, lineIndex, 0);
@@ -908,7 +899,7 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin
 
   sequence.emplace_back(Action::uppercaseIndex(col, objectValue, indexValue, inIndexValue));
 
-  costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config)
+  costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config, const std::map<std::string, int> &)
   {
     int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
     auto & expectedValue = config.getConst(col, lineIndex, 0);
@@ -932,7 +923,7 @@ void Transition::initNothing(std::string col, std::string obj, std::string index
   auto objectValue = Config::str2object(obj);
   int indexValue = std::stoi(index);
 
-  costDynamic = [col, objectValue, indexValue](const Config & config)
+  costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &)
   {
     int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
     auto & expectedValue = config.getConst(col, lineIndex, 0);
@@ -953,7 +944,7 @@ void Transition::initLowercase(std::string col, std::string obj, std::string ind
 
   sequence.emplace_back(Action::lowercase(col, objectValue, indexValue));
 
-  costDynamic = [col, objectValue, indexValue](const Config & config)
+  costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &)
   {
     int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
     auto & expectedValue = config.getConst(col, lineIndex, 0);
@@ -978,7 +969,7 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin
 
   sequence.emplace_back(Action::lowercaseIndex(col, objectValue, indexValue, inIndexValue));
 
-  costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config)
+  costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config, const std::map<std::string, int> &)
   {
     int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
     auto & expectedValue = config.getConst(col, lineIndex, 0);
diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp
index 8f146e7..63590ba 100644
--- a/reading_machine/src/TransitionSet.cpp
+++ b/reading_machine/src/TransitionSet.cpp
@@ -40,9 +40,11 @@ std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsC
   using Pair = std::pair<Transition*, int>;
   std::vector<Pair> appliableTransitions;
 
+  auto links = computeLinks(c);
+
   for (unsigned int i = 0; i < transitions.size(); i++)
     if (transitions[i].appliable(c))
-      appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c));
+      appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c, links) : transitions[i].getCostStatic(c, links));
 
   std::sort(appliableTransitions.begin(), appliableTransitions.end(), 
   [](const Pair & a, const Pair & b)
@@ -80,12 +82,64 @@ std::vector<int> TransitionSet::getAppliableTransitions(const Config & c)
   return result;
 }
 
+std::map<std::string, int> TransitionSet::computeLinks(const Config & c)
+{
+  std::map<std::string, int> links{{"StackRight", 0}, {"BufferRight", 0}, {"BufferRightHead", 0}, {"BufferStack", 0}};
+
+  if (c.has(Config::headColName,0,0))
+  {
+    int nbLinksStackRight = 0;
+    int nbLinksBufferRight = 0;
+    int nbLinksBufferRightHead = 0;
+    int nbLinksBufferStack = 0;
+    if (c.hasStack(0))
+    {
+      if ((std::size_t)std::stoi(c.getConst(Config::headColName, c.getStack(0), 0)) >= c.getWordIndex())
+        nbLinksStackRight++;
+      auto childs = util::split(c.getConst(Config::childsColName, c.getStack(0), 0), '|');
+      for (auto & child : childs)
+      {
+        if ((std::size_t)std::stoi(child) >= c.getWordIndex())
+          nbLinksStackRight++;
+      }
+    }
+
+    auto head = c.getConst(Config::headColName, c.getWordIndex(), 0);
+    if (head != "_" and (std::size_t)std::stoi(c.getConst(Config::headColName, c.getWordIndex(), 0)) > c.getWordIndex())
+    {
+      nbLinksBufferRight++;
+      nbLinksBufferRightHead++;
+    }
+    auto childs = util::split(c.getConst(Config::childsColName, c.getWordIndex(), 0), '|');
+    for (auto & child : childs)
+      if ((std::size_t)std::stoi(child) > c.getWordIndex())
+        nbLinksBufferRight++;
+    auto bufferHead = c.getConst(Config::headColName, c.getWordIndex(), 0);
+    for (unsigned int i = 0; i < c.getStackSize(); i++)
+    {
+      auto stackHead = c.getConst(Config::headColName, c.getStack(i), 0);
+      if (bufferHead != "_" and stackHead != "_")
+        if ((std::size_t)std::stoi(bufferHead) == c.getStack(i) or (std::size_t)std::stoi(stackHead) == c.getWordIndex())
+          nbLinksBufferStack++;
+    }
+
+    links.at("StackRight") = nbLinksStackRight;
+    links.at("BufferRight") = nbLinksBufferRight;
+    links.at("BufferRightHead") = nbLinksBufferRightHead;
+    links.at("BufferStack") = nbLinksBufferStack;
+  }
+
+  return links;
+}
+
 std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic)
 {
   int bestCost = std::numeric_limits<int>::max();
   std::vector<Transition *> result;
   std::vector<int> costs(transitions.size());
 
+  auto links = computeLinks(c);
+
   for (unsigned int i = 0; i < transitions.size(); i++)
   {
     if (!appliableTransitions[i])
@@ -94,7 +148,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi
       continue;
     }
 
-    int cost = dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c);
+    int cost = dynamic ? transitions[i].getCostDynamic(c, links) : transitions[i].getCostStatic(c, links);
 
     costs[i] = cost;
     if (cost < bestCost)
@@ -104,7 +158,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi
   for (unsigned int i = 0; i < transitions.size(); i++)
     if (costs[i] == bestCost)
       result.emplace_back(&transitions[i]);
-
+    
   return result;
 }
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index a0d0172..c65b72a 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -117,6 +117,10 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
           }
 
           transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
+
+          for (auto & trans : goldTransitions)
+            if (trans == transition)
+              goldTransition = trans;
         }
         else
         {
-- 
GitLab