diff --git a/common/include/util.hpp b/common/include/util.hpp index ac4ffdc5396e45061f60b64a0e1e6b4afa857c9c..daf0560c9f359f565b929499afa99efbda347f7d 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -9,6 +9,7 @@ #include <filesystem> #include <experimental/source_location> #include <boost/flyweight.hpp> +#include <boost/circular_buffer.hpp> #include "fmt/core.h" #include "utf8.hpp" #include "utf8string.hpp" @@ -90,6 +91,17 @@ std::string join(const std::string & delim, const std::vector<T> elems) return result; } +template <typename T> +std::string join(const std::string & delim, const boost::circular_buffer<T> elems) +{ + std::string result; + + for (unsigned int i = 0; i < elems.size(); i++) + result = fmt::format("{}{}{}", result, elems[i], i == elems.size()-1 ? "" : delim); + + return result; +} + }; template <> diff --git a/decoder/include/Beam.hpp b/decoder/include/Beam.hpp index a9547b39bb3767659389b2572d6028237c2a1cec..41534601fd1f85b3145dbe5acff3783925c09203 100644 --- a/decoder/include/Beam.hpp +++ b/decoder/include/Beam.hpp @@ -15,14 +15,16 @@ class Beam public : BaseConfig config; - int nextTransition; - float totalProbability; - std::string name; + int nextTransition{-1}; + boost::circular_buffer<double> probabilities{500}; + boost::circular_buffer<std::string> name{20}; + float meanProbability{0.0}; bool ended{false}; public : - Element(BaseConfig & model, int nextTransition, float totalProbability, std::string name); + Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name); + Element(const BaseConfig & model); }; private : diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index b6f0eb3a83990ce479e59615306550434ce8ff22..ab0153ea1511eebdd267759e1c9b978520e3acc9 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -25,7 +25,7 @@ class Decoder public : Decoder(ReadingMachine & machine); - void decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement); + void decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement); void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV); std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index a0dbab59fe66c13ba97e6196cf23f61e3fc147fa..3af7ab07ebcedded3228ea730b5a549cdef237dc 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -5,10 +5,14 @@ Beam::Beam(std::size_t width, float threshold, BaseConfig & model, const Reading model.setStrategy(machine.getStrategyDefinition()); model.addPredicted(machine.getPredicted()); model.setState(model.getStrategy().getInitialState()); - elements.emplace_back(model, -1, 0.0, "0"); + elements.emplace_back(model); } -Beam::Element::Element(BaseConfig & model, int nextTransition, float totalProbability, std::string name) : config(model), nextTransition(nextTransition), totalProbability(totalProbability), name(name) +Beam::Element::Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name) : config(model), nextTransition(nextTransition), probabilities(probabilities), name(name) +{ +} + +Beam::Element::Element(const BaseConfig & model) : config(model) { } @@ -23,7 +27,7 @@ void Beam::update(ReadingMachine & machine, bool debug) auto currentNbElements = elements.size(); if (debug) - fmt::print(stderr, "{:-<{}}\nBEAM SEARCH CONTENT :\n", "", 80); + fmt::print(stderr, "{:*<{}}BEAM START{:*<{}}\n", "", 37, "", 36); for (unsigned int index = 0; index < currentNbElements; index++) { @@ -49,7 +53,7 @@ void Beam::update(ReadingMachine & machine, bool debug) for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); - if (appliableTransitions[i] and score >= threshold) + if (appliableTransitions[i]) scoresOfTransitions.emplace_back(std::make_pair(score, i)); } @@ -61,18 +65,32 @@ void Beam::update(ReadingMachine & machine, bool debug) std::sort(scoresOfTransitions.rbegin(), scoresOfTransitions.rend()); + while (!scoresOfTransitions.empty() and scoresOfTransitions.back().first < threshold) + scoresOfTransitions.pop_back(); + if (width > 1) for (unsigned int i = 1; i < scoresOfTransitions.size(); i++) { - elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].totalProbability + scoresOfTransitions[i].first, elements[index].name + ":" + std::to_string(scoresOfTransitions[i].second)); + elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].probabilities, elements[index].name); + elements.back().name.push_back(std::to_string(i)); + elements.back().probabilities.push_back(scoresOfTransitions[i].first); + elements.back().meanProbability = 0.0; + for (auto & p : elements.back().probabilities) + elements.back().meanProbability += p; + elements.back().meanProbability /= elements.back().probabilities.size(); } elements[index].nextTransition = scoresOfTransitions[0].second; - elements[index].totalProbability += scoresOfTransitions[0].first; - elements[index].name += ":" + std::to_string(elements[index].nextTransition); + elements[index].probabilities.push_back(scoresOfTransitions[0].first); + elements[index].name.push_back("0"); + elements[index].meanProbability = 0.0; + for (auto & p : elements[index].probabilities) + elements[index].meanProbability += p; + elements[index].meanProbability /= elements[index].probabilities.size(); if (debug) { + fmt::print(stderr, "Element {:<3} Probability={} Name={}\n", index, elements[index].meanProbability, util::join(":", elements[index].name)); elements[index].config.printForDebug(stderr); std::vector<std::pair<float,std::string>> toPrint; for (unsigned int i = 0; i < prediction.size(0); i++) @@ -89,7 +107,7 @@ void Beam::update(ReadingMachine & machine, bool debug) std::sort(elements.begin(), elements.end(), [](const Element & a, const Element & b) { - return a.totalProbability > b.totalProbability; + return a.meanProbability > b.meanProbability; }); while (elements.size() > width) @@ -122,7 +140,7 @@ void Beam::update(ReadingMachine & machine, bool debug) } if (debug) - fmt::print(stderr, "END OF BEAM SEARCH CONTENT\n{:-<{}}\n", "", 80); + fmt::print(stderr, "{:*<{}}BEAM END{:*<{}}\n", "", 37, "", 38); } bool Beam::isEnded() const diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 4bd1c80061418e41f95c7129a179cf74b36a2b39..9ec248046f1edc3dee3c9e5e287bb71a23e74f5d 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement) +void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement) { constexpr int printInterval = 50; @@ -17,7 +17,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, int nbExamplesProcessed = 0; auto pastTime = std::chrono::high_resolution_clock::now(); - Beam beam(beamSize, 0.1, baseConfig, machine); + Beam beam(beamSize, beamThreshold, baseConfig, machine); try { diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index 290f47f582b2a1031ce0ef9b6d8a0108ce9d4ee6..7d62241cb06f84f57c4f8df2bf5551e484024b00 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -24,6 +24,8 @@ po::options_description MacaonDecode::getOptionsDescription() ("silent", "Don't print speed and progress") ("beamSize", po::value<int>()->default_value(1), "Size of the beam during beam search") + ("beamThreshold", po::value<float>()->default_value(0.1), + "Minimal probability an action must have to be considered in the beam search") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -74,6 +76,7 @@ int MacaonDecode::main() bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; auto beamSize = variables["beamSize"].as<int>(); + auto beamThreshold = variables["beamThreshold"].as<float>(); torch::globalContext().setBenchmarkCuDNN(true); @@ -89,7 +92,7 @@ int MacaonDecode::main() BaseConfig config(mcdFile, inputTSV, inputTXT); - decoder.decode(config, beamSize, debug, printAdvancement); + decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); config.print(stdout); } catch(std::exception & e) {util::error(e);} diff --git a/reading_machine/include/BaseConfig.hpp b/reading_machine/include/BaseConfig.hpp index cfc5acb2d41a4ca2d85f39fa4da1bf53add58636..0b009cdbbe83da767b1dd57038f1c7cfdf2a3905 100644 --- a/reading_machine/include/BaseConfig.hpp +++ b/reading_machine/include/BaseConfig.hpp @@ -17,8 +17,6 @@ class BaseConfig : public Config std::vector<std::string> colIndex2Name; std::unordered_map<std::string, int> colName2Index; - Utf8String rawInputUtf8; - private : void readMCD(std::string_view mcdFilename); diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index f420c222eff5acc0c17391cc92518bdd76178a1a..690f937df7076699c4a0ff6b579743b021b4824d 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -44,29 +44,29 @@ class Config private : std::vector<String> lines; - std::set<std::string> predicted; - int lastPoppedStack{-1}; - int lastAttached{-1}; - int currentWordId{0}; - std::vector<Transition *> appliableSplitTransitions; - std::vector<int> appliableTransitions; - std::shared_ptr<Strategy> strategy; protected : - const Utf8String * rawInput; + Utf8String rawInput; std::size_t wordIndex{0}; std::size_t characterIndex{0}; String state{"NONE"}; boost::circular_buffer<String> history{10}; boost::circular_buffer<std::size_t> stack{50}; std::vector<std::string> extraColumns{isMultiColName, childsColName, sentIdColName, EOSColName}; + std::set<std::string> predicted; + int lastPoppedStack{-1}; + int lastAttached{-1}; + int currentWordId{0}; + std::vector<Transition *> appliableSplitTransitions; + std::vector<int> appliableTransitions; + std::shared_ptr<Strategy> strategy; protected : - Config(const Utf8String & rawInput); - Config(const Utf8String & rawInput, const Config & other); - Config(const Config & other) = delete; + Config() = default; + Config & operator=(const Config & other) = default; + Config(const Config & other); virtual ~Config() = default; public : diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 8aa4322d992d791489d46f09a8b4708562417fc5..296a58f683b62852d626e113d469765de08e4b62 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -43,9 +43,9 @@ void BaseConfig::readRawInput(std::string_view rawFilename) std::fclose(file); - rawInputUtf8 = util::splitAsUtf8(rawInputTemp); - rawInputUtf8.replace(util::utf8char("\n"), util::utf8char(" ")); - rawInputUtf8.replace(util::utf8char("\t"), util::utf8char(" ")); + rawInput = util::splitAsUtf8(rawInputTemp); + rawInput.replace(util::utf8char("\n"), util::utf8char(" ")); + rawInput.replace(util::utf8char("\t"), util::utf8char(" ")); } void BaseConfig::readTSVInput(std::string_view tsvFilename) @@ -154,11 +154,11 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) std::fclose(file); } -BaseConfig::BaseConfig(const BaseConfig & other) : Config(rawInputUtf8, other), colIndex2Name(other.colIndex2Name), colName2Index(other.colName2Index), rawInputUtf8(other.rawInputUtf8) +BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name(other.colIndex2Name), colName2Index(other.colName2Index) { } -BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename) : Config(rawInputUtf8) +BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename) { if (tsvFilename.empty() and rawFilename.empty()) util::myThrow("tsvFilename and rawFilenames can't be both empty"); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index a884c7d6ac938d825140cc7dda0413231f86cfee..81fbc58b66c6916bae6b7b0bce7542fb84102e24 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -1,11 +1,7 @@ #include "Config.hpp" #include "util.hpp" -Config::Config(const Utf8String & rawInput) : rawInput(&rawInput) -{ -} - -Config::Config(const Utf8String & rawInput, const Config & other) : rawInput(&rawInput) +Config::Config(const Config & other) { this->lines = other.lines; this->predicted = other.predicted; @@ -17,6 +13,7 @@ Config::Config(const Utf8String & rawInput, const Config & other) : rawInput(&ra this->strategy.reset(new Strategy(*other.strategy)); + this->rawInput = other.rawInput; this->wordIndex = other.wordIndex; this->characterIndex = other.characterIndex; this->state = other.state; @@ -151,8 +148,6 @@ void Config::printForDebug(FILE * dest) const static constexpr int lettersWindowSize = 40; static constexpr int maxWordLength = 7; - fmt::print(dest, "\n"); - int firstLineToPrint = wordIndex; int lastLineToPrint = wordIndex; while (wordIndex-firstLineToPrint < windowSize and has(0, firstLineToPrint-1, 0)) @@ -228,9 +223,9 @@ void Config::printForDebug(FILE * dest) const if (!stackStr.empty()) stackStr.pop_back(); fmt::print(dest, "{}\n", longLine); - for (std::size_t index = characterIndex; index < util::getSize(*rawInput) and index - characterIndex < lettersWindowSize; index++) + for (std::size_t index = characterIndex; index < util::getSize(rawInput) and index - characterIndex < lettersWindowSize; index++) fmt::print(dest, "{}", getLetter(index)); - if (rawInput->size()) + if (!util::isEmpty(rawInput)) fmt::print(dest, "\n{}\n", longLine); fmt::print(dest, "State={}\nwordIndex={} characterIndex={}\nhistory=({})\nstack=({})\n", state, wordIndex, characterIndex, historyStr, stackStr); fmt::print(dest, "{}\n", longLine); @@ -386,12 +381,12 @@ void Config::swapStack(int relIndex1, int relIndex2) bool Config::hasCharacter(int letterIndex) const { - return letterIndex >= 0 and letterIndex < (int)util::getSize(*rawInput); + return letterIndex >= 0 and letterIndex < (int)util::getSize(rawInput); } util::utf8char Config::getLetter(int letterIndex) const { - return (*rawInput)[letterIndex]; + return rawInput[letterIndex]; } bool Config::isComment(std::size_t lineIndex) const @@ -515,20 +510,20 @@ bool Config::canMoveWordIndex(int relativeMovement) const bool Config::moveCharacterIndex(int relativeMovement) { int oldVal = characterIndex; - characterIndex = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(*rawInput))); + characterIndex = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(rawInput))); return (int)characterIndex == oldVal + relativeMovement; } bool Config::canMoveCharacterIndex(int relativeMovement) const { - int target = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(*rawInput))); + int target = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(rawInput))); return target == (int)characterIndex + relativeMovement; } bool Config::rawInputOnlySeparatorsLeft() const { - for (unsigned int i = characterIndex; i < rawInput->size(); i++) - if (!util::isSeparator((*rawInput)[i])) + for (unsigned int i = characterIndex; i < util::getSize(rawInput); i++) + if (!util::isSeparator(rawInput[i])) return false; return true; @@ -585,7 +580,7 @@ void Config::setState(const std::string state) bool Config::stateIsDone() const { - if (!rawInput->empty()) + if (!util::isEmpty(rawInput)) return rawInputOnlySeparatorsLeft() and !has(0, wordIndex+1, 0) and !hasStack(0); return !has(0, wordIndex+1, 0) and !hasStack(0); diff --git a/reading_machine/src/SubConfig.cpp b/reading_machine/src/SubConfig.cpp index 161fc635e077584f3a88da43b2f533b530006d04..eef8ae778a0fa38b3f89b48e29cfe12c7fa35583 100644 --- a/reading_machine/src/SubConfig.cpp +++ b/reading_machine/src/SubConfig.cpp @@ -1,11 +1,22 @@ #include "SubConfig.hpp" -SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : Config(*model.rawInput), model(model), spanSize(spanSize) +SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : model(model), spanSize(spanSize) { + rawInput = model.rawInput; wordIndex = model.wordIndex; characterIndex = model.characterIndex; state = model.state; history = model.history; + stack = model.stack; + extraColumns = model.extraColumns; + predicted = model.predicted; + lastPoppedStack = model.lastPoppedStack; + lastAttached = model.lastAttached; + currentWordId = model.currentWordId; + appliableSplitTransitions = model.appliableSplitTransitions; + appliableTransitions = model.appliableTransitions; + if (model.strategy.get() != nullptr) + strategy.reset(new Strategy(model.getStrategy())); update(); } diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index b84b0fb2b149e60b30eba09621a341bf25488378..f7d0a3979c731386331361e1caf42caead0f8d9c 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -178,7 +178,7 @@ int MacaonTrain::main() if (computeDevScore) { auto devConfig = devGoldConfig; - decoder.decode(devConfig, 1, debug, printAdvancement); + decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); decoder.evaluate(devConfig, modelPath, devTsvFile); devScores = decoder.getF1Scores(machine.getPredicted()); } @@ -217,7 +217,7 @@ int MacaonTrain::main() machine.getClassifier()->saveOptimizer(optimizerCheckpoint); if (printAdvancement) fmt::print(stderr, "\r{:80}\r", ""); - std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); + std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), 100.0*loss, devScoresStr, saved ? "SAVED" : ""); fmt::print(stderr, "{}\n", iterStr); std::FILE * f = std::fopen(trainInfos.c_str(), "a"); fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);