Skip to content
Snippets Groups Projects
Commit e7e29b09 authored by Franck Dary's avatar Franck Dary
Browse files

Added output ses

parent 89fe9c35
Branches
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ class Decoder ...@@ -26,7 +26,7 @@ class Decoder
public : public :
Decoder(ReadingMachine & machine); Decoder(ReadingMachine & machine);
std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer); std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer);
void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted); void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted);
std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
......
...@@ -11,11 +11,16 @@ class Producer ...@@ -11,11 +11,16 @@ class Producer
static constexpr int maxNb = 100; static constexpr int maxNb = 100;
int curNb = 0; int curNb = 0;
std::filesystem::path input, output;
std::vector<std::string> sequence;
public : public :
Producer(std::filesystem::path path); Producer(std::filesystem::path input, std::filesystem::path output);
bool apply(Config & config); bool apply(Config & config);
void addConfigToSequence(const Config & config);
void writeOutputFile() const;
}; };
#endif #endif
...@@ -140,6 +140,10 @@ void Beam::update(ReadingMachine & machine, bool debug, std::optional<Producer> ...@@ -140,6 +140,10 @@ void Beam::update(ReadingMachine & machine, bool debug, std::optional<Producer>
config.setState(movement.first); config.setState(movement.first);
config.moveWordIndexRelaxed(movement.second); config.moveWordIndexRelaxed(movement.second);
if (producer.has_value())
producer.value().addConfigToSequence(config);
} }
if (debug) if (debug)
......
...@@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) ...@@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{ {
} }
std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer) std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer)
{ {
constexpr int printInterval = 50; constexpr int printInterval = 50;
...@@ -42,6 +42,11 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float ...@@ -42,6 +42,11 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float
{ {
auto eosTransition = Transition("EOS b.0"); auto eosTransition = Transition("EOS b.0");
eosTransition.apply(baseConfig); eosTransition.apply(baseConfig);
baseConfig.addToHistory(eosTransition.getName());
if (producer.has_value())
producer.value().addConfigToSequence(baseConfig);
if (debug) if (debug)
{ {
fmt::print(stderr, "Forcing EOS transition\n"); fmt::print(stderr, "Forcing EOS transition\n");
......
...@@ -33,6 +33,8 @@ po::options_description MacaonDecode::getOptionsDescription() ...@@ -33,6 +33,8 @@ po::options_description MacaonDecode::getOptionsDescription()
"Size of the beam during beam search") "Size of the beam during beam search")
("beamThreshold", po::value<float>()->default_value(0.1), ("beamThreshold", po::value<float>()->default_value(0.1),
"Minimal probability an action must have to be considered in the beam search") "Minimal probability an action must have to be considered in the beam search")
("outputSES", po::value<std::string>()->default_value(""),
"Output file for enriched SES")
("help,h", "Produce this help message"); ("help,h", "Produce this help message");
desc.add(req).add(opt); desc.add(req).add(opt);
...@@ -80,6 +82,7 @@ int MacaonDecode::main() ...@@ -80,6 +82,7 @@ int MacaonDecode::main()
auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
auto inputSES = variables.count("inputSES") ? variables["inputSES"].as<std::string>() : ""; auto inputSES = variables.count("inputSES") ? variables["inputSES"].as<std::string>() : "";
auto outputSES = variables.count("outputSES") ? variables["outputSES"].as<std::string>() : "";
auto mcd = variables["mcd"].as<std::string>(); auto mcd = variables["mcd"].as<std::string>();
bool debug = variables.count("debug") == 0 ? false : true; bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
...@@ -88,6 +91,8 @@ int MacaonDecode::main() ...@@ -88,6 +91,8 @@ int MacaonDecode::main()
auto beamSize = variables["beamSize"].as<int>(); auto beamSize = variables["beamSize"].as<int>();
auto beamThreshold = variables["beamThreshold"].as<float>(); auto beamThreshold = variables["beamThreshold"].as<float>();
auto noProducer = std::optional<Producer>();
torch::globalContext().setBenchmarkCuDNN(true); torch::globalContext().setBenchmarkCuDNN(true);
Submodule::setReloadPretrained(reloadPretrained); Submodule::setReloadPretrained(reloadPretrained);
...@@ -136,19 +141,21 @@ int MacaonDecode::main() ...@@ -136,19 +141,21 @@ int MacaonDecode::main()
NeuralNetworkImpl::setDevice(torch::kCPU); NeuralNetworkImpl::setDevice(torch::kCPU);
machine.to(NeuralNetworkImpl::getDevice()); machine.to(NeuralNetworkImpl::getDevice());
std::for_each(std::execution::par, configs.begin(), configs.end(), std::for_each(std::execution::par, configs.begin(), configs.end(),
[&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config) [&decoder, debug, printAdvancement, beamSize, beamThreshold, &noProducer](BaseConfig & config)
{ {
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, noProducer);
}); });
} }
else else
{ {
if (not inputSES.empty()) if (not inputSES.empty())
{ {
decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>(Producer(inputSES))); auto producer = std::optional<Producer>(Producer(inputSES, outputSES));
decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, producer);
producer.value().writeOutputFile();
} }
else else
decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>()); decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, noProducer);
} }
for (unsigned int i = 0; i < configs.size(); i++) for (unsigned int i = 0; i < configs.size(); i++)
......
#include "Producer.hpp" #include "Producer.hpp"
Producer::Producer(std::filesystem::path) Producer::Producer(std::filesystem::path input, std::filesystem::path output) : input(input), output(output)
{ {
} }
// Add one or more characters to config's rawInput.
// Returns false if we are finished and true if we have events remaining.
bool Producer::apply(Config & config) bool Producer::apply(Config & config)
{ {
if (util::choiceWithProbability(0.05)) if (util::choiceWithProbability(0.05))
{ {
config.rawInputAdd("."); config.rawInputAdd(".");
config.rawInputAdd(" "); config.rawInputAdd(" ");
sequence.push_back("<addletter \".\">");
sequence.push_back("<addletter \" \">");
} }
else if (util::choiceWithProbability(0.8)) else if (util::choiceWithProbability(0.8))
config.rawInputAdd(fmt::format("{}", (char) ('a'+rand()%26))); {
auto letter = fmt::format("{}", (char) ('a'+rand()%26));
config.rawInputAdd(letter);
sequence.push_back(fmt::format("<addletter \"{}\">", letter));
}
else else
{
config.rawInputAdd(" "); config.rawInputAdd(" ");
sequence.push_back("<addletter \" \">");
}
curNb++; curNb++;
return curNb < maxNb; return curNb < maxNb;
} }
// Adds an event in the sequence that represent the current config state.
void Producer::addConfigToSequence(const Config & config)
{
sequence.push_back(fmt::format("<action \"{}\">", config.getHistory(0)));
}
// Writes the entire sequence to the output file.
void Producer::writeOutputFile() const
{
std::FILE * outputFile = output.empty() ? stdout : std::fopen(output.c_str(), "w");
for (auto & event : sequence)
fmt::print(outputFile, "{}\n", event);
}
...@@ -151,6 +151,8 @@ int MacaonTrain::main() ...@@ -151,6 +151,8 @@ int MacaonTrain::main()
std::srand(seed); std::srand(seed);
torch::manual_seed(seed); torch::manual_seed(seed);
auto noProducer = std::optional<Producer>();
auto trainStrategy = parseTrainStrategy(trainStrategyStr); auto trainStrategy = parseTrainStrategy(trainStrategyStr);
torch::globalContext().setBenchmarkCuDNN(true); torch::globalContext().setBenchmarkCuDNN(true);
...@@ -331,16 +333,16 @@ int MacaonTrain::main() ...@@ -331,16 +333,16 @@ int MacaonTrain::main()
NeuralNetworkImpl::setDevice(torch::kCPU); NeuralNetworkImpl::setDevice(torch::kCPU);
machine.to(NeuralNetworkImpl::getDevice()); machine.to(NeuralNetworkImpl::getDevice());
std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(), std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(),
[&decoder, debug, printAdvancement](BaseConfig & devConfig) [&decoder, debug, printAdvancement, &noProducer](BaseConfig & devConfig)
{ {
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, std::optional<Producer>()); decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, noProducer);
}); });
NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice()); NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
machine.to(NeuralNetworkImpl::getDevice()); machine.to(NeuralNetworkImpl::getDevice());
} }
else else
{ {
decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, std::optional<Producer>()); decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, noProducer);
} }
std::vector<const Config *> devConfigsPtrs; std::vector<const Config *> devConfigsPtrs;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment