diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 47afb72787e5757c290dfbb3de02feef89be4275..08606efd0a6db74a796a5310a6b1e9f16428faaa 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -41,11 +41,12 @@ void Beam::update(ReadingMachine & machine, bool debug) classifier.setState(elements[index].config.getState()); - auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config); - elements[index].config.setAppliableTransitions(appliableTransitions); if (machine.hasSplitWordTransitionSet()) elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions)); + auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config); + elements[index].config.setAppliableTransitions(appliableTransitions); + auto context = classifier.getNN()->extractContext(elements[index].config).back(); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(), 0); diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 6b14f7606c5d1729a4eb0c3d319b0b78112eb080..b511ab5046c8de52582ff0c59b599161d4be75c6 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -44,6 +44,7 @@ class Action static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); + static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); static Action pushWordIndexOnStack(); static Action popStack(int relIndex); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 428b417f817b5e364a84abe17005e693c8db2eec..5ebbda5d57926f0f3ec8b15fada4bb1a78f86a20 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -239,6 +239,30 @@ Action Action::addHypothesisRelative(const std::string & colName, Config::Object return {Type::Write, apply, undo, appliable}; } +Action Action::addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis) +{ + auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); + + return addHypothesis(colName, lineIndex, hypothesis).apply(config, a); + }; + + auto undo = [colName, object, relativeIndex](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); + + return addHypothesis(colName, lineIndex, "").undo(config, a); + }; + + auto appliable = [colName, object, relativeIndex](const Config & config, const Action & a) + { + return true; + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::pushWordIndexOnStack() { auto apply = [](Config & config, Action & a) diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 1a2ad12bdad2d496fd4449e25914bc0b54212013..7589612dfa4bb80ae68521ed36765f165c8abea2 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -213,7 +213,7 @@ void Transition::initSplitWord(std::vector<std::string> words) sequence.emplace_back(Action::addLinesIfNeeded(words.size())); sequence.emplace_back(Action::consumeCharacterIndex(consumedWord)); for (unsigned int i = 0; i < words.size(); i++) - sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i])); + sequence.emplace_back(Action::addHypothesisRelativeRelaxed("FORM", Config::Object::Buffer, i, words[i])); sequence.emplace_back(Action::setMultiwordIds(words.size()-1)); cost = [words](const Config & config)