diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index 01567574a93a47a506efcb8b66defb60689ab3fc..4518c72e5f89c9034ff192f74434de6d4cdab2d8 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -26,7 +26,7 @@ class Decoder public : Decoder(ReadingMachine & machine); - std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer); + std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer); void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; diff --git a/decoder/include/Producer.hpp b/decoder/include/Producer.hpp index ec954024e0983b7fd8300e9ed98c2dd629e8f3ee..a5b331a88765668828e8e91f0f46761d077ad15f 100644 --- a/decoder/include/Producer.hpp +++ b/decoder/include/Producer.hpp @@ -11,11 +11,16 @@ class Producer static constexpr int maxNb = 100; int curNb = 0; + std::filesystem::path input, output; + std::vector<std::string> sequence; + public : - Producer(std::filesystem::path path); + Producer(std::filesystem::path input, std::filesystem::path output); bool apply(Config & config); + void addConfigToSequence(const Config & config); + void writeOutputFile() const; }; #endif diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 1a5e151e76071ac7f55da63a07fc097e1ce0e362..90500dfd21f0017040551db14c5f18572695bc14 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -140,6 +140,10 @@ void Beam::update(ReadingMachine & machine, bool debug, std::optional<Producer> config.setState(movement.first); config.moveWordIndexRelaxed(movement.second); + + + if (producer.has_value()) + producer.value().addConfigToSequence(config); } if (debug) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 7b1e7b028c91d9b86d7fff73ea42c8f4c9c8acc8..93493bf9f719ac13f52054370c847b53402d480c 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer) +std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer) { constexpr int printInterval = 50; @@ -42,6 +42,11 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float { auto eosTransition = Transition("EOS b.0"); eosTransition.apply(baseConfig); + baseConfig.addToHistory(eosTransition.getName()); + + if (producer.has_value()) + producer.value().addConfigToSequence(baseConfig); + if (debug) { fmt::print(stderr, "Forcing EOS transition\n"); diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index f96406150fe5a988518885dcfb38e888be073c3a..47622578c117ebd59885c48dd95a1b63e9e8d443 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -33,6 +33,8 @@ po::options_description MacaonDecode::getOptionsDescription() "Size of the beam during beam search") ("beamThreshold", po::value<float>()->default_value(0.1), "Minimal probability an action must have to be considered in the beam search") + ("outputSES", po::value<std::string>()->default_value(""), + "Output file for enriched SES") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -80,6 +82,7 @@ int MacaonDecode::main() auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; auto inputSES = variables.count("inputSES") ? variables["inputSES"].as<std::string>() : ""; + auto outputSES = variables.count("outputSES") ? variables["outputSES"].as<std::string>() : ""; auto mcd = variables["mcd"].as<std::string>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; @@ -88,6 +91,8 @@ int MacaonDecode::main() auto beamSize = variables["beamSize"].as<int>(); auto beamThreshold = variables["beamThreshold"].as<float>(); + auto noProducer = std::optional<Producer>(); + torch::globalContext().setBenchmarkCuDNN(true); Submodule::setReloadPretrained(reloadPretrained); @@ -136,19 +141,21 @@ int MacaonDecode::main() NeuralNetworkImpl::setDevice(torch::kCPU); machine.to(NeuralNetworkImpl::getDevice()); std::for_each(std::execution::par, configs.begin(), configs.end(), - [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config) + [&decoder, debug, printAdvancement, beamSize, beamThreshold, &noProducer](BaseConfig & config) { - decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); + decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, noProducer); }); } else { if (not inputSES.empty()) { - decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>(Producer(inputSES))); + auto producer = std::optional<Producer>(Producer(inputSES, outputSES)); + decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, producer); + producer.value().writeOutputFile(); } else - decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); + decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, noProducer); } for (unsigned int i = 0; i < configs.size(); i++) diff --git a/decoder/src/Producer.cpp b/decoder/src/Producer.cpp index 8d564dd531f063a354008fef3b4cd9e2e1a3b2b7..9906095352d9a5acddb0b2eb231fed2b3415d323 100644 --- a/decoder/src/Producer.cpp +++ b/decoder/src/Producer.cpp @@ -1,22 +1,48 @@ #include "Producer.hpp" -Producer::Producer(std::filesystem::path) +Producer::Producer(std::filesystem::path input, std::filesystem::path output) : input(input), output(output) { } +// Add one or more characters to config's rawInput. +// Returns false if we are finished and true if we have events remaining. bool Producer::apply(Config & config) { if (util::choiceWithProbability(0.05)) { config.rawInputAdd("."); config.rawInputAdd(" "); + sequence.push_back("<addletter \".\">"); + sequence.push_back("<addletter \" \">"); } else if (util::choiceWithProbability(0.8)) - config.rawInputAdd(fmt::format("{}", (char) ('a'+rand()%26))); + { + auto letter = fmt::format("{}", (char) ('a'+rand()%26)); + config.rawInputAdd(letter); + sequence.push_back(fmt::format("<addletter \"{}\">", letter)); + } else + { config.rawInputAdd(" "); + sequence.push_back("<addletter \" \">"); + } curNb++; return curNb < maxNb; } +// Adds an event in the sequence that represent the current config state. +void Producer::addConfigToSequence(const Config & config) +{ + sequence.push_back(fmt::format("<action \"{}\">", config.getHistory(0))); +} + +// Writes the entire sequence to the output file. +void Producer::writeOutputFile() const +{ + std::FILE * outputFile = output.empty() ? stdout : std::fopen(output.c_str(), "w"); + + for (auto & event : sequence) + fmt::print(outputFile, "{}\n", event); +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 08832ba56542f473a6a437115472369907e28e95..7229ee4b94c62ee41aa0b51fb4505031f4e6575c 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -151,6 +151,8 @@ int MacaonTrain::main() std::srand(seed); torch::manual_seed(seed); + auto noProducer = std::optional<Producer>(); + auto trainStrategy = parseTrainStrategy(trainStrategyStr); torch::globalContext().setBenchmarkCuDNN(true); @@ -331,16 +333,16 @@ int MacaonTrain::main() NeuralNetworkImpl::setDevice(torch::kCPU); machine.to(NeuralNetworkImpl::getDevice()); std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(), - [&decoder, debug, printAdvancement](BaseConfig & devConfig) + [&decoder, debug, printAdvancement, &noProducer](BaseConfig & devConfig) { - decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, std::optional<Producer>()); + decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, noProducer); }); NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice()); machine.to(NeuralNetworkImpl::getDevice()); } else { - decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, std::optional<Producer>()); + decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, noProducer); } std::vector<const Config *> devConfigsPtrs;