From 89fe9c355ce91cffd2b2b61e57a252f4128fd5cf Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 21 Jul 2021 16:33:16 +0200 Subject: [PATCH] Added and integred producer --- decoder/include/Beam.hpp | 3 ++- decoder/include/Decoder.hpp | 3 ++- decoder/include/Producer.hpp | 21 +++++++++++++++++++++ decoder/src/Beam.cpp | 7 ++++++- decoder/src/Decoder.cpp | 4 ++-- decoder/src/MacaonDecode.cpp | 23 ++++++++++++++++++----- decoder/src/Producer.cpp | 22 ++++++++++++++++++++++ reading_machine/include/Config.hpp | 5 +++++ reading_machine/src/Action.cpp | 2 +- reading_machine/src/BaseConfig.cpp | 3 --- reading_machine/src/Config.cpp | 21 +++++++++++++++++++++ reading_machine/src/Strategy.cpp | 2 ++ trainer/src/MacaonTrain.cpp | 5 +++-- 13 files changed, 105 insertions(+), 16 deletions(-) create mode 100644 decoder/include/Producer.hpp create mode 100644 decoder/src/Producer.cpp diff --git a/decoder/include/Beam.hpp b/decoder/include/Beam.hpp index 1dd4018..e38d38f 100644 --- a/decoder/include/Beam.hpp +++ b/decoder/include/Beam.hpp @@ -5,6 +5,7 @@ #include <string> #include "BaseConfig.hpp" #include "ReadingMachine.hpp" +#include "Producer.hpp" class Beam { @@ -40,7 +41,7 @@ class Beam Beam(std::size_t width, float threshold, BaseConfig & model, const ReadingMachine & machine); Element & operator[](std::size_t index); - void update(ReadingMachine & machine, bool debug); + void update(ReadingMachine & machine, bool debug, std::optional<Producer> & producer); bool isEnded() const; }; diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index fe8c870..0156757 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -4,6 +4,7 @@ #include <filesystem> #include "ReadingMachine.hpp" #include "SubConfig.hpp" +#include "Producer.hpp" class Decoder { @@ -25,7 +26,7 @@ class Decoder public : Decoder(ReadingMachine & machine); - std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement); + std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer); void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; diff --git a/decoder/include/Producer.hpp b/decoder/include/Producer.hpp new file mode 100644 index 0000000..ec95402 --- /dev/null +++ b/decoder/include/Producer.hpp @@ -0,0 +1,21 @@ +#ifndef PRODUCER__H +#define PRODUCER__H + +#include <filesystem> +#include "Config.hpp" + +class Producer +{ + private : + + static constexpr int maxNb = 100; + int curNb = 0; + + public : + + Producer(std::filesystem::path path); + + bool apply(Config & config); +}; + +#endif diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index c932733..1a5e151 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -22,7 +22,7 @@ Beam::Element & Beam::operator[](std::size_t index) return elements[index]; } -void Beam::update(ReadingMachine & machine, bool debug) +void Beam::update(ReadingMachine & machine, bool debug, std::optional<Producer> & producer) { ended = true; auto currentNbElements = elements.size(); @@ -37,6 +37,11 @@ void Beam::update(ReadingMachine & machine, bool debug) ended = false; + if (producer.has_value() and elements[index].config.getState() == "tokenizer") + elements[index].config.setRawInputStatus(not producer.value().apply(elements[index].config)); + if (not producer.has_value()) + elements[index].config.setRawInputStatus(true); + auto & classifier = *machine.getClassifier(elements[index].config.getState()); if (machine.hasSplitWordTransitionSet()) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 5394280..7b1e7b0 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement) +std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer) { constexpr int printInterval = 50; @@ -20,7 +20,7 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float { while (!beam.isEnded()) { - beam.update(machine, debug); + beam.update(machine, debug, producer); ++totalNbExamplesProcessed; if (printAdvancement) diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index 22e715f..f964061 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -4,6 +4,7 @@ #include "util.hpp" #include "Decoder.hpp" #include "Submodule.hpp" +#include "Producer.hpp" po::options_description MacaonDecode::getOptionsDescription() { @@ -16,7 +17,9 @@ po::options_description MacaonDecode::getOptionsDescription() ("inputTSV", po::value<std::string>(), "File containing the text to decode, TSV file") ("inputTXT", po::value<std::string>(), - "File containing the text to decode, raw text file"); + "File containing the text to decode, raw text file") + ("inputSES", po::value<std::string>(), + "File containing a list of actions that will fill the input tape"); po::options_description opt("Optional"); opt.add_options() @@ -55,7 +58,7 @@ po::variables_map MacaonDecode::checkOptions(po::options_description & od) try {po::notify(vm);} catch(std::exception& e) {util::myThrow(e.what());} - if (vm.count("inputTSV") + vm.count("inputTXT") != 1) + if (vm.count("inputTSV") + vm.count("inputTXT") + vm.count("inputSES") != 1) { std::stringstream ss; ss << od; @@ -76,6 +79,7 @@ int MacaonDecode::main() auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, "")); auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; + auto inputSES = variables.count("inputSES") ? variables["inputSES"].as<std::string>() : ""; auto mcd = variables["mcd"].as<std::string>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; @@ -117,7 +121,9 @@ int MacaonDecode::main() } else { - if (rawInputs.size()) + if (inputSES.size()) + configs.emplace_back(mcd, noTsv, util::utf8string(), std::vector<int>()); + else if (rawInputs.size()) configs.emplace_back(mcd, noTsv, rawInputs[0], std::vector<int>()); else configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>()); @@ -132,11 +138,18 @@ int MacaonDecode::main() std::for_each(std::execution::par, configs.begin(), configs.end(), [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config) { - decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); + decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); }); } else - decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement); + { + if (not inputSES.empty()) + { + decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>(Producer(inputSES))); + } + else + decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); + } for (unsigned int i = 0; i < configs.size(); i++) configs[i].print(stdout, i == 0); diff --git a/decoder/src/Producer.cpp b/decoder/src/Producer.cpp new file mode 100644 index 0000000..8d564dd --- /dev/null +++ b/decoder/src/Producer.cpp @@ -0,0 +1,22 @@ +#include "Producer.hpp" + +Producer::Producer(std::filesystem::path) +{ +} + +bool Producer::apply(Config & config) +{ + if (util::choiceWithProbability(0.05)) + { + config.rawInputAdd("."); + config.rawInputAdd(" "); + } + else if (util::choiceWithProbability(0.8)) + config.rawInputAdd(fmt::format("{}", (char) ('a'+rand()%26))); + else + config.rawInputAdd(" "); + + curNb++; + return curNb < maxNb; +} + diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index ed0a29d..db8eb10 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -49,6 +49,7 @@ class Config protected : Utf8String rawInput; + bool rawInputIsComplete; std::size_t wordIndex{0}; std::size_t characterIndex{0}; std::size_t currentSentenceStartRawInput{0}; @@ -120,6 +121,10 @@ class Config util::String & getFirstEmpty(const std::string & colName, int lineIndex); bool hasCharacter(int letterIndex) const; const util::utf8char & getLetter(int letterIndex) const; + bool getRawInputStatus() const; + void setRawInputStatus(bool status); + void rawInputPop(); + void rawInputAdd(util::utf8char letter); void addToHistory(const std::string & transition); void addToStack(std::size_t index); void popStack(); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index d978112..72e33ec 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -444,7 +444,7 @@ Action Action::endWord() config.setCurrentWordId(config.getCurrentWordId()+1); addHypothesisRelative(Config::idColName, Config::Object::Buffer, 0, std::to_string(config.getCurrentWordId())).apply(config, a); - if (!config.rawInputOnlySeparatorsLeft() and !config.has(0,config.getWordIndex()+1,0)) + if (!(config.rawInputOnlySeparatorsLeft() and config.getRawInputStatus()) and !config.has(0,config.getWordIndex()+1,0)) config.addLines(1); }; diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 61488eb..086f7f0 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -156,9 +156,6 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name( BaseConfig::BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes) { - if (sentences.empty() and rawInput.empty()) - util::myThrow("sentences and rawInput can't be both empty"); - createColumns(mcd); if (not rawInput.empty()) diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 69ae400..22ca045 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -14,6 +14,7 @@ Config::Config(const Config & other) this->strategy.reset(other.strategy ? new Strategy(*other.strategy) : nullptr); this->rawInput = other.rawInput; + this->rawInputIsComplete = other.rawInputIsComplete; this->wordIndex = other.wordIndex; this->characterIndex = other.characterIndex; this->state = other.state; @@ -396,6 +397,26 @@ const util::utf8char & Config::getLetter(int letterIndex) const return rawInput[letterIndex]; } +bool Config::getRawInputStatus() const +{ + return rawInputIsComplete; +} + +void Config::setRawInputStatus(bool status) +{ + rawInputIsComplete = status; +} + +void Config::rawInputPop() +{ + rawInput.pop_back(); +} + +void Config::rawInputAdd(util::utf8char letter) +{ + rawInput.push_back(letter); +} + bool Config::isMultiword(std::size_t lineIndex) const { return hasColIndex(idColName) && std::string(getConst(idColName, lineIndex, 0)).find('-') != std::string::npos; diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index b22b9e6..d23440e 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -116,7 +116,9 @@ bool Strategy::Block::isFinished(const Config & c, const Movement & movement) if (condition == EndCondition::CannotMove) { if (c.canMoveWordIndex(movement.second)) + { return false; + } } return true; diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index d760d1f..08832ba 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -4,6 +4,7 @@ #include "util.hpp" #include "NeuralNetwork.hpp" #include "WordEmbeddings.hpp" +#include "Producer.hpp" namespace po = boost::program_options; @@ -332,14 +333,14 @@ int MacaonTrain::main() std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(), [&decoder, debug, printAdvancement](BaseConfig & devConfig) { - decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); + decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, std::optional<Producer>()); }); NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice()); machine.to(NeuralNetworkImpl::getDevice()); } else { - decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement); + decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, std::optional<Producer>()); } std::vector<const Config *> devConfigsPtrs; -- GitLab