diff --git a/decoder/include/Beam.hpp b/decoder/include/Beam.hpp index 1dd40184acbe7f0adb75a24201939959ca12083d..e38d38f9b24b8fa6ed10d79cd888a96f7ff54811 100644 --- a/decoder/include/Beam.hpp +++ b/decoder/include/Beam.hpp @@ -5,6 +5,7 @@ #include <string> #include "BaseConfig.hpp" #include "ReadingMachine.hpp" +#include "Producer.hpp" class Beam { @@ -40,7 +41,7 @@ class Beam Beam(std::size_t width, float threshold, BaseConfig & model, const ReadingMachine & machine); Element & operator[](std::size_t index); - void update(ReadingMachine & machine, bool debug); + void update(ReadingMachine & machine, bool debug, std::optional<Producer> & producer); bool isEnded() const; }; diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index fe8c870c21dcc8bdf2f677d94d8e6dd6c2d98aef..01567574a93a47a506efcb8b66defb60689ab3fc 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -4,6 +4,7 @@ #include <filesystem> #include "ReadingMachine.hpp" #include "SubConfig.hpp" +#include "Producer.hpp" class Decoder { @@ -25,7 +26,7 @@ class Decoder public : Decoder(ReadingMachine & machine); - std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement); + std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer); void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; diff --git a/decoder/include/Producer.hpp b/decoder/include/Producer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec954024e0983b7fd8300e9ed98c2dd629e8f3ee --- /dev/null +++ b/decoder/include/Producer.hpp @@ -0,0 +1,21 @@ +#ifndef PRODUCER__H +#define PRODUCER__H + +#include <filesystem> +#include "Config.hpp" + +class Producer +{ + private : + + static constexpr int maxNb = 100; + int curNb = 0; + + public : + + Producer(std::filesystem::path path); + + bool apply(Config & config); +}; + +#endif diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index c9327339675b254fa33cfe9d6773598fdcfdda79..1a5e151e76071ac7f55da63a07fc097e1ce0e362 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -22,7 +22,7 @@ Beam::Element & Beam::operator[](std::size_t index) return elements[index]; } -void Beam::update(ReadingMachine & machine, bool debug) +void Beam::update(ReadingMachine & machine, bool debug, std::optional<Producer> & producer) { ended = true; auto currentNbElements = elements.size(); @@ -37,6 +37,11 @@ void Beam::update(ReadingMachine & machine, bool debug) ended = false; + if (producer.has_value() and elements[index].config.getState() == "tokenizer") + elements[index].config.setRawInputStatus(not producer.value().apply(elements[index].config)); + if (not producer.has_value()) + elements[index].config.setRawInputStatus(true); + auto & classifier = *machine.getClassifier(elements[index].config.getState()); if (machine.hasSplitWordTransitionSet()) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 5394280fb1bc9b05145df876a6118d26d2f21043..7b1e7b028c91d9b86d7fff73ea42c8f4c9c8acc8 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement) +std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer) { constexpr int printInterval = 50; @@ -20,7 +20,7 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float { while (!beam.isEnded()) { - beam.update(machine, debug); + beam.update(machine, debug, producer); ++totalNbExamplesProcessed; if (printAdvancement) diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index 22e715f505d66ed9d41c32894d10d3d8395f405f..f96406150fe5a988518885dcfb38e888be073c3a 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -4,6 +4,7 @@ #include "util.hpp" #include "Decoder.hpp" #include "Submodule.hpp" +#include "Producer.hpp" po::options_description MacaonDecode::getOptionsDescription() { @@ -16,7 +17,9 @@ po::options_description MacaonDecode::getOptionsDescription() ("inputTSV", po::value<std::string>(), "File containing the text to decode, TSV file") ("inputTXT", po::value<std::string>(), - "File containing the text to decode, raw text file"); + "File containing the text to decode, raw text file") + ("inputSES", po::value<std::string>(), + "File containing a list of actions that will fill the input tape"); po::options_description opt("Optional"); opt.add_options() @@ -55,7 +58,7 @@ po::variables_map MacaonDecode::checkOptions(po::options_description & od) try {po::notify(vm);} catch(std::exception& e) {util::myThrow(e.what());} - if (vm.count("inputTSV") + vm.count("inputTXT") != 1) + if (vm.count("inputTSV") + vm.count("inputTXT") + vm.count("inputSES") != 1) { std::stringstream ss; ss << od; @@ -76,6 +79,7 @@ int MacaonDecode::main() auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, "")); auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; + auto inputSES = variables.count("inputSES") ? variables["inputSES"].as<std::string>() : ""; auto mcd = variables["mcd"].as<std::string>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; @@ -117,7 +121,9 @@ int MacaonDecode::main() } else { - if (rawInputs.size()) + if (inputSES.size()) + configs.emplace_back(mcd, noTsv, util::utf8string(), std::vector<int>()); + else if (rawInputs.size()) configs.emplace_back(mcd, noTsv, rawInputs[0], std::vector<int>()); else configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>()); @@ -132,11 +138,18 @@ int MacaonDecode::main() std::for_each(std::execution::par, configs.begin(), configs.end(), [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config) { - decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); + decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); }); } else - decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement); + { + if (not inputSES.empty()) + { + decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>(Producer(inputSES))); + } + else + decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); + } for (unsigned int i = 0; i < configs.size(); i++) configs[i].print(stdout, i == 0); diff --git a/decoder/src/Producer.cpp b/decoder/src/Producer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d564dd531f063a354008fef3b4cd9e2e1a3b2b7 --- /dev/null +++ b/decoder/src/Producer.cpp @@ -0,0 +1,22 @@ +#include "Producer.hpp" + +Producer::Producer(std::filesystem::path) +{ +} + +bool Producer::apply(Config & config) +{ + if (util::choiceWithProbability(0.05)) + { + config.rawInputAdd("."); + config.rawInputAdd(" "); + } + else if (util::choiceWithProbability(0.8)) + config.rawInputAdd(fmt::format("{}", (char) ('a'+rand()%26))); + else + config.rawInputAdd(" "); + + curNb++; + return curNb < maxNb; +} + diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index ed0a29dfc3da615a1829ed5ad06119d75df41ea7..db8eb10730aae55c4e9d181d7829509f062771b3 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -49,6 +49,7 @@ class Config protected : Utf8String rawInput; + bool rawInputIsComplete; std::size_t wordIndex{0}; std::size_t characterIndex{0}; std::size_t currentSentenceStartRawInput{0}; @@ -120,6 +121,10 @@ class Config util::String & getFirstEmpty(const std::string & colName, int lineIndex); bool hasCharacter(int letterIndex) const; const util::utf8char & getLetter(int letterIndex) const; + bool getRawInputStatus() const; + void setRawInputStatus(bool status); + void rawInputPop(); + void rawInputAdd(util::utf8char letter); void addToHistory(const std::string & transition); void addToStack(std::size_t index); void popStack(); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index d978112494d37d1224826901858e3102cc4edb5d..72e33ecf1a5f8c802f4fca2c0f1b4ccc9cec249b 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -444,7 +444,7 @@ Action Action::endWord() config.setCurrentWordId(config.getCurrentWordId()+1); addHypothesisRelative(Config::idColName, Config::Object::Buffer, 0, std::to_string(config.getCurrentWordId())).apply(config, a); - if (!config.rawInputOnlySeparatorsLeft() and !config.has(0,config.getWordIndex()+1,0)) + if (!(config.rawInputOnlySeparatorsLeft() and config.getRawInputStatus()) and !config.has(0,config.getWordIndex()+1,0)) config.addLines(1); }; diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 61488ebdcf3016eef811e8f4a9f824755848c2d2..086f7f09096b2ea69a204ffd3343676e981bcc4f 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -156,9 +156,6 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name( BaseConfig::BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes) { - if (sentences.empty() and rawInput.empty()) - util::myThrow("sentences and rawInput can't be both empty"); - createColumns(mcd); if (not rawInput.empty()) diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 69ae4005217da12a021091ec935497596cd467b6..22ca0451ead9e256f878de67d6537811ae5add0f 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -14,6 +14,7 @@ Config::Config(const Config & other) this->strategy.reset(other.strategy ? new Strategy(*other.strategy) : nullptr); this->rawInput = other.rawInput; + this->rawInputIsComplete = other.rawInputIsComplete; this->wordIndex = other.wordIndex; this->characterIndex = other.characterIndex; this->state = other.state; @@ -396,6 +397,26 @@ const util::utf8char & Config::getLetter(int letterIndex) const return rawInput[letterIndex]; } +bool Config::getRawInputStatus() const +{ + return rawInputIsComplete; +} + +void Config::setRawInputStatus(bool status) +{ + rawInputIsComplete = status; +} + +void Config::rawInputPop() +{ + rawInput.pop_back(); +} + +void Config::rawInputAdd(util::utf8char letter) +{ + rawInput.push_back(letter); +} + bool Config::isMultiword(std::size_t lineIndex) const { return hasColIndex(idColName) && std::string(getConst(idColName, lineIndex, 0)).find('-') != std::string::npos; diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index b22b9e62cc19ef3420f59e4cc05cdbd59260bb45..d23440e243c755b88d1a7a39582e6b009a851663 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -116,7 +116,9 @@ bool Strategy::Block::isFinished(const Config & c, const Movement & movement) if (condition == EndCondition::CannotMove) { if (c.canMoveWordIndex(movement.second)) + { return false; + } } return true; diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index d760d1f9e92cf46d2731ea57d47d7f04d76a2c6c..08832ba56542f473a6a437115472369907e28e95 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -4,6 +4,7 @@ #include "util.hpp" #include "NeuralNetwork.hpp" #include "WordEmbeddings.hpp" +#include "Producer.hpp" namespace po = boost::program_options; @@ -332,14 +333,14 @@ int MacaonTrain::main() std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(), [&decoder, debug, printAdvancement](BaseConfig & devConfig) { - decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); + decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, std::optional<Producer>()); }); NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice()); machine.to(NeuralNetworkImpl::getDevice()); } else { - decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement); + decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, std::optional<Producer>()); } std::vector<const Config *> devConfigsPtrs;