diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 273e9a555d350433e8adbf7e3196ae0505f15acb..94dd16beb55786200195827d18f708e0202a188a 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -5,79 +5,134 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) +void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement) { torch::AutoGradMode useGrad(false); machine.trainMode(false); machine.setDictsState(Dict::State::Closed); machine.getStrategy().reset(); - config.addPredicted(machine.getPredicted()); constexpr int printInterval = 50; int nbExamplesProcessed = 0; auto pastTime = std::chrono::high_resolution_clock::now(); + std::vector<BaseConfig> beam; + std::vector<bool> endFlag; + try { - config.setState(machine.getStrategy().getInitialState()); + + baseConfig.addPredicted(machine.getPredicted()); + baseConfig.setState(machine.getStrategy().getInitialState()); + for (unsigned int i = 0; i < beamSize; i++) + { + beam.emplace_back(baseConfig); + endFlag.emplace_back(false); + } + machine.getClassifier()->setState(machine.getStrategy().getInitialState()); while (true) { - if (debug) - config.printForDebug(stderr); - if (machine.hasSplitWordTransitionSet()) - config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); - auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); - config.setAppliableTransitions(appliableTransitions); + for (auto & c : beam) + c.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(c, Config::maxNbAppliableSplitTransitions)); - auto context = machine.getClassifier()->getNN()->extractContext(config).back(); - - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); + std::vector<std::vector<int>> appliableTransitions; + for (auto & c : beam) + { + appliableTransitions.emplace_back(machine.getTransitionSet().getAppliableTransitions(c)); + c.setAppliableTransitions(appliableTransitions.back()); + } - int chosenTransition = -1; - float bestScore = std::numeric_limits<float>::min(); + std::vector<torch::Tensor> predictions; + for (auto & c : beam) + { + machine.getClassifier()->setState(c.getState()); + auto context = machine.getClassifier()->getNN()->extractContext(c).back(); + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); + predictions.emplace_back(machine.getClassifier()->getNN()(neuralInput).squeeze()); + } if (debug) { - auto softmaxed = torch::softmax(prediction,-1); - std::vector<std::pair<float,std::string>> toPrint; - for (unsigned int i = 0; i < softmaxed.size(0); i++) + fmt::print(stderr, "{:-<{}}\n", "", 80); + fmt::print(stderr, "BEAM SEARCH CONTENT :\n"); + for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++) { - float score = softmaxed[i].item<float>(); - std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); - toPrint.emplace_back(std::make_pair(score,nicePrint)); + auto & c = beam[beamIndex]; + c.printForDebug(stderr); + auto softmaxed = torch::softmax(predictions[beamIndex],-1); + std::vector<std::pair<float,std::string>> toPrint; + for (unsigned int i = 0; i < softmaxed.size(0); i++) + { + float score = softmaxed[i].item<float>(); + std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[beamIndex][i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); + toPrint.emplace_back(std::make_pair(score,nicePrint)); + } + std::sort(toPrint.rbegin(), toPrint.rend()); + for (unsigned int i = 0; i < 5 and i < toPrint.size(); i++) + fmt::print(stderr, "{}\n", toPrint[i].second); } - std::sort(toPrint.rbegin(), toPrint.rend()); - for (unsigned int i = 0; i < 5 and i < toPrint.size(); i++) - fmt::print(stderr, "{}\n", toPrint[i].second); + fmt::print(stderr, "END OF BEAM SEARCH CONTENT\n"); + fmt::print(stderr, "{:-<{}}\n", "", 80); } - try + for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++) { - for (unsigned int i = 0; i < prediction.size(0); i++) + if (endFlag[beamIndex]) + continue; + auto & c = beam[beamIndex]; + int chosenTransition = -1; + float bestScore = std::numeric_limits<float>::min(); + + try { - float score = prediction[i].item<float>(); - if ((chosenTransition == -1 or score > bestScore) and appliableTransitions[i]) + for (unsigned int i = 0; i < predictions[beamIndex].size(0); i++) { - chosenTransition = i; - bestScore = score; + float score = predictions[beamIndex][i].item<float>(); + if ((chosenTransition == -1 or score > bestScore) and appliableTransitions[beamIndex][i]) + { + chosenTransition = i; + bestScore = score; + } } + } catch(std::exception & e) {util::myThrow(e.what());} + + if (chosenTransition == -1) + { + c.printForDebug(stderr); + util::myThrow("No transition appliable !"); } - } catch(std::exception & e) {util::myThrow(e.what());} - if (chosenTransition == -1) - { - config.printForDebug(stderr); - util::myThrow("No transition appliable !"); + auto * transition = machine.getTransitionSet().getTransition(chosenTransition); + + transition->apply(c); + c.addToHistory(transition->getName()); + + auto movement = machine.getStrategy().getMovement(c, transition->getName()); + if (debug) + { + //TODO improve this for beam search + fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); + } + if (movement == Strategy::endMovement) + { + endFlag[beamIndex] = true; + continue; + } + + c.setState(movement.first); + c.moveWordIndexRelaxed(movement.second); } - auto * transition = machine.getTransitionSet().getTransition(chosenTransition); + bool allBeamAreEnded = true; + for (unsigned int i = 0; i < beam.size(); i++) + if (!endFlag[i]) + allBeamAreEnded = false; - transition->apply(config); - config.addToHistory(transition->getName()); + if (allBeamAreEnded) + break; if (printAdvancement) if (++nbExamplesProcessed >= printInterval) @@ -89,32 +144,29 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool nbExamplesProcessed = 0; } - auto movement = machine.getStrategy().getMovement(config, transition->getName()); - if (debug) - fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); - if (movement == Strategy::endMovement) - break; - - config.setState(movement.first); - machine.getClassifier()->setState(movement.first); - config.moveWordIndexRelaxed(movement.second); } + } catch(std::exception & e) {util::myThrow(e.what());} - // Force EOS when needed - if (machine.getTransitionSet().getTransition("EOS b.0") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1) + for (auto & c : beam) { - machine.getTransitionSet().getTransition("EOS b.0")->apply(config); - if (debug) + // Force EOS when needed + if (machine.getTransitionSet().getTransition("EOS b.0") and c.getLastNotEmptyHypConst(Config::EOSColName, c.getWordIndex()) != Config::EOSSymbol1) { - fmt::print(stderr, "Forcing EOS transition\n"); - config.printForDebug(stderr); + machine.getTransitionSet().getTransition("EOS b.0")->apply(c); + if (debug) + { + fmt::print(stderr, "Forcing EOS transition\n"); + c.printForDebug(stderr); + } } + + // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script + try {c.addMissingColumns();} + catch (std::exception & e) {util::myThrow(e.what());} } - // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script - try {config.addMissingColumns();} - catch (std::exception & e) {util::myThrow(e.what());} + baseConfig = beam[0]; } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index 8aac97762dac362506ccd3ecc1e4330f5a9e1903..290f47f582b2a1031ce0ef9b6d8a0108ce9d4ee6 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -22,6 +22,8 @@ po::options_description MacaonDecode::getOptionsDescription() opt.add_options() ("debug,d", "Print debuging infos on stderr") ("silent", "Don't print speed and progress") + ("beamSize", po::value<int>()->default_value(1), + "Size of the beam during beam search") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -71,6 +73,7 @@ int MacaonDecode::main() auto mcdFile = variables["mcd"].as<std::string>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; + auto beamSize = variables["beamSize"].as<int>(); torch::globalContext().setBenchmarkCuDNN(true); @@ -86,7 +89,7 @@ int MacaonDecode::main() BaseConfig config(mcdFile, inputTSV, inputTXT); - decoder.decode(config, 1, debug, printAdvancement); + decoder.decode(config, beamSize, debug, printAdvancement); config.print(stdout); } catch(std::exception & e) {util::error(e);} diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 527fd23fefcb25387f9396e7d8be5e89731fef0e..b9de6887e894852d3716d994fa9e63e3e548e71f 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -52,7 +52,7 @@ class Config protected : - const Utf8String & rawInput; + const Utf8String * rawInput; std::size_t wordIndex{0}; std::size_t characterIndex{0}; String state{"NONE"}; diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 320686c1564b25c82176ddfde95dfb41e8495c91..4a218a5daefc7eed4398c67f5ec14429896c1f3c 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -1,7 +1,7 @@ #include "Config.hpp" #include "util.hpp" -Config::Config(const Utf8String & rawInput) : rawInput(rawInput) +Config::Config(const Utf8String & rawInput) : rawInput(&rawInput) { } @@ -208,9 +208,9 @@ void Config::printForDebug(FILE * dest) const if (!stackStr.empty()) stackStr.pop_back(); fmt::print(dest, "{}\n", longLine); - for (std::size_t index = characterIndex; index < util::getSize(rawInput) and index - characterIndex < lettersWindowSize; index++) + for (std::size_t index = characterIndex; index < util::getSize(*rawInput) and index - characterIndex < lettersWindowSize; index++) fmt::print(dest, "{}", getLetter(index)); - if (rawInput.size()) + if (rawInput->size()) fmt::print(dest, "\n{}\n", longLine); fmt::print(dest, "State={}\nwordIndex={} characterIndex={}\nhistory=({})\nstack=({})\n", state, wordIndex, characterIndex, historyStr, stackStr); fmt::print(dest, "{}\n", longLine); @@ -366,12 +366,12 @@ void Config::swapStack(int relIndex1, int relIndex2) bool Config::hasCharacter(int letterIndex) const { - return letterIndex >= 0 and letterIndex < (int)util::getSize(rawInput); + return letterIndex >= 0 and letterIndex < (int)util::getSize(*rawInput); } util::utf8char Config::getLetter(int letterIndex) const { - return rawInput[letterIndex]; + return (*rawInput)[letterIndex]; } bool Config::isComment(std::size_t lineIndex) const @@ -495,20 +495,20 @@ bool Config::canMoveWordIndex(int relativeMovement) const bool Config::moveCharacterIndex(int relativeMovement) { int oldVal = characterIndex; - characterIndex = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(rawInput))); + characterIndex = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(*rawInput))); return (int)characterIndex == oldVal + relativeMovement; } bool Config::canMoveCharacterIndex(int relativeMovement) const { - int target = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(rawInput))); + int target = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(*rawInput))); return target == (int)characterIndex + relativeMovement; } bool Config::rawInputOnlySeparatorsLeft() const { - for (unsigned int i = characterIndex; i < rawInput.size(); i++) - if (!util::isSeparator(rawInput[i])) + for (unsigned int i = characterIndex; i < rawInput->size(); i++) + if (!util::isSeparator((*rawInput)[i])) return false; return true; @@ -565,7 +565,7 @@ void Config::setState(const std::string state) bool Config::stateIsDone() const { - if (!rawInput.empty()) + if (!rawInput->empty()) return rawInputOnlySeparatorsLeft() and !has(0, wordIndex+1, 0) and !hasStack(0); return !has(0, wordIndex+1, 0) and !hasStack(0); diff --git a/reading_machine/src/SubConfig.cpp b/reading_machine/src/SubConfig.cpp index c571b3c95bdbf2efeb4fa1647c168acef99b49a2..161fc635e077584f3a88da43b2f533b530006d04 100644 --- a/reading_machine/src/SubConfig.cpp +++ b/reading_machine/src/SubConfig.cpp @@ -1,6 +1,6 @@ #include "SubConfig.hpp" -SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : Config(model.rawInput), model(model), spanSize(spanSize) +SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : Config(*model.rawInput), model(model), spanSize(spanSize) { wordIndex = model.wordIndex; characterIndex = model.characterIndex;