Skip to content
Snippets Groups Projects
Commit e7e29b09 authored by Franck Dary's avatar Franck Dary
Browse files

Added output ses

parent 89fe9c35
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,7 @@ class Decoder
public :
Decoder(ReadingMachine & machine);
std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer);
std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer);
void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted);
std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
......
......@@ -11,11 +11,16 @@ class Producer
static constexpr int maxNb = 100;
int curNb = 0;
std::filesystem::path input, output;
std::vector<std::string> sequence;
public :
Producer(std::filesystem::path path);
Producer(std::filesystem::path input, std::filesystem::path output);
bool apply(Config & config);
void addConfigToSequence(const Config & config);
void writeOutputFile() const;
};
#endif
......@@ -140,6 +140,10 @@ void Beam::update(ReadingMachine & machine, bool debug, std::optional<Producer>
config.setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
if (producer.has_value())
producer.value().addConfigToSequence(config);
}
if (debug)
......
......@@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}
std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer)
std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer)
{
constexpr int printInterval = 50;
......@@ -42,6 +42,11 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float
{
auto eosTransition = Transition("EOS b.0");
eosTransition.apply(baseConfig);
baseConfig.addToHistory(eosTransition.getName());
if (producer.has_value())
producer.value().addConfigToSequence(baseConfig);
if (debug)
{
fmt::print(stderr, "Forcing EOS transition\n");
......
......@@ -33,6 +33,8 @@ po::options_description MacaonDecode::getOptionsDescription()
"Size of the beam during beam search")
("beamThreshold", po::value<float>()->default_value(0.1),
"Minimal probability an action must have to be considered in the beam search")
("outputSES", po::value<std::string>()->default_value(""),
"Output file for enriched SES")
("help,h", "Produce this help message");
desc.add(req).add(opt);
......@@ -80,6 +82,7 @@ int MacaonDecode::main()
auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
auto inputSES = variables.count("inputSES") ? variables["inputSES"].as<std::string>() : "";
auto outputSES = variables.count("outputSES") ? variables["outputSES"].as<std::string>() : "";
auto mcd = variables["mcd"].as<std::string>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
......@@ -88,6 +91,8 @@ int MacaonDecode::main()
auto beamSize = variables["beamSize"].as<int>();
auto beamThreshold = variables["beamThreshold"].as<float>();
auto noProducer = std::optional<Producer>();
torch::globalContext().setBenchmarkCuDNN(true);
Submodule::setReloadPretrained(reloadPretrained);
......@@ -136,19 +141,21 @@ int MacaonDecode::main()
NeuralNetworkImpl::setDevice(torch::kCPU);
machine.to(NeuralNetworkImpl::getDevice());
std::for_each(std::execution::par, configs.begin(), configs.end(),
[&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config)
[&decoder, debug, printAdvancement, beamSize, beamThreshold, &noProducer](BaseConfig & config)
{
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>());
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, noProducer);
});
}
else
{
if (not inputSES.empty())
{
decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>(Producer(inputSES)));
auto producer = std::optional<Producer>(Producer(inputSES, outputSES));
decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, producer);
producer.value().writeOutputFile();
}
else
decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>());
decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, noProducer);
}
for (unsigned int i = 0; i < configs.size(); i++)
......
#include "Producer.hpp"
Producer::Producer(std::filesystem::path)
Producer::Producer(std::filesystem::path input, std::filesystem::path output) : input(input), output(output)
{
}
// Add one or more characters to config's rawInput.
// Returns false if we are finished and true if we have events remaining.
bool Producer::apply(Config & config)
{
if (util::choiceWithProbability(0.05))
{
config.rawInputAdd(".");
config.rawInputAdd(" ");
sequence.push_back("<addletter \".\">");
sequence.push_back("<addletter \" \">");
}
else if (util::choiceWithProbability(0.8))
config.rawInputAdd(fmt::format("{}", (char) ('a'+rand()%26)));
{
auto letter = fmt::format("{}", (char) ('a'+rand()%26));
config.rawInputAdd(letter);
sequence.push_back(fmt::format("<addletter \"{}\">", letter));
}
else
{
config.rawInputAdd(" ");
sequence.push_back("<addletter \" \">");
}
curNb++;
return curNb < maxNb;
}
// Adds an event in the sequence that represent the current config state.
void Producer::addConfigToSequence(const Config & config)
{
sequence.push_back(fmt::format("<action \"{}\">", config.getHistory(0)));
}
// Writes the entire sequence to the output file.
void Producer::writeOutputFile() const
{
std::FILE * outputFile = output.empty() ? stdout : std::fopen(output.c_str(), "w");
for (auto & event : sequence)
fmt::print(outputFile, "{}\n", event);
}
......@@ -151,6 +151,8 @@ int MacaonTrain::main()
std::srand(seed);
torch::manual_seed(seed);
auto noProducer = std::optional<Producer>();
auto trainStrategy = parseTrainStrategy(trainStrategyStr);
torch::globalContext().setBenchmarkCuDNN(true);
......@@ -331,16 +333,16 @@ int MacaonTrain::main()
NeuralNetworkImpl::setDevice(torch::kCPU);
machine.to(NeuralNetworkImpl::getDevice());
std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(),
[&decoder, debug, printAdvancement](BaseConfig & devConfig)
[&decoder, debug, printAdvancement, &noProducer](BaseConfig & devConfig)
{
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, std::optional<Producer>());
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, noProducer);
});
NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
machine.to(NeuralNetworkImpl::getDevice());
}
else
{
decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, std::optional<Producer>());
decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, noProducer);
}
std::vector<const Config *> devConfigsPtrs;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment