From 8e01a10044e3f643a1a11b6623ca27df64309c3f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 5 Jun 2020 11:05:17 +0200
Subject: [PATCH] Corrected bug where splittransitions were always no appliable

---
 decoder/src/Beam.cpp               |  5 +++--
 reading_machine/include/Action.hpp |  1 +
 reading_machine/src/Action.cpp     | 24 ++++++++++++++++++++++++
 reading_machine/src/Transition.cpp |  2 +-
 4 files changed, 29 insertions(+), 3 deletions(-)

diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index 47afb72..08606ef 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -41,11 +41,12 @@ void Beam::update(ReadingMachine & machine, bool debug)
 
     classifier.setState(elements[index].config.getState());
 
-    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config);
-    elements[index].config.setAppliableTransitions(appliableTransitions);
     if (machine.hasSplitWordTransitionSet())
       elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions));
 
+    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config);
+    elements[index].config.setAppliableTransitions(appliableTransitions);
+
     auto context = classifier.getNN()->extractContext(elements[index].config).back();
     auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
     auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(), 0);
diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp
index 6b14f76..b511ab5 100644
--- a/reading_machine/include/Action.hpp
+++ b/reading_machine/include/Action.hpp
@@ -44,6 +44,7 @@ class Action
   static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis);
   static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition);
   static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
+  static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
   static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition);
   static Action pushWordIndexOnStack();
   static Action popStack(int relIndex);
diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp
index 428b417..5ebbda5 100644
--- a/reading_machine/src/Action.cpp
+++ b/reading_machine/src/Action.cpp
@@ -239,6 +239,30 @@ Action Action::addHypothesisRelative(const std::string & colName, Config::Object
   return {Type::Write, apply, undo, appliable}; 
 }
 
+Action Action::addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis)
+{
+  auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a)
+  {
+    int lineIndex = config.getRelativeWordIndex(object, relativeIndex);
+
+    return addHypothesis(colName, lineIndex, hypothesis).apply(config, a);
+  };
+
+  auto undo = [colName, object, relativeIndex](Config & config, Action & a)
+  {
+    int lineIndex = config.getRelativeWordIndex(object, relativeIndex);
+
+    return addHypothesis(colName, lineIndex, "").undo(config, a);
+  };
+
+  auto appliable = [colName, object, relativeIndex](const Config & config, const Action & a)
+  {
+    return true;
+  };
+
+  return {Type::Write, apply, undo, appliable}; 
+}
+
 Action Action::pushWordIndexOnStack()
 {
   auto apply = [](Config & config, Action & a)
diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp
index 1a2ad12..7589612 100644
--- a/reading_machine/src/Transition.cpp
+++ b/reading_machine/src/Transition.cpp
@@ -213,7 +213,7 @@ void Transition::initSplitWord(std::vector<std::string> words)
   sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
   sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
   for (unsigned int i = 0; i < words.size(); i++)
-    sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
+    sequence.emplace_back(Action::addHypothesisRelativeRelaxed("FORM", Config::Object::Buffer, i, words[i]));
   sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
 
   cost = [words](const Config & config)
-- 
GitLab