diff --git a/decoder/include/Beam.hpp b/decoder/include/Beam.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a9547b39bb3767659389b2572d6028237c2a1cec --- /dev/null +++ b/decoder/include/Beam.hpp @@ -0,0 +1,43 @@ +#ifndef BEAM__H +#define BEAM__H + +#include <vector> +#include <string> +#include "BaseConfig.hpp" +#include "ReadingMachine.hpp" + +class Beam +{ + public : + + class Element + { + public : + + BaseConfig config; + int nextTransition; + float totalProbability; + std::string name; + bool ended{false}; + + public : + + Element(BaseConfig & model, int nextTransition, float totalProbability, std::string name); + }; + + private : + + std::size_t width; + float threshold; + std::vector<Element> elements; + bool ended{false}; + + public : + + Beam(std::size_t width, float threshold, BaseConfig & model, const ReadingMachine & machine); + Element & operator[](std::size_t index); + void update(ReadingMachine & machine, bool debug); + bool isEnded() const; +}; + +#endif diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0dbab59fe66c13ba97e6196cf23f61e3fc147fa --- /dev/null +++ b/decoder/src/Beam.cpp @@ -0,0 +1,132 @@ +#include "Beam.hpp" + +Beam::Beam(std::size_t width, float threshold, BaseConfig & model, const ReadingMachine & machine) : width(width), threshold(threshold) +{ + model.setStrategy(machine.getStrategyDefinition()); + model.addPredicted(machine.getPredicted()); + model.setState(model.getStrategy().getInitialState()); + elements.emplace_back(model, -1, 0.0, "0"); +} + +Beam::Element::Element(BaseConfig & model, int nextTransition, float totalProbability, std::string name) : config(model), nextTransition(nextTransition), totalProbability(totalProbability), name(name) +{ +} + +Beam::Element & Beam::operator[](std::size_t index) +{ + return elements[index]; +} + +void Beam::update(ReadingMachine & machine, bool debug) +{ + ended = true; + auto currentNbElements = elements.size(); + + if (debug) + fmt::print(stderr, "{:-<{}}\nBEAM SEARCH CONTENT :\n", "", 80); + + for (unsigned int index = 0; index < currentNbElements; index++) + { + if (elements[index].ended) + continue; + + ended = false; + + auto & classifier = *machine.getClassifier(); + + classifier.setState(elements[index].config.getState()); + + auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config); + elements[index].config.setAppliableTransitions(appliableTransitions); + if (machine.hasSplitWordTransitionSet()) + elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions)); + + auto context = classifier.getNN()->extractContext(elements[index].config).back(); + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); + auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(), 0); + + std::vector<std::pair<float, int>> scoresOfTransitions; + for (unsigned int i = 0; i < prediction.size(0); i++) + { + float score = prediction[i].item<float>(); + if (appliableTransitions[i] and score >= threshold) + scoresOfTransitions.emplace_back(std::make_pair(score, i)); + } + + if (scoresOfTransitions.empty()) + { + elements[index].config.printForDebug(stderr); + util::myThrow("No suitable transition found !"); + } + + std::sort(scoresOfTransitions.rbegin(), scoresOfTransitions.rend()); + + if (width > 1) + for (unsigned int i = 1; i < scoresOfTransitions.size(); i++) + { + elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].totalProbability + scoresOfTransitions[i].first, elements[index].name + ":" + std::to_string(scoresOfTransitions[i].second)); + } + + elements[index].nextTransition = scoresOfTransitions[0].second; + elements[index].totalProbability += scoresOfTransitions[0].first; + elements[index].name += ":" + std::to_string(elements[index].nextTransition); + + if (debug) + { + elements[index].config.printForDebug(stderr); + std::vector<std::pair<float,std::string>> toPrint; + for (unsigned int i = 0; i < prediction.size(0); i++) + { + float score = prediction[i].item<float>(); + std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); + toPrint.emplace_back(std::make_pair(score,nicePrint)); + } + std::sort(toPrint.rbegin(), toPrint.rend()); + for (unsigned int i = 0; i < 5 and i < toPrint.size(); i++) + fmt::print(stderr, "{}\n", toPrint[i].second); + } + } + + std::sort(elements.begin(), elements.end(), [](const Element & a, const Element & b) + { + return a.totalProbability > b.totalProbability; + }); + + while (elements.size() > width) + elements.pop_back(); + + for (auto & element : elements) + { + if (element.ended) + continue; + + auto & config = element.config; + auto & classifier = *machine.getClassifier(); + + classifier.setState(config.getState()); + + auto * transition = machine.getTransitionSet().getTransition(element.nextTransition); + + transition->apply(config); + config.addToHistory(transition->getName()); + + auto movement = config.getStrategy().getMovement(config, transition->getName()); + if (movement == Strategy::endMovement) + { + element.ended = true; + continue; + } + + config.setState(movement.first); + config.moveWordIndexRelaxed(movement.second); + } + + if (debug) + fmt::print(stderr, "END OF BEAM SEARCH CONTENT\n{:-<{}}\n", "", 80); +} + +bool Beam::isEnded() const +{ + return ended; +} + diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 1c68a0d8db414b9bd8a9fbd02692d2fe8ea5fce9..4bd1c80061418e41f95c7129a179cf74b36a2b39 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -1,5 +1,6 @@ #include "Decoder.hpp" #include "SubConfig.hpp" +#include "Beam.hpp" Decoder::Decoder(ReadingMachine & machine) : machine(machine) { @@ -7,7 +8,6 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement) { - constexpr float beamThreshold = 0.1; constexpr int printInterval = 50; torch::AutoGradMode useGrad(false); @@ -17,164 +17,43 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, int nbExamplesProcessed = 0; auto pastTime = std::chrono::high_resolution_clock::now(); - std::vector<BaseConfig> beam; - std::vector<bool> endFlag; + Beam beam(beamSize, 0.1, baseConfig, machine); try { - - for (unsigned int i = 0; i < beamSize; i++) - { - beam.emplace_back(baseConfig); - beam.back().setStrategy(machine.getStrategyDefinition()); - beam.back().addPredicted(machine.getPredicted()); - beam.back().setState(beam.back().getStrategy().getInitialState()); - endFlag.emplace_back(false); - } - - while (true) - { - if (machine.hasSplitWordTransitionSet()) - for (auto & c : beam) - c.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(c, Config::maxNbAppliableSplitTransitions)); - - std::vector<std::vector<int>> appliableTransitions; - for (auto & c : beam) + while (!beam.isEnded()) { - machine.getClassifier()->setState(c.getState()); - appliableTransitions.emplace_back(machine.getTransitionSet().getAppliableTransitions(c)); - c.setAppliableTransitions(appliableTransitions.back()); - } - - std::vector<torch::Tensor> predictions; - for (auto & c : beam) - { - machine.getClassifier()->setState(c.getState()); - auto context = machine.getClassifier()->getNN()->extractContext(c).back(); - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - predictions.emplace_back(machine.getClassifier()->getNN()(neuralInput).squeeze()); - } + beam.update(machine, debug); - if (debug) - { - fmt::print(stderr, "{:-<{}}\n", "", 80); - fmt::print(stderr, "BEAM SEARCH CONTENT :\n"); - for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++) - { - auto & c = beam[beamIndex]; - machine.getClassifier()->setState(c.getState()); - c.printForDebug(stderr); - auto softmaxed = torch::softmax(predictions[beamIndex],-1); - std::vector<std::pair<float,std::string>> toPrint; - for (unsigned int i = 0; i < softmaxed.size(0); i++) + if (printAdvancement) + if (++nbExamplesProcessed >= printInterval) { - float score = softmaxed[i].item<float>(); - std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[beamIndex][i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); - toPrint.emplace_back(std::make_pair(score,nicePrint)); + auto actualTime = std::chrono::high_resolution_clock::now(); + double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; + pastTime = actualTime; + fmt::print(stderr, "\r{:80}\rdecoding... speed={:<6}ex/s\r", "", (int)(nbExamplesProcessed/seconds)); + nbExamplesProcessed = 0; } - std::sort(toPrint.rbegin(), toPrint.rend()); - for (unsigned int i = 0; i < 5 and i < toPrint.size(); i++) - fmt::print(stderr, "{}\n", toPrint[i].second); - } - fmt::print(stderr, "END OF BEAM SEARCH CONTENT\n"); - fmt::print(stderr, "{:-<{}}\n", "", 80); - } - for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++) - { - if (endFlag[beamIndex]) - continue; - auto & c = beam[beamIndex]; - machine.getClassifier()->setState(c.getState()); - int chosenTransition = -1; - float bestScore = std::numeric_limits<float>::min(); - auto softmaxed = torch::softmax(predictions[beamIndex], 0); - std::vector<int> consideredTransitions; - - try - { - for (unsigned int i = 0; i < predictions[beamIndex].size(0); i++) - { - float score = predictions[beamIndex][i].item<float>(); - if ((chosenTransition == -1 or score > bestScore) and appliableTransitions[beamIndex][i]) - { - chosenTransition = i; - bestScore = score; - } - if (softmaxed[i].item<float>() >= beamThreshold) - consideredTransitions.emplace_back(i); - } - - } catch(std::exception & e) {util::myThrow(e.what());} - - if (chosenTransition == -1) - { - c.printForDebug(stderr); - util::myThrow("No transition appliable !"); - } - - auto * transition = machine.getTransitionSet().getTransition(chosenTransition); - - transition->apply(c); - c.addToHistory(transition->getName()); - - auto movement = c.getStrategy().getMovement(c, transition->getName()); - if (debug) - { - //TODO improve this for beam search - fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); - } - if (movement == Strategy::endMovement) - { - endFlag[beamIndex] = true; - continue; - } - - c.setState(movement.first); - c.moveWordIndexRelaxed(movement.second); } - - bool allBeamAreEnded = true; - for (unsigned int i = 0; i < beam.size(); i++) - if (!endFlag[i]) - allBeamAreEnded = false; - - if (allBeamAreEnded) - break; - - if (printAdvancement) - if (++nbExamplesProcessed >= printInterval) - { - auto actualTime = std::chrono::high_resolution_clock::now(); - double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; - pastTime = actualTime; - fmt::print(stderr, "\r{:80}\rdecoding... speed={:<6}ex/s\r", "", (int)(nbExamplesProcessed/seconds)); - nbExamplesProcessed = 0; - } - - } - } catch(std::exception & e) {util::myThrow(e.what());} - for (auto & c : beam) + baseConfig = beam[0].config; + machine.getClassifier()->setState(baseConfig.getState()); + + if (machine.getTransitionSet().getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) { - // Force EOS when needed - if (machine.getTransitionSet().getTransition("EOS b.0") and c.getLastNotEmptyHypConst(Config::EOSColName, c.getWordIndex()) != Config::EOSSymbol1) + machine.getTransitionSet().getTransition("EOS b.0")->apply(baseConfig); + if (debug) { - machine.getTransitionSet().getTransition("EOS b.0")->apply(c); - if (debug) - { - fmt::print(stderr, "Forcing EOS transition\n"); - c.printForDebug(stderr); - } + fmt::print(stderr, "Forcing EOS transition\n"); + baseConfig.printForDebug(stderr); } - - // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script - try {c.addMissingColumns();} - catch (std::exception & e) {util::myThrow(e.what());} } - baseConfig = beam[0]; + // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script + try {baseConfig.addMissingColumns();} + catch (std::exception & e) {util::myThrow(e.what());} } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const diff --git a/reading_machine/include/BaseConfig.hpp b/reading_machine/include/BaseConfig.hpp index 101487921f2c0d0eeff743cf8e426d928e11159e..cfc5acb2d41a4ca2d85f39fa4da1bf53add58636 100644 --- a/reading_machine/include/BaseConfig.hpp +++ b/reading_machine/include/BaseConfig.hpp @@ -28,6 +28,8 @@ class BaseConfig : public Config public : BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename); + BaseConfig(const BaseConfig & other); + BaseConfig & operator=(const BaseConfig & other) = default; std::size_t getNbColumns() const override; std::size_t getFirstLineIndex() const override; diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 4a003176a8c31bdca21f4097831a865e830ea4c6..f420c222eff5acc0c17391cc92518bdd76178a1a 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -65,6 +65,9 @@ class Config protected : Config(const Utf8String & rawInput); + Config(const Utf8String & rawInput, const Config & other); + Config(const Config & other) = delete; + virtual ~Config() = default; public : @@ -97,7 +100,6 @@ class Config public : - virtual ~Config() {} void print(FILE * dest) const; void printForDebug(FILE * dest) const; bool has(const std::string & colName, int lineIndex, int hypothesisIndex) const; @@ -158,7 +160,7 @@ class Config const std::vector<Transition *> & getAppliableSplitTransitions() const; const std::vector<int> & getAppliableTransitions() const; bool isExtraColumn(const std::string & colName) const; - void setStrategy(std::vector<std::string> & strategyDefinition); + void setStrategy(const std::vector<std::string> & strategyDefinition); Strategy & getStrategy(); }; diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index ab23fdcac68f0720f3d47b27b2c3b1df87de9337..d4b419aedc115d38ad8a1268d4698b0c533155ce 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -35,7 +35,7 @@ class ReadingMachine TransitionSet & getTransitionSet(); TransitionSet & getSplitWordTransitionSet(); bool hasSplitWordTransitionSet() const; - std::vector<std::string> & getStrategyDefinition(); + const std::vector<std::string> & getStrategyDefinition() const; Classifier * getClassifier(); bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 551c13ebe6d3c980dc7437b38df72296c7c90d51..8aa4322d992d791489d46f09a8b4708562417fc5 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -154,6 +154,10 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) std::fclose(file); } +BaseConfig::BaseConfig(const BaseConfig & other) : Config(rawInputUtf8, other), colIndex2Name(other.colIndex2Name), colName2Index(other.colName2Index), rawInputUtf8(other.rawInputUtf8) +{ +} + BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename) : Config(rawInputUtf8) { if (tsvFilename.empty() and rawFilename.empty()) diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index edf1a067a498321f0ba858437acfd1410c52dd34..a884c7d6ac938d825140cc7dda0413231f86cfee 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -5,6 +5,26 @@ Config::Config(const Utf8String & rawInput) : rawInput(&rawInput) { } +Config::Config(const Utf8String & rawInput, const Config & other) : rawInput(&rawInput) +{ + this->lines = other.lines; + this->predicted = other.predicted; + this->lastPoppedStack = other.lastPoppedStack; + this->lastAttached = other.lastAttached; + this->currentWordId = other.currentWordId; + this->appliableSplitTransitions = other.appliableSplitTransitions; + this->appliableTransitions = other.appliableTransitions; + + this->strategy.reset(new Strategy(*other.strategy)); + + this->wordIndex = other.wordIndex; + this->characterIndex = other.characterIndex; + this->state = other.state; + this->history = other.history; + this->stack = other.stack; + this->extraColumns = this->extraColumns; +} + std::size_t Config::getIndexOfLine(int lineIndex) const { return lineIndex * getNbColumns() * (nbHypothesesMax+1); @@ -727,7 +747,7 @@ std::size_t Config::getStackSize() const return stack.size(); } -void Config::setStrategy(std::vector<std::string> & strategyDefinition) +void Config::setStrategy(const std::vector<std::string> & strategyDefinition) { strategy.reset(new Strategy(strategyDefinition)); } diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index edecf08befdfa60d32aa734e9c27e540441c1f65..e48d507ca0ba99648f2d11b2607ba4f1fd217273 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -121,7 +121,7 @@ TransitionSet & ReadingMachine::getSplitWordTransitionSet() return *splitWordTransitionSet; } -std::vector<std::string> & ReadingMachine::getStrategyDefinition() +const std::vector<std::string> & ReadingMachine::getStrategyDefinition() const { return strategyDefinition; }