From e7e29b090188941bb94c0f1251420857df504279 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 22 Jul 2021 10:53:57 +0200 Subject: [PATCH] Added output ses --- decoder/include/Decoder.hpp | 2 +- decoder/include/Producer.hpp | 7 ++++++- decoder/src/Beam.cpp | 4 ++++ decoder/src/Decoder.cpp | 7 ++++++- decoder/src/MacaonDecode.cpp | 15 +++++++++++---- decoder/src/Producer.cpp | 30 ++++++++++++++++++++++++++++-- trainer/src/MacaonTrain.cpp | 8 +++++--- 7 files changed, 61 insertions(+), 12 deletions(-) diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index 0156757..4518c72 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -26,7 +26,7 @@ class Decoder public : Decoder(ReadingMachine & machine); - std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer); + std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer); void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; diff --git a/decoder/include/Producer.hpp b/decoder/include/Producer.hpp index ec95402..a5b331a 100644 --- a/decoder/include/Producer.hpp +++ b/decoder/include/Producer.hpp @@ -11,11 +11,16 @@ class Producer static constexpr int maxNb = 100; int curNb = 0; + std::filesystem::path input, output; + std::vector<std::string> sequence; + public : - Producer(std::filesystem::path path); + Producer(std::filesystem::path input, std::filesystem::path output); bool apply(Config & config); + void addConfigToSequence(const Config & config); + void writeOutputFile() const; }; #endif diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 1a5e151..90500df 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -140,6 +140,10 @@ void Beam::update(ReadingMachine & machine, bool debug, std::optional<Producer> config.setState(movement.first); config.moveWordIndexRelaxed(movement.second); + + + if (producer.has_value()) + producer.value().addConfigToSequence(config); } if (debug) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 7b1e7b0..93493bf 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer) +std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer) { constexpr int printInterval = 50; @@ -42,6 +42,11 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float { auto eosTransition = Transition("EOS b.0"); eosTransition.apply(baseConfig); + baseConfig.addToHistory(eosTransition.getName()); + + if (producer.has_value()) + producer.value().addConfigToSequence(baseConfig); + if (debug) { fmt::print(stderr, "Forcing EOS transition\n"); diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index f964061..4762257 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -33,6 +33,8 @@ po::options_description MacaonDecode::getOptionsDescription() "Size of the beam during beam search") ("beamThreshold", po::value<float>()->default_value(0.1), "Minimal probability an action must have to be considered in the beam search") + ("outputSES", po::value<std::string>()->default_value(""), + "Output file for enriched SES") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -80,6 +82,7 @@ int MacaonDecode::main() auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; auto inputSES = variables.count("inputSES") ? variables["inputSES"].as<std::string>() : ""; + auto outputSES = variables.count("outputSES") ? variables["outputSES"].as<std::string>() : ""; auto mcd = variables["mcd"].as<std::string>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; @@ -88,6 +91,8 @@ int MacaonDecode::main() auto beamSize = variables["beamSize"].as<int>(); auto beamThreshold = variables["beamThreshold"].as<float>(); + auto noProducer = std::optional<Producer>(); + torch::globalContext().setBenchmarkCuDNN(true); Submodule::setReloadPretrained(reloadPretrained); @@ -136,19 +141,21 @@ int MacaonDecode::main() NeuralNetworkImpl::setDevice(torch::kCPU); machine.to(NeuralNetworkImpl::getDevice()); std::for_each(std::execution::par, configs.begin(), configs.end(), - [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config) + [&decoder, debug, printAdvancement, beamSize, beamThreshold, &noProducer](BaseConfig & config) { - decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); + decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, noProducer); }); } else { if (not inputSES.empty()) { - decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>(Producer(inputSES))); + auto producer = std::optional<Producer>(Producer(inputSES, outputSES)); + decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, producer); + producer.value().writeOutputFile(); } else - decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); + decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, noProducer); } for (unsigned int i = 0; i < configs.size(); i++) diff --git a/decoder/src/Producer.cpp b/decoder/src/Producer.cpp index 8d564dd..9906095 100644 --- a/decoder/src/Producer.cpp +++ b/decoder/src/Producer.cpp @@ -1,22 +1,48 @@ #include "Producer.hpp" -Producer::Producer(std::filesystem::path) +Producer::Producer(std::filesystem::path input, std::filesystem::path output) : input(input), output(output) { } +// Add one or more characters to config's rawInput. +// Returns false if we are finished and true if we have events remaining. bool Producer::apply(Config & config) { if (util::choiceWithProbability(0.05)) { config.rawInputAdd("."); config.rawInputAdd(" "); + sequence.push_back("<addletter \".\">"); + sequence.push_back("<addletter \" \">"); } else if (util::choiceWithProbability(0.8)) - config.rawInputAdd(fmt::format("{}", (char) ('a'+rand()%26))); + { + auto letter = fmt::format("{}", (char) ('a'+rand()%26)); + config.rawInputAdd(letter); + sequence.push_back(fmt::format("<addletter \"{}\">", letter)); + } else + { config.rawInputAdd(" "); + sequence.push_back("<addletter \" \">"); + } curNb++; return curNb < maxNb; } +// Adds an event in the sequence that represent the current config state. +void Producer::addConfigToSequence(const Config & config) +{ + sequence.push_back(fmt::format("<action \"{}\">", config.getHistory(0))); +} + +// Writes the entire sequence to the output file. +void Producer::writeOutputFile() const +{ + std::FILE * outputFile = output.empty() ? stdout : std::fopen(output.c_str(), "w"); + + for (auto & event : sequence) + fmt::print(outputFile, "{}\n", event); +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 08832ba..7229ee4 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -151,6 +151,8 @@ int MacaonTrain::main() std::srand(seed); torch::manual_seed(seed); + auto noProducer = std::optional<Producer>(); + auto trainStrategy = parseTrainStrategy(trainStrategyStr); torch::globalContext().setBenchmarkCuDNN(true); @@ -331,16 +333,16 @@ int MacaonTrain::main() NeuralNetworkImpl::setDevice(torch::kCPU); machine.to(NeuralNetworkImpl::getDevice()); std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(), - [&decoder, debug, printAdvancement](BaseConfig & devConfig) + [&decoder, debug, printAdvancement, &noProducer](BaseConfig & devConfig) { - decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, std::optional<Producer>()); + decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, noProducer); }); NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice()); machine.to(NeuralNetworkImpl::getDevice()); } else { - decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, std::optional<Producer>()); + decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, noProducer); } std::vector<const Config *> devConfigsPtrs; -- GitLab