From 8e01a10044e3f643a1a11b6623ca27df64309c3f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 5 Jun 2020 11:05:17 +0200 Subject: [PATCH] Corrected bug where splittransitions were always no appliable --- decoder/src/Beam.cpp | 5 +++-- reading_machine/include/Action.hpp | 1 + reading_machine/src/Action.cpp | 24 ++++++++++++++++++++++++ reading_machine/src/Transition.cpp | 2 +- 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 47afb72..08606ef 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -41,11 +41,12 @@ void Beam::update(ReadingMachine & machine, bool debug) classifier.setState(elements[index].config.getState()); - auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config); - elements[index].config.setAppliableTransitions(appliableTransitions); if (machine.hasSplitWordTransitionSet()) elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions)); + auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config); + elements[index].config.setAppliableTransitions(appliableTransitions); + auto context = classifier.getNN()->extractContext(elements[index].config).back(); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(), 0); diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 6b14f76..b511ab5 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -44,6 +44,7 @@ class Action static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); + static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); static Action pushWordIndexOnStack(); static Action popStack(int relIndex); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 428b417..5ebbda5 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -239,6 +239,30 @@ Action Action::addHypothesisRelative(const std::string & colName, Config::Object return {Type::Write, apply, undo, appliable}; } +Action Action::addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis) +{ + auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); + + return addHypothesis(colName, lineIndex, hypothesis).apply(config, a); + }; + + auto undo = [colName, object, relativeIndex](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); + + return addHypothesis(colName, lineIndex, "").undo(config, a); + }; + + auto appliable = [colName, object, relativeIndex](const Config & config, const Action & a) + { + return true; + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::pushWordIndexOnStack() { auto apply = [](Config & config, Action & a) diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 1a2ad12..7589612 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -213,7 +213,7 @@ void Transition::initSplitWord(std::vector<std::string> words) sequence.emplace_back(Action::addLinesIfNeeded(words.size())); sequence.emplace_back(Action::consumeCharacterIndex(consumedWord)); for (unsigned int i = 0; i < words.size(); i++) - sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i])); + sequence.emplace_back(Action::addHypothesisRelativeRelaxed("FORM", Config::Object::Buffer, i, words[i])); sequence.emplace_back(Action::setMultiwordIds(words.size()-1)); cost = [words](const Config & config) -- GitLab